commit
7f89fbba2d
|
@ -1381,4 +1381,17 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
assertNotNull(ds.getFeatures());
|
assertNotNull(ds.getFeatures());
|
||||||
assertNull(ds.getLabels());
|
assertNull(ds.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCollectMetaData(){
|
||||||
|
RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.<List<Writable>>emptyList()), 1)
|
||||||
|
.collectMetaData(true)
|
||||||
|
.build();
|
||||||
|
assertTrue(trainIter.isCollectMetaData());
|
||||||
|
trainIter.setCollectMetaData(false);
|
||||||
|
assertFalse(trainIter.isCollectMetaData());
|
||||||
|
trainIter.setCollectMetaData(true);
|
||||||
|
assertTrue(trainIter.isCollectMetaData());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,6 @@ import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.eval.meta.Prediction;
|
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
@ -52,19 +51,13 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.linalg.util.FeatureUtil;
|
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 12/22/14.
|
* Created by agibsonccc on 12/22/14.
|
||||||
|
@ -165,7 +158,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
|
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testEvaluationWithMetaData() throws Exception {
|
public void testEvaluationWithMetaData() throws Exception {
|
||||||
|
|
||||||
RecordReader csv = new CSVRecordReader();
|
RecordReader csv = new CSVRecordReader();
|
||||||
|
@ -256,6 +249,30 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
assertEquals(actualCounts[i], actualClassI.size());
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
assertEquals(predictedCounts[i], predictedClassI.size());
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Finally: test doEvaluation methods
|
||||||
|
rrdsi.reset();
|
||||||
|
org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
|
net.doEvaluation(rrdsi, e2);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
|
||||||
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationGraph cg = net.toComputationGraph();
|
||||||
|
rrdsi.reset();
|
||||||
|
e2 = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
|
cg.doEvaluation(rrdsi, e2);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e2.getPredictionsByActualClass(i);
|
||||||
|
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e2.getPredictionByPredictedClass(i);
|
||||||
|
assertEquals(actualCounts[i], actualClassI.size());
|
||||||
|
assertEquals(predictedCounts[i], predictedClassI.size());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
|
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
|
||||||
|
@ -504,11 +521,11 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
|
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Evaluation e = new Evaluation();
|
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
|
||||||
RegressionEvaluation e2 = new RegressionEvaluation();
|
org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation();
|
||||||
Map<Integer,IEvaluation[]> evals = new HashMap<>();
|
Map<Integer,org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
|
||||||
evals.put(0, new IEvaluation[]{(IEvaluation) e});
|
evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e});
|
||||||
evals.put(1, new IEvaluation[]{(IEvaluation) e2});
|
evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2});
|
||||||
|
|
||||||
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
|
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
|
||||||
|
|
||||||
|
@ -567,14 +584,14 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
net.evaluateROC(iter);
|
net.evaluateROC(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
net.evaluateROCMultiClass(iter);
|
net.evaluateROCMultiClass(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
||||||
|
@ -589,14 +606,14 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
cg.evaluateROC(iter);
|
cg.evaluateROC(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
cg.evaluateROCMultiClass(iter);
|
cg.evaluateROCMultiClass(iter, 0);
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
|
||||||
|
@ -606,10 +623,10 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
//Disable validation, and check same thing:
|
//Disable validation, and check same thing:
|
||||||
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
|
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
|
||||||
net.evaluate(iter);
|
net.evaluate(iter);
|
||||||
net.evaluateROCMultiClass(iter);
|
net.evaluateROCMultiClass(iter, 0);
|
||||||
|
|
||||||
cg.getConfiguration().setValidateOutputLayerConfig(false);
|
cg.getConfiguration().setValidateOutputLayerConfig(false);
|
||||||
cg.evaluate(iter);
|
cg.evaluate(iter);
|
||||||
cg.evaluateROCMultiClass(iter);
|
cg.evaluateROCMultiClass(iter, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class RegressionEvalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
DataSet ds = new DataSet(f, l);
|
DataSet ds = new DataSet(f, l);
|
||||||
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
|
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
|
||||||
RegressionEvaluation re = net.evaluateRegression(iter);
|
org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter);
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
assertEquals(1.0, re.meanSquaredError(i), 1e-6);
|
assertEquals(1.0, re.meanSquaredError(i), 1e-6);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.dtypes;
|
package org.deeplearning4j.nn.dtypes;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -128,7 +129,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) {
|
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(tmpFile.getAbsolutePath(), localCacheDir.getAbsolutePath(), false);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
//Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state
|
//Catch any errors during extraction, and delete the directory to avoid leaving the dir in an invalid state
|
||||||
if(localCacheDir.exists())
|
if(localCacheDir.exists())
|
||||||
|
|
|
@ -205,6 +205,7 @@ public class RecordReaderDataSetIterator implements DataSetIterator {
|
||||||
this.numPossibleLabels = b.numPossibleLabels;
|
this.numPossibleLabels = b.numPossibleLabels;
|
||||||
this.regression = b.regression;
|
this.regression = b.regression;
|
||||||
this.preProcessor = b.preProcessor;
|
this.preProcessor = b.preProcessor;
|
||||||
|
this.collectMetaData = b.collectMetaData;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -105,6 +105,14 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-tensorflow</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -103,4 +103,6 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration {
|
||||||
|
|
||||||
/* Keras weight initializers. */
|
/* Keras weight initializers. */
|
||||||
private final String LAYER_FIELD_INIT = "kernel_initializer";
|
private final String LAYER_FIELD_INIT = "kernel_initializer";
|
||||||
|
|
||||||
|
private final String TENSORFLOW_OP_LAYER = "TensorFlowOpLayer";
|
||||||
}
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
public class KerasTFOpLayer extends KerasLayer {
|
||||||
|
|
||||||
|
public KerasTFOpLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
|
||||||
|
super(kerasVersion);
|
||||||
|
if (kerasVersion != 2){
|
||||||
|
throw new UnsupportedKerasConfigurationException("KerasTFOpLayer expects Keras version 2");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor from parsed Keras layer configuration dictionary.
|
||||||
|
*
|
||||||
|
* @param layerConfig dictionary containing Keras layer configuration
|
||||||
|
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||||
|
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||||
|
*/
|
||||||
|
public KerasTFOpLayer(Map<String, Object> layerConfig)
|
||||||
|
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||||
|
this(layerConfig, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor from parsed Keras layer configuration dictionary.
|
||||||
|
*
|
||||||
|
* @param layerConfig dictionary containing Keras layer configuration
|
||||||
|
* @param enforceTrainingConfig whether to enforce training-related configuration options
|
||||||
|
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||||
|
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||||
|
*/
|
||||||
|
public KerasTFOpLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException{
|
||||||
|
super(layerConfig, enforceTrainingConfig);
|
||||||
|
this.layer = new TFOpLayer((Map)((Map)layerConfig.get("config")).get("node_def"), (Map)((Map)layerConfig.get("config")).get("constants"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get layer output type.
|
||||||
|
*
|
||||||
|
* @param inputType Array of InputTypes
|
||||||
|
* @return output type as InputType
|
||||||
|
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||||
|
*/
|
||||||
|
public InputType getOutputType(InputType... inputType){
|
||||||
|
return this.layer.getOutputType(0, inputType[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl;
|
||||||
|
import org.deeplearning4j.nn.params.EmptyParamInitializer;
|
||||||
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
public class TFOpLayer extends Layer {
|
||||||
|
|
||||||
|
private Map nodeDef;
|
||||||
|
private Map constants;
|
||||||
|
public TFOpLayer(Map nodeDef, Map constants){
|
||||||
|
super();
|
||||||
|
this.nodeDef = nodeDef;
|
||||||
|
this.constants = constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParamInitializer initializer() {
|
||||||
|
return EmptyParamInitializer.getInstance();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isPretrainParam(String param){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputType getOutputType(int idx, InputType inputType){
|
||||||
|
long[] shape = inputType.getShape(true);
|
||||||
|
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
|
||||||
|
long[] outputShape = tempLayer.getOutputShape(shape);
|
||||||
|
return InputType.inferInputType(Nd4j.create(outputShape));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNIn(InputType inputType, boolean override){}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GradientNormalization getGradientNormalization(){return null;}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||||
|
boolean initializeParams, DataType networkDataType) {
|
||||||
|
|
||||||
|
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType);
|
||||||
|
tfOpLayerImpl.setListeners(trainingListeners);
|
||||||
|
tfOpLayerImpl.setIndex(layerIndex);
|
||||||
|
return tfOpLayerImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getGradientNormalizationThreshold(){return 0.;}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Regularization> getRegularizationByParam(String paramName){return null;}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
|
return new LayerMemoryReport(); //TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,169 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.TFGraphRunnerService;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import org.nd4j.shade.protobuf.Message;
|
||||||
|
import org.nd4j.shade.protobuf.TextFormat;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Data
|
||||||
|
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
|
|
||||||
|
|
||||||
|
private Map nodeDef;
|
||||||
|
private Map constants;
|
||||||
|
private List<String> inputNames;
|
||||||
|
TFGraphRunnerService graphRunnerService;
|
||||||
|
|
||||||
|
public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){
|
||||||
|
super(conf, dtype);
|
||||||
|
this.nodeDef = nodeDef;
|
||||||
|
this.constants = constants;
|
||||||
|
setGraphRunner();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr){
|
||||||
|
throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet." +
|
||||||
|
" TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models " +
|
||||||
|
"(tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a Map representation of Nodedef to a singleton TF Graph and instantiates a GraphRunner.
|
||||||
|
*/
|
||||||
|
private void setGraphRunner() {
|
||||||
|
try{
|
||||||
|
String json = new Gson().toJson(nodeDef);
|
||||||
|
NodeDef.Builder builder = NodeDef.newBuilder();
|
||||||
|
org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json, builder);
|
||||||
|
NodeDef nodeDef = builder.build();
|
||||||
|
List<String> allInputNames = new ArrayList<>(); // including constants
|
||||||
|
Map<String, String> inputDataTypes = new HashMap<>();
|
||||||
|
Map<String, INDArray> constArrays = new HashMap();
|
||||||
|
this.inputNames = new ArrayList<>();
|
||||||
|
List<String> outputNames = Arrays.asList(nodeDef.getName());
|
||||||
|
Map<String, AttrValue> attrMap = nodeDef.getAttrMap();
|
||||||
|
for (int i = 0; i < nodeDef.getInputCount(); i++){
|
||||||
|
String inputName = nodeDef.getInput(i);
|
||||||
|
String[] split = inputName.split("/");
|
||||||
|
String attrKey;
|
||||||
|
if (split.length == 1){
|
||||||
|
attrKey = "T";
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
attrKey = "T" + split[split.length - 1];
|
||||||
|
}
|
||||||
|
allInputNames.add(nodeDef.getInput(i));
|
||||||
|
inputDataTypes.put(nodeDef.getInput(i), attrMap.get(attrKey).getType().toString());
|
||||||
|
if (constants.containsKey(String.valueOf(i))){
|
||||||
|
constArrays.put(nodeDef.getInput(i), Nd4j.create((List<Number>)constants.get(String.valueOf(i))));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
this.inputNames.add(nodeDef.getInput(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}";
|
||||||
|
for (int i = 0; i < allInputNames.size(); i++){
|
||||||
|
String inpName = allInputNames.get(i);
|
||||||
|
String dtype = inputDataTypes.get(inpName);
|
||||||
|
graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
|
||||||
|
}
|
||||||
|
log.info(graph);
|
||||||
|
GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
|
||||||
|
TextFormat.getParser().merge(graph, graphDefBuilder);
|
||||||
|
GraphDef graphDef = graphDefBuilder.build();
|
||||||
|
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
|
||||||
|
byte[] graphBytes = serialized.toByteArray();
|
||||||
|
|
||||||
|
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
|
||||||
|
Iterator<TFGraphRunnerService> iter = sl.iterator();
|
||||||
|
if (!iter.hasNext()){
|
||||||
|
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes);
|
||||||
|
}
|
||||||
|
catch (Exception e){
|
||||||
|
throw new RuntimeException("Error parsing protobuf", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray runGraph(INDArray input){
|
||||||
|
if (input.rank() == 3){
|
||||||
|
// TODO make this a preprocessor
|
||||||
|
input = input.permute(0, 2, 1);
|
||||||
|
}
|
||||||
|
Map<String, INDArray> inputMap = new HashMap<>();
|
||||||
|
inputMap.put(inputNames.get(0), input);
|
||||||
|
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
|
||||||
|
if (out.rank() == 3){
|
||||||
|
out = out.permute(0, 2, 1); // TODO post-processing?
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[] getOutputShape(long[] inputShape){
|
||||||
|
long[] shape = ArrayUtils.clone(inputShape);
|
||||||
|
for(int i = 0; i < shape.length; i++){
|
||||||
|
if (shape[i] < 0){
|
||||||
|
shape[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
INDArray dummyArr = Nd4j.zeros(shape);
|
||||||
|
return runGraph(dummyArr).shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
||||||
|
return runGraph(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isPretrainLayer(){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void clearNoiseWeightParams(){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -190,7 +190,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
|
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
|
||||||
InputPreProcessor preProcessor = getInputPreprocessor(inputType);
|
InputPreProcessor preProcessor = getInputPreprocessor(inputType);
|
||||||
if (preProcessor != null)
|
if (preProcessor != null)
|
||||||
return preProcessor.getOutputType(inputType[0]);
|
return this.getBidirectionalLayer().getOutputType(-1, preProcessor.getOutputType(inputType[0]));
|
||||||
else
|
else
|
||||||
return this.getBidirectionalLayer().getOutputType(-1, inputType[0]);
|
return this.getBidirectionalLayer().getOutputType(-1, inputType[0]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
|
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
|
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
||||||
|
@ -317,6 +319,11 @@ public class KerasLayerUtils {
|
||||||
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
||||||
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
||||||
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
||||||
|
} else if (conf instanceof Keras2LayerConfiguration){
|
||||||
|
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||||
|
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
|
||||||
|
layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (layer == null){
|
if (layer == null){
|
||||||
Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
|
Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
|
||||||
|
@ -402,6 +409,16 @@ public class KerasLayerUtils {
|
||||||
public static String getLayerNameFromConfig(Map<String, Object> layerConfig,
|
public static String getLayerNameFromConfig(Map<String, Object> layerConfig,
|
||||||
KerasLayerConfiguration conf)
|
KerasLayerConfiguration conf)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
|
if(conf instanceof Keras2LayerConfiguration){
|
||||||
|
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||||
|
if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())){
|
||||||
|
if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
|
||||||
|
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
|
||||||
|
+ " missing from layer config");
|
||||||
|
return (String) layerConfig.get(conf.getLAYER_FIELD_NAME());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
|
if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
|
||||||
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
|
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class TFKerasTests extends BaseDL4JTest{
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testModelWithTFOp1() throws Exception{
|
||||||
|
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
|
||||||
|
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||||
|
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||||
|
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testModelWithTFOp2() throws Exception{
|
||||||
|
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
|
||||||
|
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||||
|
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||||
|
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
|
||||||
|
long[] expectedShape = new long[]{12 * 2, 5};
|
||||||
|
Assert.assertArrayEquals(expectedShape, out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -67,7 +67,7 @@ public class KuromojiBinFilesFetcher {
|
||||||
new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"),
|
new URL("https://dl4jdata.blob.core.windows.net/kuromoji/kuromoji_bin_files.tar.gz"),
|
||||||
tarFile);
|
tarFile);
|
||||||
}
|
}
|
||||||
ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), rootDir.getAbsolutePath(), false);
|
||||||
|
|
||||||
return rootDir.getAbsoluteFile();
|
return rootDir.getAbsoluteFile();
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,11 @@
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.gson</groupId>
|
||||||
|
<artifactId>gson</artifactId>
|
||||||
|
<version>${gson.version}</version>
|
||||||
|
</dependency>
|
||||||
<!-- ND4J Shaded Jackson Dependency -->
|
<!-- ND4J Shaded Jackson Dependency -->
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -4170,6 +4170,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
INDArray[] featuresMasks = next.getFeaturesMaskArrays();
|
INDArray[] featuresMasks = next.getFeaturesMaskArrays();
|
||||||
INDArray[] labels = next.getLabels();
|
INDArray[] labels = next.getLabels();
|
||||||
INDArray[] labelMasks = next.getLabelsMaskArrays();
|
INDArray[] labelMasks = next.getLabelsMaskArrays();
|
||||||
|
List<Serializable> meta = next.getExampleMetaData();
|
||||||
|
|
||||||
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
|
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
|
||||||
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
|
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
|
||||||
|
@ -4188,7 +4189,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (IEvaluation evaluation : evalsThisOutput)
|
for (IEvaluation evaluation : evalsThisOutput)
|
||||||
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i));
|
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i), meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,9 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class ComputationGraphUtil {
|
public class ComputationGraphUtil {
|
||||||
|
|
||||||
private ComputationGraphUtil() {}
|
private ComputationGraphUtil() {}
|
||||||
|
@ -33,13 +36,16 @@ public class ComputationGraphUtil {
|
||||||
INDArray l = dataSet.getLabels();
|
INDArray l = dataSet.getLabels();
|
||||||
INDArray fMask = dataSet.getFeaturesMaskArray();
|
INDArray fMask = dataSet.getFeaturesMaskArray();
|
||||||
INDArray lMask = dataSet.getLabelsMaskArray();
|
INDArray lMask = dataSet.getLabelsMaskArray();
|
||||||
|
List<Serializable> meta = dataSet.getExampleMetaData();
|
||||||
|
|
||||||
INDArray[] fNew = f == null ? null : new INDArray[] {f};
|
INDArray[] fNew = f == null ? null : new INDArray[] {f};
|
||||||
INDArray[] lNew = l == null ? null : new INDArray[] {l};
|
INDArray[] lNew = l == null ? null : new INDArray[] {l};
|
||||||
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
|
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
|
||||||
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
|
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
|
||||||
|
|
||||||
return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
|
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
|
||||||
|
mds.setExampleMetaData(meta);
|
||||||
|
return mds;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */
|
/** Convert a DataSetIterator to a MultiDataSetIterator, via an adaptor class */
|
||||||
|
|
|
@ -62,6 +62,7 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
|
||||||
|
|
||||||
public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) {
|
public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
|
if (conf != null)
|
||||||
cacheMode = conf.getCacheMode();
|
cacheMode = conf.getCacheMode();
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,14 +25,11 @@ import lombok.val;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.adapters.OutputAdapter;
|
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.eval.RegressionEvaluation;
|
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.*;
|
|
||||||
import org.deeplearning4j.nn.api.Updater;
|
import org.deeplearning4j.nn.api.Updater;
|
||||||
|
import org.deeplearning4j.nn.api.*;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
|
||||||
import org.deeplearning4j.nn.conf.*;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
|
@ -44,8 +41,8 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.FrozenLayer;
|
import org.deeplearning4j.nn.layers.FrozenLayer;
|
||||||
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
|
||||||
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
|
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
|
||||||
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
import org.deeplearning4j.nn.updater.UpdaterCreator;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
@ -58,19 +55,23 @@ import org.deeplearning4j.util.CrashReportingUtil;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
import org.deeplearning4j.util.OutputLayerUtil;
|
import org.deeplearning4j.util.OutputLayerUtil;
|
||||||
|
import org.nd4j.adapters.OutputAdapter;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -84,7 +85,6 @@ import org.nd4j.linalg.heartbeat.reports.Task;
|
||||||
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
|
||||||
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.primitives.Triple;
|
import org.nd4j.linalg.primitives.Triple;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
@ -96,6 +96,8 @@ import org.nd4j.util.OneTimeLogger;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br>
|
* MultiLayerNetwork is a neural network with multiple layers in a stack, and usually an output layer.<br>
|
||||||
|
@ -3315,19 +3317,39 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
* @param iterator Iterator to evaluate on
|
* @param iterator Iterator to evaluate on
|
||||||
* @return Evaluation object; results of evaluation on all examples in the data set
|
* @return Evaluation object; results of evaluation on all examples in the data set
|
||||||
*/
|
*/
|
||||||
public <T extends Evaluation> T evaluate(DataSetIterator iterator) {
|
public <T extends Evaluation> T evaluate(@NonNull DataSetIterator iterator) {
|
||||||
return (T)evaluate(iterator, null);
|
return (T)evaluate(iterator, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate the network (classification performance).
|
||||||
|
* Can only be used with MultiDataSetIterator instances with a single input/output array
|
||||||
|
*
|
||||||
|
* @param iterator Iterator to evaluate on
|
||||||
|
* @return Evaluation object; results of evaluation on all examples in the data set
|
||||||
|
*/
|
||||||
|
public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) {
|
||||||
|
return evaluate(new MultiDataSetWrapperIterator(iterator));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate the network for regression performance
|
* Evaluate the network for regression performance
|
||||||
* @param iterator Data to evaluate on
|
* @param iterator Data to evaluate on
|
||||||
* @return
|
* @return Regression evaluation
|
||||||
*/
|
*/
|
||||||
public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
|
public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator iterator) {
|
||||||
return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0];
|
return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate the network for regression performance
|
||||||
|
* Can only be used with MultiDataSetIterator instances with a single input/output array
|
||||||
|
* @param iterator Data to evaluate on
|
||||||
|
*/
|
||||||
|
public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) {
|
||||||
|
return evaluateRegression(new MultiDataSetWrapperIterator(iterator));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
|
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
|
||||||
*/
|
*/
|
||||||
|
@ -3424,6 +3446,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
INDArray labels = next.getLabels();
|
INDArray labels = next.getLabels();
|
||||||
INDArray fMask = next.getFeaturesMaskArray();
|
INDArray fMask = next.getFeaturesMaskArray();
|
||||||
INDArray lMask = next.getLabelsMaskArray();
|
INDArray lMask = next.getLabelsMaskArray();
|
||||||
|
List<Serializable> meta = next.getExampleMetaData();
|
||||||
|
|
||||||
|
|
||||||
if (!useRnnSegments) {
|
if (!useRnnSegments) {
|
||||||
|
@ -3433,7 +3456,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (T evaluation : evaluations)
|
for (T evaluation : evaluations)
|
||||||
evaluation.eval(labels, out, lMask);
|
evaluation.eval(labels, out, lMask, meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -222,8 +222,11 @@ public class AdaptiveThresholdAlgorithm implements ThresholdAlgorithm {
|
||||||
if(a == null || Double.isNaN(a.lastThreshold))
|
if(a == null || Double.isNaN(a.lastThreshold))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
|
||||||
lastThresholdSum += a.lastThreshold;
|
lastThresholdSum += a.lastThreshold;
|
||||||
|
if (!Double.isNaN(a.lastSparsity)) {
|
||||||
lastSparsitySum += a.lastSparsity;
|
lastSparsitySum += a.lastSparsity;
|
||||||
|
}
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,16 +38,22 @@
|
||||||
<artifactId>nd4j-aeron</artifactId>
|
<artifactId>nd4j-aeron</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<artifactId>dl4j-spark_2.11</artifactId>
|
<artifactId>dl4j-spark_2.11</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>net.jpountz.lz4</groupId>
|
||||||
|
<artifactId>lz4</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
<artifactId>lombok</artifactId>
|
<artifactId>lombok</artifactId>
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.linalg.dataset.api.iterator.ParallelMultiDataSetIterator;
|
||||||
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
|
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
|
||||||
|
@ -32,14 +33,16 @@ import java.util.List;
|
||||||
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {
|
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {
|
||||||
|
|
||||||
protected final List<Iterator<MultiDataSet>> iterators;
|
protected final List<Iterator<MultiDataSet>> iterators;
|
||||||
|
protected final AtomicInteger position;
|
||||||
|
|
||||||
public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
|
public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
|
||||||
this.iterators = iterators;
|
this.iterators = iterators;
|
||||||
|
this.position = new AtomicInteger(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSet next(int num) {
|
public MultiDataSet next(int num) {
|
||||||
return null;
|
return next();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -59,27 +62,34 @@ public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean asyncSupported() {
|
public boolean asyncSupported() {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasNext() {
|
public boolean hasNext() {
|
||||||
return false;
|
// just checking if that's not the last iterator, or if that's the last one - check if it has something
|
||||||
|
boolean ret = position.get() < iterators.size() - 1
|
||||||
|
|| (position.get() < iterators.size() && iterators.get(position.get()).hasNext());
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MultiDataSet next() {
|
public MultiDataSet next() {
|
||||||
return null;
|
// TODO: this solution isn't ideal, it assumes non-empty iterators all the time. Would be nice to do something here
|
||||||
|
if (!iterators.get(position.get()).hasNext())
|
||||||
|
position.getAndIncrement();
|
||||||
|
|
||||||
|
return iterators.get(position.get()).next();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void remove() {
|
public void remove() {
|
||||||
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -109,6 +109,7 @@ public class SharedTrainingWrapper {
|
||||||
|
|
||||||
// now we're creating DataSetIterators, to feed ParallelWrapper
|
// now we're creating DataSetIterators, to feed ParallelWrapper
|
||||||
iteratorDS = new VirtualDataSetIterator(iteratorsDS);
|
iteratorDS = new VirtualDataSetIterator(iteratorsDS);
|
||||||
|
iteratorMDS = new VirtualMultiDataSetIterator(iteratorsMDS);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static synchronized SharedTrainingWrapper getInstance(long id) {
|
public static synchronized SharedTrainingWrapper getInstance(long id) {
|
||||||
|
@ -447,17 +448,19 @@ public class SharedTrainingWrapper {
|
||||||
throw new DL4JInvalidConfigException("No iterators were defined for training");
|
throw new DL4JInvalidConfigException("No iterators were defined for training");
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while((iteratorDS != null && iteratorDS.hasNext()) || (iteratorMDS != null && iteratorMDS.hasNext())) {
|
boolean dsNext;
|
||||||
|
boolean mdsNext;
|
||||||
|
while((dsNext = iteratorDS != null && iteratorDS.hasNext()) || (mdsNext = iteratorMDS != null && iteratorMDS.hasNext())) {
|
||||||
//Loop as a guard against concurrent modifications and RCs
|
//Loop as a guard against concurrent modifications and RCs
|
||||||
|
|
||||||
if (wrapper != null) {
|
if (wrapper != null) {
|
||||||
if (iteratorDS != null)
|
if (dsNext)
|
||||||
wrapper.fit(iteratorDS);
|
wrapper.fit(iteratorDS);
|
||||||
else
|
else
|
||||||
wrapper.fit(iteratorMDS);
|
wrapper.fit(iteratorMDS);
|
||||||
} else {
|
} else {
|
||||||
// if wrapper is null, we're fitting standalone model then
|
// if wrapper is null, we're fitting standalone model then
|
||||||
if (iteratorDS != null) {
|
if (dsNext) {
|
||||||
if (model instanceof ComputationGraph) {
|
if (model instanceof ComputationGraph) {
|
||||||
((ComputationGraph) originalModel).fit(iteratorDS);
|
((ComputationGraph) originalModel).fit(iteratorDS);
|
||||||
} else if (model instanceof MultiLayerNetwork) {
|
} else if (model instanceof MultiLayerNetwork) {
|
||||||
|
@ -472,6 +475,7 @@ public class SharedTrainingWrapper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(consumer != null)
|
||||||
consumer.getUpdatesQueue().purge();
|
consumer.getUpdatesQueue().purge();
|
||||||
}
|
}
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
|
|
|
@ -116,8 +116,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
|
||||||
}
|
}
|
||||||
|
|
||||||
protected int numExecutors() {
|
protected int numExecutors() {
|
||||||
int numProc = Runtime.getRuntime().availableProcessors();
|
return 4;
|
||||||
return Math.min(4, numProc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MultiLayerConfiguration getBasicConf() {
|
protected MultiLayerConfiguration getBasicConf() {
|
||||||
|
|
|
@ -49,6 +49,7 @@ import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.learning.config.AMSGrad;
|
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||||
|
@ -66,20 +67,26 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
||||||
public class GradientSharingTrainingTest extends BaseSparkTest {
|
public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void trainSanityCheck() throws Exception {
|
public void trainSanityCheck() throws Exception {
|
||||||
|
|
||||||
|
for(boolean mds : new boolean[]{false, true}) {
|
||||||
INDArray last = null;
|
INDArray last = null;
|
||||||
INDArray lastDup = null;
|
INDArray lastDup = null;
|
||||||
for (String s : new String[]{"paths", "direct", "export"}) {
|
for (String s : new String[]{"paths", "direct", "export"}) {
|
||||||
System.out.println("--------------------------------------------------------------------------------------------------------------");
|
System.out.println("--------------------------------------------------------------------------------------------------------------");
|
||||||
log.info("Starting: {}", s);
|
log.info("Starting: {} - {}", s, (mds ? "MultiDataSet" : "DataSet"));
|
||||||
boolean isPaths = "paths".equals(s);
|
boolean isPaths = "paths".equals(s);
|
||||||
|
|
||||||
RDDTrainingApproach rddTrainingApproach;
|
RDDTrainingApproach rddTrainingApproach;
|
||||||
|
@ -144,7 +151,11 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
DataSet d = iter.next();
|
DataSet d = iter.next();
|
||||||
if (isPaths) {
|
if (isPaths) {
|
||||||
File out = new File(f, count + ".bin");
|
File out = new File(f, count + ".bin");
|
||||||
|
if(mds){
|
||||||
|
d.toMultiDataSet().save(out);
|
||||||
|
} else {
|
||||||
d.save(out);
|
d.save(out);
|
||||||
|
}
|
||||||
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
|
String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
|
||||||
paths.add(path);
|
paths.add(path);
|
||||||
}
|
}
|
||||||
|
@ -160,6 +171,27 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
|
INDArray paramsBefore = sparkNet.getNetwork().params().dup();
|
||||||
ComputationGraph after;
|
ComputationGraph after;
|
||||||
|
if(mds) {
|
||||||
|
//Fitting from MultiDataSet
|
||||||
|
List<MultiDataSet> mdsList = new ArrayList<>();
|
||||||
|
for(DataSet d : ds){
|
||||||
|
mdsList.add(d.toMultiDataSet());
|
||||||
|
}
|
||||||
|
switch (s) {
|
||||||
|
case "direct":
|
||||||
|
case "export":
|
||||||
|
JavaRDD<MultiDataSet> dsRDD = sc.parallelize(mdsList);
|
||||||
|
after = sparkNet.fitMultiDataSet(dsRDD);
|
||||||
|
break;
|
||||||
|
case "paths":
|
||||||
|
JavaRDD<String> pathRdd = sc.parallelize(paths);
|
||||||
|
after = sparkNet.fitPathsMultiDataSet(pathRdd);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//Fitting from DataSet
|
||||||
switch (s) {
|
switch (s) {
|
||||||
case "direct":
|
case "direct":
|
||||||
case "export":
|
case "export":
|
||||||
|
@ -173,6 +205,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
INDArray paramsAfter = after.params();
|
INDArray paramsAfter = after.params();
|
||||||
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
|
@ -199,6 +232,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
lastDup = last.dup();
|
lastDup = last.dup();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -289,7 +323,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test @Ignore
|
||||||
public void testEpochUpdating() throws Exception {
|
public void testEpochUpdating() throws Exception {
|
||||||
//Ensure that epoch counter is incremented properly on the workers
|
//Ensure that epoch counter is incremented properly on the workers
|
||||||
|
|
||||||
|
@ -316,7 +350,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new AMSGrad(0.1))
|
.updater(new AMSGrad(0.001))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
|
.layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
|
||||||
|
|
|
@ -20,12 +20,12 @@ log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||||
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||||
|
|
||||||
log4j.appender.org.springframework=DEBUG
|
log4j.appender.org.springframework=DEBUG
|
||||||
log4j.appender.org.deeplearning4j=DEBUG
|
log4j.appender.org.deeplearning4j=INFO
|
||||||
log4j.appender.org.nd4j=DEBUG
|
log4j.appender.org.nd4j=INFO
|
||||||
|
|
||||||
log4j.logger.org.springframework=INFO
|
log4j.logger.org.springframework=INFO
|
||||||
log4j.logger.org.deeplearning4j=DEBUG
|
log4j.logger.org.deeplearning4j=INFO
|
||||||
log4j.logger.org.nd4j=DEBUG
|
log4j.logger.org.nd4j=INFO
|
||||||
log4j.logger.org.apache.spark=WARN
|
log4j.logger.org.apache.spark=WARN
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@
|
||||||
|
|
||||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||||
<logger name="org.springframework" level="DEBUG" />
|
<logger name="org.springframework" level="DEBUG" />
|
||||||
<logger name="org.deeplearning4j" level="DEBUG" />
|
<logger name="org.deeplearning4j" level="INFO" />
|
||||||
<logger name="org.datavec" level="INFO" />
|
<logger name="org.datavec" level="INFO" />
|
||||||
<logger name="org.nd4j" level="INFO" />
|
<logger name="org.nd4j" level="INFO" />
|
||||||
<logger name="opennlp.uima.util" level="OFF" />
|
<logger name="opennlp.uima.util" level="OFF" />
|
||||||
|
|
|
@ -25,10 +25,6 @@
|
||||||
|
|
||||||
<artifactId>deeplearning4j-ui-components</artifactId>
|
<artifactId>deeplearning4j-ui-components</artifactId>
|
||||||
|
|
||||||
<properties>
|
|
||||||
<freemarker.version>2.3.23</freemarker.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.ui.components.chart.style.StyleChart;
|
||||||
import org.deeplearning4j.ui.components.table.ComponentTable;
|
import org.deeplearning4j.ui.components.table.ComponentTable;
|
||||||
import org.deeplearning4j.ui.components.table.style.StyleTable;
|
import org.deeplearning4j.ui.components.table.style.StyleTable;
|
||||||
import org.deeplearning4j.ui.standalone.StaticPageUtil;
|
import org.deeplearning4j.ui.standalone.StaticPageUtil;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
|
@ -60,7 +60,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.freemarker</groupId>
|
<groupId>org.freemarker</groupId>
|
||||||
<artifactId>freemarker</artifactId>
|
<artifactId>freemarker</artifactId>
|
||||||
<version>2.3.29</version>
|
<version>${freemarker.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -200,6 +200,7 @@ public class TrainModule implements UIModule {
|
||||||
}));
|
}));
|
||||||
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
||||||
} else {
|
} else {
|
||||||
|
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
|
||||||
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
||||||
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
|
r.add(new Route("/train/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)));
|
||||||
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
|
r.add(new Route("/train/overview", HttpMethod.GET, (path, rc) -> this.renderFtl("TrainingOverview.html.ftl", rc)));
|
||||||
|
|
|
@ -33,7 +33,7 @@ OP_IMPL(mergeadd, -1, 1, false) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
@ -42,7 +42,6 @@ OP_IMPL(mergeadd, -1, 1, false) {
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SYN(mergesum, mergeadd);
|
DECLARE_SYN(mergesum, mergeadd);
|
||||||
DECLARE_SYN(add_n, mergeadd);
|
DECLARE_SYN(add_n, mergeadd);
|
||||||
DECLARE_SYN(addn, mergeadd);
|
DECLARE_SYN(addn, mergeadd);
|
||||||
|
@ -54,6 +53,45 @@ DECLARE_SYN(accumulate_n, mergeadd);
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes(sd::DataType::ANY);
|
->setAllowedOutputTypes(sd::DataType::ANY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(mergeadd_bp, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto inSize = block.width() - 1;
|
||||||
|
|
||||||
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
|
|
||||||
|
std::vector<NDArray*> outArrs(inSize);
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(inSize);
|
||||||
|
|
||||||
|
for (int i = 0; i < inSize; ++i) {
|
||||||
|
outArrs[i] = OUTPUT_VARIABLE(i);
|
||||||
|
}
|
||||||
|
helpers::mergeAddBp(block.launchContext(), *gradient, outArrs);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(mergeadd_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes(sd::DataType::ANY);
|
||||||
|
}
|
||||||
|
DECLARE_SHAPE_FN(mergeadd_bp) {
|
||||||
|
|
||||||
|
const int numOfInArrs = block.width() - 1;
|
||||||
|
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
|
||||||
|
for (int e = 0; e < numOfInArrs; e++) {
|
||||||
|
auto inShape = inputShape->at(e);
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
||||||
|
}
|
||||||
|
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ OP_IMPL(mergeavg, -1, 1, false) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
@ -48,6 +48,44 @@ OP_IMPL(mergeavg, -1, 1, false) {
|
||||||
->setAllowedInputTypes({ALL_FLOATS})
|
->setAllowedInputTypes({ALL_FLOATS})
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(mergeavg_bp, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto inSize = block.width() - 1;
|
||||||
|
|
||||||
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
|
|
||||||
|
std::vector<NDArray*> outArrs(inSize);
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(inSize);
|
||||||
|
|
||||||
|
for (int i = 0; i < inSize; ++i) {
|
||||||
|
outArrs[i] = OUTPUT_VARIABLE(i);
|
||||||
|
}
|
||||||
|
helpers::mergeAvgBp(block.launchContext(), *gradient, outArrs);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(mergeavg_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes(sd::DataType::ANY);
|
||||||
|
}
|
||||||
|
DECLARE_SHAPE_FN(mergeavg_bp) {
|
||||||
|
|
||||||
|
const int numOfInArrs = block.width() - 1;
|
||||||
|
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
|
||||||
|
for (int e = 0; e < numOfInArrs; e++) {
|
||||||
|
auto inShape = inputShape->at(e);
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
||||||
|
}
|
||||||
|
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ OP_IMPL(mergemax, -1, 1, false) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
@ -42,7 +42,6 @@ OP_IMPL(mergemax, -1, 1, false) {
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SYN(MergeMax, mergemax);
|
DECLARE_SYN(MergeMax, mergemax);
|
||||||
|
|
||||||
DECLARE_TYPES(mergemax) {
|
DECLARE_TYPES(mergemax) {
|
||||||
|
@ -51,6 +50,47 @@ DECLARE_SYN(MergeMax, mergemax);
|
||||||
->setAllowedOutputTypes(sd::DataType::ANY);
|
->setAllowedOutputTypes(sd::DataType::ANY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(mergemax_bp, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto inSize = block.width();
|
||||||
|
|
||||||
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
|
|
||||||
|
std::vector<const NDArray*> inArrs(inSize);
|
||||||
|
std::vector<NDArray*> outArrs(inSize - 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < inSize; ++i)
|
||||||
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
|
||||||
|
for (int i = 0; i < (inSize - 1); ++i) {
|
||||||
|
outArrs[i] = OUTPUT_NULLIFIED(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::mergeMaxBp(block.launchContext(), inArrs, outArrs);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(mergemax_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes(sd::DataType::ANY);
|
||||||
|
}
|
||||||
|
DECLARE_SHAPE_FN(mergemax_bp) {
|
||||||
|
|
||||||
|
const int numOfInArrs = block.width() - 1;
|
||||||
|
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
|
||||||
|
for (int e = 0; e < numOfInArrs; e++) {
|
||||||
|
auto inShape = inputShape->at(e);
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
||||||
|
}
|
||||||
|
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
|
||||||
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
|
|
@ -64,6 +64,7 @@ namespace sd {
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_mergemax)
|
#if NOT_EXCLUDED(OP_mergemax)
|
||||||
DECLARE_OP(mergemax, -1, 1, false);
|
DECLARE_OP(mergemax, -1, 1, false);
|
||||||
|
DECLARE_CUSTOM_OP(mergemax_bp, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
/*
|
/*
|
||||||
* Complete tensor with max indices merged from all input tensors list
|
* Complete tensor with max indices merged from all input tensors list
|
||||||
|
@ -78,10 +79,12 @@ namespace sd {
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_mergeadd)
|
#if NOT_EXCLUDED(OP_mergeadd)
|
||||||
DECLARE_OP(mergeadd, -1, 1, false);
|
DECLARE_OP(mergeadd, -1, 1, false);
|
||||||
|
DECLARE_CUSTOM_OP(mergeadd_bp, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_mergeavg)
|
#if NOT_EXCLUDED(OP_mergeavg)
|
||||||
DECLARE_OP(mergeavg, -1, 1, false);
|
DECLARE_OP(mergeavg, -1, 1, false);
|
||||||
|
DECLARE_CUSTOM_OP(mergeavg_bp, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_scatter_update)
|
#if NOT_EXCLUDED(OP_scatter_update)
|
||||||
|
|
|
@ -0,0 +1,274 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
|
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||||
|
|
||||||
|
const T normActual = norm2.e<T>(0);
|
||||||
|
const T normClip = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
if (isInplace) {
|
||||||
|
|
||||||
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
|
if(normActual > normClip)
|
||||||
|
input *= (normClip / normActual);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
const T iNormActual = norm2.e<T>(i);
|
||||||
|
if (iNormActual > normClip)
|
||||||
|
*listOfInSubArrs.at(i) *= normClip / iNormActual;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
|
if(normActual > normClip)
|
||||||
|
output.assign(input * (normClip / normActual));
|
||||||
|
else
|
||||||
|
output.assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||||
|
auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto inputSubArr = listOfInSubArrs.at(i);
|
||||||
|
auto outputSubArr = listOfOutSubArrs.at(i);
|
||||||
|
outputSubArr->assign(inputSubArr);
|
||||||
|
|
||||||
|
const T iNormActual = norm2.e<T>(i);
|
||||||
|
|
||||||
|
if (iNormActual > clipNorm.e<T>(0))
|
||||||
|
*outputSubArr *= clipNorm / iNormActual;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void clipByGlobalNorm_(std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
|
T globalNorm = 0; //NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(sumT : globalNorm)
|
||||||
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
|
auto input = inputs[i];
|
||||||
|
auto l2norm = input->reduceNumber(reduce::Norm2);
|
||||||
|
globalNorm += l2norm.t<T>(0) * l2norm.t<T>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = sd::math::nd4j_sqrt(globalNorm);
|
||||||
|
auto normS = sd::math::nd4j_sqrt<T,T>(globalNorm);
|
||||||
|
outputs[inputs.size()]->p(0, normS);
|
||||||
|
|
||||||
|
const T factor = clipNorm / normS;
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (size_t e = 0; e < inputs.size(); e++) {
|
||||||
|
// all-reduce
|
||||||
|
auto input = inputs[e];
|
||||||
|
auto output = outputs[e];
|
||||||
|
|
||||||
|
if (normS <= clipNorm) {
|
||||||
|
output->assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||||
|
input->applyLambda<T>(lambda, *output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||||
|
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
|
||||||
|
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||||
|
|
||||||
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
|
const T N = norm2.e<T>(0);
|
||||||
|
|
||||||
|
auto cn = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
if(N > cn) {
|
||||||
|
|
||||||
|
const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||||
|
const T factor1 = static_cast<T>(1.f) / N;
|
||||||
|
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
||||||
|
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
||||||
|
};
|
||||||
|
|
||||||
|
(const_cast<NDArray&>(input)).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
gradI.assign(gradO);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
||||||
|
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
||||||
|
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
||||||
|
|
||||||
|
auto cn = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
T N = norm2.e<T>(i);
|
||||||
|
|
||||||
|
auto gradOSubArr = gradOSubArrs.at(i);
|
||||||
|
auto gradISubArr = gradISubArrs.at(i);
|
||||||
|
|
||||||
|
if (N > cn) {
|
||||||
|
auto inputSubArr = inputSubArrs.at(i);
|
||||||
|
const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||||
|
const T factor1 = static_cast<T>(1.f) / N;
|
||||||
|
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
||||||
|
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
||||||
|
};
|
||||||
|
|
||||||
|
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
|
||||||
|
} else
|
||||||
|
gradISubArr->assign(gradOSubArr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
|
|
||||||
|
auto cn = clipNorm.e<T>(0);
|
||||||
|
if (dimensions.size() == 0) {
|
||||||
|
// all-reduce
|
||||||
|
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / input.lengthOf();
|
||||||
|
if (n2 <= cn) {
|
||||||
|
if (!isInplace)
|
||||||
|
output.assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const T factor = cn / n2;
|
||||||
|
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||||
|
input.applyLambda<T>(lambda, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// along dimension
|
||||||
|
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
|
||||||
|
if (!isInplace)
|
||||||
|
output.assign(input);
|
||||||
|
auto tads = output.allTensorsAlongDimension(dimensions);
|
||||||
|
// TODO: make this CUDA-compliant somehow
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
T n2 = norm2.e<T>(e) / tads.at(e)->lengthOf();
|
||||||
|
const T factor = cn / n2;
|
||||||
|
if (n2 > cn) {
|
||||||
|
auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
|
||||||
|
tads.at(e)->applyLambda<T>(lambda, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
|
/*
|
||||||
|
if (d1 > params[1])
|
||||||
|
return params[1];
|
||||||
|
else if (d1 < params[0])
|
||||||
|
return params[0];
|
||||||
|
else return d1;
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||||
|
auto routine = LAMBDA_T(_x, leftBound, rightBound) {
|
||||||
|
if (_x > rightBound) return rightBound;
|
||||||
|
if (_x < leftBound) return leftBound;
|
||||||
|
return _x;
|
||||||
|
};
|
||||||
|
|
||||||
|
input.applyLambda<T>(routine, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (input, leftBound, rightBound, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void eye(sd::LaunchContext * context, NDArray& output) {
|
||||||
|
|
||||||
|
const int rank = output.rankOf();
|
||||||
|
auto arrs = output.allTensorsAlongDimension({rank-2, rank-1});
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
|
arrs.at(i)->setIdentity();
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, arrs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,183 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<X*>(input.getBuffer());
|
||||||
|
const Y* y = reinterpret_cast<Y*>(indices.getBuffer());
|
||||||
|
X* z = reinterpret_cast<X*>(output.getBuffer());
|
||||||
|
|
||||||
|
const int xRank = input.rankOf();
|
||||||
|
const int yRank = indices.rankOf();
|
||||||
|
const int zRank = output.rankOf();
|
||||||
|
const int maxRank = sd::math::nd4j_max<int>(yRank, sd::math::nd4j_max<int>(xRank, zRank));
|
||||||
|
|
||||||
|
const Nd4jLong zLen = output.lengthOf();
|
||||||
|
|
||||||
|
const uint yLastDim = indices.sizeAt(-1);
|
||||||
|
|
||||||
|
const int diff = zRank - xRank;
|
||||||
|
const bool bEqual = yLastDim == xRank;
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
int xCoords[MAX_RANK], zCoords[MAX_RANK], temp;
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, output.getShapeInfo(), zCoords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), zCoords);
|
||||||
|
|
||||||
|
temp = zCoords[yRank - 1];
|
||||||
|
zCoords[yRank - 1] = 0;
|
||||||
|
const auto yOffset = shape::getOffset(indices.getShapeInfo(), zCoords);
|
||||||
|
zCoords[yRank - 1] = temp;
|
||||||
|
|
||||||
|
if(bEqual)
|
||||||
|
memcpy(xCoords, zCoords, zRank * sizeof(int));
|
||||||
|
else if(diff >= 0)
|
||||||
|
memcpy(xCoords, zCoords + diff, xRank * sizeof(int));
|
||||||
|
else
|
||||||
|
memcpy(xCoords - diff, zCoords, zRank * sizeof(int));
|
||||||
|
|
||||||
|
for (uint j = 0; j < yLastDim; ++j)
|
||||||
|
xCoords[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(input.getShapeInfo(), xCoords);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void gather_(NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
|
||||||
|
|
||||||
|
int axis = intArgs.size() > 0 ? intArgs[0] : 0;
|
||||||
|
const int inputRank = input->rankOf();
|
||||||
|
if(axis < 0)
|
||||||
|
axis += inputRank;
|
||||||
|
|
||||||
|
const int numOfIntArgs = intArgs.size();
|
||||||
|
|
||||||
|
if (indices != nullptr) {
|
||||||
|
|
||||||
|
for(Nd4jLong i = 0; i < indices->lengthOf(); ++i)
|
||||||
|
if(indices->e<Nd4jLong>(i) >= input->sizeAt(axis))
|
||||||
|
throw std::runtime_error("helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !");
|
||||||
|
|
||||||
|
// first case: indices consist of only one scalar
|
||||||
|
if(indices->isScalar()) {
|
||||||
|
if(input->rankOf() <= 1){
|
||||||
|
//For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar
|
||||||
|
auto idx = indices->e<Nd4jLong>(0);
|
||||||
|
auto scalarNDArray = input->e(idx);
|
||||||
|
output->assign(scalarNDArray);
|
||||||
|
} else {
|
||||||
|
auto dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis});
|
||||||
|
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
|
||||||
|
auto tadArr = NDArray(reinterpret_cast<void *>(reinterpret_cast<T*>(input->getBuffer()) + tadPack.primaryOffsets()[indices->e<Nd4jLong>(0)]), tadPack.primaryShapeInfo(), output->getContext());
|
||||||
|
output->assign(&tadArr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (input->rankOf() == 1 && indices->isVector()) {
|
||||||
|
// special case
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto e = start; e < stop; e++)
|
||||||
|
output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
std::vector<int> dimsOut(indices->rankOf());
|
||||||
|
std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... indices->rankOf()-1
|
||||||
|
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->getShapeInfo(), dimsOut);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
NDArray subArrOut = (*output)(i, dimsOut);
|
||||||
|
NDArray subArrIn = (*input)(indices->e<Nd4jLong>(i), {axis});
|
||||||
|
subArrOut.assign(subArrIn);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
for(int i = 1; i < numOfIntArgs; ++i)
|
||||||
|
if(intArgs[i] >= input->sizeAt(axis))
|
||||||
|
throw std::runtime_error("helpers::gather function: some of input indexes is larger than corresponding shape of input array !");
|
||||||
|
|
||||||
|
// we only allow scalar/vector case here
|
||||||
|
if (numOfIntArgs == 2) { // scalar case
|
||||||
|
output->assign((*input)(intArgs[1], {axis}));
|
||||||
|
}
|
||||||
|
else { // vector case
|
||||||
|
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->getShapeInfo(), {axis});
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
NDArray subArrOut = (*output)(i, {axis});
|
||||||
|
NDArray subArrIn = (*input)(intArgs[i + 1], {axis});
|
||||||
|
subArrOut.assign(subArrIn);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void gather(NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void invertPermutation(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
|
||||||
|
|
||||||
|
std::set<int> uniqueElems;
|
||||||
|
const int length = input.lengthOf();
|
||||||
|
|
||||||
|
for(int i = 0; i < length; ++i) {
|
||||||
|
|
||||||
|
int elem = input.e<int>(i);
|
||||||
|
|
||||||
|
if(!uniqueElems.insert(elem).second) // this operation forbids us to use #pragma omp
|
||||||
|
throw std::runtime_error("helpers::invertPermutation function: input array contains duplicates !");
|
||||||
|
|
||||||
|
if(elem < 0 || elem > length - 1)
|
||||||
|
throw std::runtime_error("helpers::invertPermutation function: element of input array is out of range (0, length-1) !");
|
||||||
|
|
||||||
|
output.p<int>(elem, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,277 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
|
||||||
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
|
auto x = inArrs[0];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
T max = -DataTypeUtils::max<T>();
|
||||||
|
Nd4jLong idx = 0;
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
T v = inArrs[i]->e<T>(e);
|
||||||
|
if (v > max) {
|
||||||
|
max = v;
|
||||||
|
idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output.p(e, idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeMax_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
|
||||||
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
|
auto x = inArrs[0];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
T max = -DataTypeUtils::max<T>();
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
T v = inArrs[i]->e<T>(e);
|
||||||
|
if (v > max)
|
||||||
|
max = v;
|
||||||
|
}
|
||||||
|
output.p(e, max);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeMax(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
// outArrs.size() == inArrs.size() - 1
|
||||||
|
const Nd4jLong numArgs = outArrs.size();
|
||||||
|
// last array is gradient
|
||||||
|
const auto gradient = inArrs[numArgs]->bufferAsT<T>();
|
||||||
|
auto length = inArrs[numArgs]->lengthOf();
|
||||||
|
|
||||||
|
bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews());
|
||||||
|
|
||||||
|
if (bSameOrderAndEws1) {
|
||||||
|
auto gradOrdering = inArrs[numArgs]->ordering();
|
||||||
|
|
||||||
|
for (int i = 0; i < numArgs; ++i) {
|
||||||
|
bSameOrderAndEws1 &= (gradOrdering == inArrs[i]->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == inArrs[i]->ews());
|
||||||
|
bSameOrderAndEws1 &= (gradOrdering == outArrs[i]->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(bSameOrderAndEws1){
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
T max = -DataTypeUtils::max<T>();
|
||||||
|
Nd4jLong nMaxIndex = 0;
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
const T* v = inArrs[i]->bufferAsT<T>();
|
||||||
|
if (v[e] > max) {
|
||||||
|
max = v[e];
|
||||||
|
nMaxIndex = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
|
||||||
|
z[e] = gradient[e];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gradShape = inArrs[numArgs]->getShapeInfo();
|
||||||
|
std::vector<bool> vbSameShaepeAndStrides(numArgs);
|
||||||
|
for (int i = 0; i < numArgs; ++i) {
|
||||||
|
vbSameShaepeAndStrides[i] = shape::haveSameShapeAndStrides(gradShape, inArrs[i]->getShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, e, gradShape, coords);
|
||||||
|
|
||||||
|
const auto gradOffset = shape::getOffset(gradShape, coords);
|
||||||
|
|
||||||
|
T max = -DataTypeUtils::max<T>();
|
||||||
|
Nd4jLong nMaxIndex = 0;
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
|
||||||
|
const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->getShapeInfo(), coords);
|
||||||
|
const T* v = inArrs[i]->bufferAsT<T>();
|
||||||
|
if (v[xOffset] > max) {
|
||||||
|
max = v[xOffset];
|
||||||
|
nMaxIndex = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->getShapeInfo(), coords);
|
||||||
|
|
||||||
|
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
|
||||||
|
z[zOffset] = gradient[gradOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeMaxBp(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
||||||
|
BUILD_SINGLE_SELECTOR(outArrs[0]->dataType(), mergeMaxBp_, (inArrs, outArrs), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAvg_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
|
const T factor = 1.f / numArgs;
|
||||||
|
auto x = inArrs[0];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
T sum = 0.;
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
T v = inArrs[i]->e<T>(e);
|
||||||
|
sum += v;
|
||||||
|
}
|
||||||
|
output.p<T>(e, sum * factor);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeAvg(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAvgBp_(const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
const Nd4jLong numArgs = outArrs.size();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
|
||||||
|
T v = gradient.e<T>(e) / numArgs;
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
outArrs[i]->p<T>(e, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (gradient, outArrs), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAdd_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
|
||||||
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
|
auto x = inArrs[0];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
T sum = (T) 0.f;
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++)
|
||||||
|
sum += inArrs[i]->e<T>(e);
|
||||||
|
|
||||||
|
output.p(e, sum);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
|
}
|
||||||
|
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAddBp_(const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
const Nd4jLong numArgs = outArrs.size();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto e = start; e < stop; e++) {
|
||||||
|
|
||||||
|
T v = gradient.e<T>(e);
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
outArrs[i]->p<T>(e, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (gradient, outArrs), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,483 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
||||||
|
|
||||||
|
const T* x = input.bufferAsT<T>();
|
||||||
|
T* z = output.bufferAsT<T>();
|
||||||
|
|
||||||
|
const Nd4jLong* xShape = input.shapeOf();
|
||||||
|
const Nd4jLong* zShape = output.shapeOf();
|
||||||
|
|
||||||
|
const int rank = input.rankOf(); // both input and output have the same rank
|
||||||
|
const int rankMinusOne = rank - 1;
|
||||||
|
|
||||||
|
const auto zLen = output.lengthOf();
|
||||||
|
|
||||||
|
if(mode == 0) { // CONSTANT case
|
||||||
|
|
||||||
|
const T padVal = padValue.e<T>(0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
int zCoords[MAX_RANK], xCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, output.getShapeInfo(), zCoords);
|
||||||
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), zCoords);
|
||||||
|
|
||||||
|
memcpy(xCoords, zCoords, rank * sizeof(int));
|
||||||
|
|
||||||
|
bool within = true;
|
||||||
|
|
||||||
|
for (int j = rankMinusOne; j >= 0; --j) {
|
||||||
|
|
||||||
|
if (xShape[j] == zShape[j])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const auto left = paddings.e<Nd4jLong>(j, 0);
|
||||||
|
|
||||||
|
if (zCoords[j] < left || zCoords[j] >= left + xShape[j]) {
|
||||||
|
within = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
xCoords[j] = zCoords[j] - left;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (within)
|
||||||
|
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), xCoords)];
|
||||||
|
else
|
||||||
|
z[zOffset] = padVal;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
|
}
|
||||||
|
else { // REFLECT and SYMMETRIC cases
|
||||||
|
|
||||||
|
const Nd4jLong shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC
|
||||||
|
const Nd4jLong shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
int zCoords[MAX_RANK], xCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, output.getShapeInfo(), zCoords);
|
||||||
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), zCoords);
|
||||||
|
|
||||||
|
memcpy(xCoords, zCoords, rank * sizeof(int));
|
||||||
|
|
||||||
|
for (int j = rankMinusOne; j >= 0; --j) {
|
||||||
|
|
||||||
|
if (xShape[j] == zShape[j])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
xCoords[j] = zCoords[j] - paddings.e<Nd4jLong>(j, 0); // are ready to fill middle (within input dimension range)
|
||||||
|
|
||||||
|
if (xCoords[j] < 0)
|
||||||
|
xCoords[j] = -xCoords[j] - shift1; // means fill from left
|
||||||
|
else if (xCoords[j] >= xShape[j])
|
||||||
|
xCoords[j] = 2 * xShape[j] - xCoords[j] - shift2; // means fill from right
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(input.getShapeInfo(), xCoords);
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// //////////////////////////////////////////////////////////////////////////
|
||||||
|
// template<typename T>
|
||||||
|
// void pad2_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) {
|
||||||
|
|
||||||
|
// const int rank = output.rankOf();
|
||||||
|
// std::vector<int> dimsToExclude(rank);
|
||||||
|
// std::iota(dimsToExclude.begin(), dimsToExclude.end(), 0); // fill with 0, 1, ... rank-1
|
||||||
|
|
||||||
|
// Nd4jLong numLeft = paddings.e<Nd4jLong>(rank-1,0);
|
||||||
|
// Nd4jLong numRight = paddings.e<Nd4jLong>(rank-1,1);
|
||||||
|
// Nd4jLong inDimSize = input.sizeAt(rank-1);
|
||||||
|
// Nd4jLong outDimSize = output.sizeAt(rank-1);
|
||||||
|
|
||||||
|
// std::vector<std::vector<Nd4jLong>> outIdx = { std::vector<Nd4jLong>(2*rank), {numLeft, numLeft + inDimSize}, {0, numLeft}, {numLeft + inDimSize, outDimSize} };
|
||||||
|
|
||||||
|
// for(int i = 0; i < rank-1; ++i) {
|
||||||
|
// outIdx[0][2*i] = paddings.e<Nd4jLong>(i, 0);
|
||||||
|
// outIdx[0][2*i + 1] = outIdx[0][2*i] + input.sizeAt(i);
|
||||||
|
// }
|
||||||
|
// outIdx[0][2*rank-1] = outIdx[0][2*rank-2] = 0;
|
||||||
|
|
||||||
|
// // ***** populate innermost sub-arrays firstly ***** //
|
||||||
|
// dimsToExclude.pop_back();
|
||||||
|
|
||||||
|
// Nd4jLong startL = mode == 1 ? 1 : 0; // REFLECT or SYMMETRIC
|
||||||
|
// Nd4jLong startR = mode == 1 ? inDimSize-2 : inDimSize-1; // REFLECT or SYMMETRIC
|
||||||
|
|
||||||
|
// Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.getShapeInfo(), dimsToExclude);
|
||||||
|
|
||||||
|
// NDArray outSubArr0 = output(outIdx[0], true);
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
// for(Nd4jLong j = 0; j < numOfSubArrs; ++j) {
|
||||||
|
|
||||||
|
// NDArray outSubArr1 = outSubArr0(j, dimsToExclude);
|
||||||
|
// NDArray inSubArr = input(j, dimsToExclude);
|
||||||
|
// NDArray outSubArrMid = outSubArr1(outIdx[1]);
|
||||||
|
|
||||||
|
// outSubArrMid.assign(inSubArr); // assign middle
|
||||||
|
|
||||||
|
// if(mode == 0) { // CONSTANT
|
||||||
|
// if(numLeft != 0) {
|
||||||
|
// NDArray temp = outSubArr1(outIdx[2]);
|
||||||
|
// temp.assign(padValue); // assign left
|
||||||
|
// }
|
||||||
|
// if(numRight != 0) {
|
||||||
|
// NDArray temp = outSubArr1(outIdx[3]);
|
||||||
|
// temp.assign(padValue); // assign right
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// else { // REFLECT or SYMMETRIC
|
||||||
|
|
||||||
|
// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) // fill left side
|
||||||
|
// outSubArr1.t<T>(k) = inSubArr.t<T>(e);
|
||||||
|
|
||||||
|
// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) // fill right side
|
||||||
|
// outSubArr1.t<T>(k) = inSubArr.t<T>(e);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // ***** fill rest of outer sub-arrays ***** //
|
||||||
|
// std::vector<Nd4jLong> outIdxInner(2, 0);
|
||||||
|
// std::vector<Nd4jLong> outIdxOuter(2, 0);
|
||||||
|
|
||||||
|
// for(int i = rankBorder - 1; i >= 0; --i) {
|
||||||
|
|
||||||
|
// dimsToExclude.pop_back();
|
||||||
|
|
||||||
|
// outIdxInner.push_back(0), outIdxInner.push_back(0);
|
||||||
|
// outIdxOuter.push_back(0), outIdxOuter.push_back(0);
|
||||||
|
|
||||||
|
// Nd4jLong numLeft = paddings.e<Nd4jLong>(i, 0);
|
||||||
|
// Nd4jLong numRight = paddings.e<Nd4jLong>(i, 1);
|
||||||
|
|
||||||
|
// if(numLeft == 0 && numRight == 0)
|
||||||
|
// continue;
|
||||||
|
|
||||||
|
// Nd4jLong inDimSize = input.sizeAt(i);
|
||||||
|
// Nd4jLong outDimSize = output.sizeAt(i);
|
||||||
|
|
||||||
|
// if(mode == 0) {
|
||||||
|
// outIdxOuter[0] = 0; outIdxOuter[1] = numLeft;
|
||||||
|
// outIdxInner[0] = numLeft + inDimSize; outIdxInner[1] = outDimSize;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// startL = mode == 1 ? numLeft + 1 : numLeft; // REFLECT or SYMMETRIC
|
||||||
|
// startR = mode == 1 ? numLeft + inDimSize - 2 : numLeft + inDimSize-1; // REFLECT or SYMMETRIC
|
||||||
|
|
||||||
|
// numOfSubArrs = ShapeUtils::getNumOfSubArrs(output.getShapeInfo(), dimsToExclude);
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(outIdxOuter, outIdxInner))
|
||||||
|
// for(Nd4jLong j = 0; j < numOfSubArrs; ++j) {
|
||||||
|
|
||||||
|
// NDArray outSubArr = output(j, dimsToExclude);
|
||||||
|
|
||||||
|
// if(mode == 0) { // CONSTANT
|
||||||
|
|
||||||
|
// if(numLeft != 0) {
|
||||||
|
// NDArray tempO = outSubArr(outIdxOuter);
|
||||||
|
// tempO.assign(padValue); // assign left
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if(numRight != 0) {
|
||||||
|
// NDArray tempI = outSubArr(outIdxInner);
|
||||||
|
// tempI.assign(padValue); // assign right
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// else { // REFLECT or SYMMETRIC
|
||||||
|
|
||||||
|
// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) { // fill left side
|
||||||
|
// outIdxOuter[0] = k;
|
||||||
|
// outIdxOuter[1] = k+1;
|
||||||
|
// outIdxInner[0] = e;
|
||||||
|
// outIdxInner[1] = e+1;
|
||||||
|
// NDArray outSubArrInner = outSubArr(outIdxInner);
|
||||||
|
// NDArray outSubArrOuter = outSubArr(outIdxOuter);
|
||||||
|
// outSubArrOuter.assign(outSubArrInner);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) { // fill right side
|
||||||
|
// outIdxOuter[0] = k;
|
||||||
|
// outIdxOuter[1] = k+1;
|
||||||
|
// outIdxInner[0] = e;
|
||||||
|
// outIdxInner[1] = e+1;
|
||||||
|
// NDArray outSubArrInner = outSubArr(outIdxInner);
|
||||||
|
// NDArray outSubArrOuter = outSubArr(outIdxOuter);
|
||||||
|
// outSubArrOuter.assign(outSubArrInner);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), pad_, (mode, input, paddings, output, padValue), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
||||||
|
|
||||||
|
// mode: 0 - REFLECT, else - SYMMETRIC
|
||||||
|
const int reflBorder = (bool)mode ? 1 : 0;
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
const Nd4jLong outLen = output.lengthOf();
|
||||||
|
|
||||||
|
if(rank <= 1) {
|
||||||
|
|
||||||
|
const Nd4jLong inLen = input.lengthOf();
|
||||||
|
const auto leftSide = paddings.e<Nd4jLong>(0);
|
||||||
|
const auto leftSideCorrected = leftSide - reflBorder;
|
||||||
|
const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder;
|
||||||
|
|
||||||
|
for(int i = 0; i < outLen; ++i) {
|
||||||
|
|
||||||
|
if (i < leftSide) // left side
|
||||||
|
output.p(i, input.e<T>(leftSideCorrected - i));
|
||||||
|
|
||||||
|
else if(i >= leftSide && i < leftSide + inLen) // middle
|
||||||
|
output.p(i, input.e<T>(i - leftSide));
|
||||||
|
|
||||||
|
else // right side
|
||||||
|
output.p(i, input.e<T>(len - i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
int inIdx[MAX_RANK], outIdx[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, output.getShapeInfo(), outIdx);
|
||||||
|
|
||||||
|
for (int j = 0; j < rank; ++j) {
|
||||||
|
const Nd4jLong inLen = input.sizeAt(j);
|
||||||
|
const auto leftSide = paddings.e<T>(j, 0);
|
||||||
|
const auto leftSideCorrected = leftSide - reflBorder;
|
||||||
|
const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder;
|
||||||
|
|
||||||
|
if (outIdx[j] < leftSide) // left side
|
||||||
|
inIdx[j] = leftSideCorrected - outIdx[j];
|
||||||
|
|
||||||
|
else if (outIdx[j] >= leftSide && outIdx[j] < leftSide + inLen) // middle
|
||||||
|
inIdx[j] = outIdx[j] - leftSide;
|
||||||
|
|
||||||
|
else // right side
|
||||||
|
inIdx[j] = len - outIdx[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outOffset = shape::getOffset(output.getShapeInfo(), outIdx);
|
||||||
|
auto inOffset = shape::getOffset(input.getShapeInfo(), inIdx);
|
||||||
|
reinterpret_cast<T *>(output.buffer())[outOffset] = reinterpret_cast<T *>(input.getBuffer())[inOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, outLen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), mirrorPad_, (input, paddings, output, mode), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void mirrorPad_, (const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
/*// initial values of inIdx, outIdx, dim must be equal to zero
|
||||||
|
template<typename T>
|
||||||
|
static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) {
|
||||||
|
|
||||||
|
int leftOffset;
|
||||||
|
// dimensions are array of input dimensions, it is sorted in increasing order
|
||||||
|
// every time at the beginning we erase first element from it (not good idea to use vector for this purpose, but luckily it is small enough)
|
||||||
|
// then we use this array for tads building, every time while recursion the number of built tads becomes bigger
|
||||||
|
dimensions.erase(dimensions.begin());
|
||||||
|
// build tad basing on output array, also create auxiliary arrays pointing on required output array ranges
|
||||||
|
shape::TAD tadOut(output.getShapeInfo(), dimensions.data(), dimensions.size());
|
||||||
|
tadOut.createTadOnlyShapeInfo();
|
||||||
|
tadOut.createOffsets();
|
||||||
|
auto subArrOut = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext());
|
||||||
|
auto subArr = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext());
|
||||||
|
// build tad basing on input array, also create auxiliary array pointing on required input array range
|
||||||
|
shape::TAD tadIn(input.getShapeInfo(), dimensions.data(), dimensions.size());
|
||||||
|
tadIn.createTadOnlyShapeInfo();
|
||||||
|
tadIn.createOffsets();
|
||||||
|
auto subArrIn = NDArray(input.getBuffer(), tadIn.tadOnlyShapeInfo, output.getContext());
|
||||||
|
// these indices take into account recursion and always point to actual tads numbers
|
||||||
|
if (input.rankOf() > 1 && output.rankOf() > 1) {// only for non-vector cases
|
||||||
|
outIdx = outIdx * output.sizeAt(dim + 1);
|
||||||
|
inIdx = inIdx * input.sizeAt(dim + 1);
|
||||||
|
}
|
||||||
|
// current input tad number, we add to it unity in a loop
|
||||||
|
int k = -1;
|
||||||
|
// loop through current dimension
|
||||||
|
for(int i = 0; i < output.sizeAt(dim); ++i) {
|
||||||
|
// corresponds to outer range (relevant indices are absent in input)
|
||||||
|
leftOffset = paddings.e<int>(dim, 0);
|
||||||
|
if(i < leftOffset || i >= (input.sizeAt(dim) + leftOffset))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// increase input tads number
|
||||||
|
++k;
|
||||||
|
// recursion condition allows for the fact that tad can't reduce to scalar
|
||||||
|
if(dim < input.rankOf() - 2)
|
||||||
|
recursiveLoopForPad(mode, input, paddings, output, dimensions, dim + 1, inIdx + k, outIdx + i, padValue);
|
||||||
|
else if (paddings.sizeAt(0) > dim + 1){
|
||||||
|
leftOffset = paddings.e<int>(dim + 1, 0);
|
||||||
|
// shift buffers pointers to actual element position
|
||||||
|
if (output.rankOf() > 1) {
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + i]);
|
||||||
|
subArrIn.setBuffer(reinterpret_cast<T*>(input.getBuffer()) + tadIn.tadOffsets[inIdx + i - paddings.e<int>(dim, 0)]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
subArrOut.p(i, subArrIn.e<T>(i - leftOffset));
|
||||||
|
}
|
||||||
|
// most inner loop, corresponds to last dim = rank-1
|
||||||
|
switch (mode) {
|
||||||
|
case 0: // CONSTANT mode
|
||||||
|
for(int j = 0; j < subArrOut.lengthOf(); ++j)
|
||||||
|
if(j < leftOffset || j >= (subArrIn.lengthOf() + leftOffset) ) // firstly fill with zeros outer ranges
|
||||||
|
subArrOut.p(j, (T)0.f);
|
||||||
|
else
|
||||||
|
subArrOut.p(j, subArrIn.e<T>(j - leftOffset)); // fill middle with elements of input array
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 1: // REFLECT mode
|
||||||
|
for(int j = 1; j <= leftOffset; ++j) // fill firstly left side
|
||||||
|
subArrOut.p(leftOffset - j, subArrIn.e<T>(j));
|
||||||
|
for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle
|
||||||
|
subArrOut.p(leftOffset + j, subArrIn.e<T>(j));
|
||||||
|
for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side
|
||||||
|
subArrOut.p(j, subArrIn.e<T>(subArrOut.lengthOf() - j - 1));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 2: // SYMMETRIC mode
|
||||||
|
for(int j = 1; j <= leftOffset; ++j) // fill firstly left side
|
||||||
|
subArrOut.p(leftOffset - j, subArrIn.e<T>(j-1));
|
||||||
|
for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle
|
||||||
|
subArrOut.p(leftOffset + j, subArrIn.e<T>(j));
|
||||||
|
for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side
|
||||||
|
subArrOut.p(j, subArrIn.e<T>(subArrOut.lengthOf() - j));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
if (mode == 0 && input.rankOf() < 2)
|
||||||
|
subArrOut.p(i, subArrIn.e<T>(i - leftOffset)); // fill middle with elements of input array
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// populate sub-array formed previously
|
||||||
|
leftOffset = paddings.e<int>(dim,0);
|
||||||
|
switch (mode) {
|
||||||
|
case 0: // CONSTANT mode
|
||||||
|
for(int j = 1; j <= leftOffset; ++j) {
|
||||||
|
// fill left side with padValue
|
||||||
|
if (output.rankOf() > 1) {
|
||||||
|
subArrOut.setBuffer(
|
||||||
|
reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]);
|
||||||
|
subArrOut.assign(padValue);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
subArrOut.p(j - 1, padValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// output.printIndexedBuffer("Output at");
|
||||||
|
for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill left side with zeros
|
||||||
|
if (output.rankOf() > 1) {
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]);
|
||||||
|
subArrOut.assign(padValue);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
subArrOut.p(j, padValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 1: // REFLECT mode
|
||||||
|
for(int j = 1; j <= leftOffset; ++j) { // fill left side
|
||||||
|
subArr.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j]);
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]);
|
||||||
|
subArrOut.assign(&subArr);
|
||||||
|
}
|
||||||
|
for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side
|
||||||
|
subArr.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - 1 - j]);
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]);
|
||||||
|
subArrOut.assign(&subArr);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 2: // SYMMETRIC mode
|
||||||
|
for(int j = 1; j <= leftOffset; ++j) { // fill left side
|
||||||
|
subArr.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j - 1]);
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]);
|
||||||
|
subArrOut.assign(&subArr);
|
||||||
|
}
|
||||||
|
for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side
|
||||||
|
subArr.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - j]);
|
||||||
|
subArrOut.setBuffer(reinterpret_cast<T*>(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]);
|
||||||
|
subArrOut.assign(&subArr);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), recursiveLoopForPad_, (mode, input, paddings, output, dimensions, dim, inIdx, outIdx, padValue), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void recursiveLoopForPad_, (const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
#include <graph/RandomGenerator.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
|
||||||
|
// check edge cases first
|
||||||
|
int temp;
|
||||||
|
const int firstDim = input.sizeAt(0);
|
||||||
|
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||||
|
|
||||||
|
if(!isInplace)
|
||||||
|
output.assign(input);
|
||||||
|
}
|
||||||
|
else if (input.isVector() || shape::isLikeVector(input.getShapeInfo(), temp)) {
|
||||||
|
|
||||||
|
// apply Fisher-Yates shuffle
|
||||||
|
if(isInplace) {
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
|
for(int i = firstDim-1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
if(i == r)
|
||||||
|
continue;
|
||||||
|
T t0 = input.t<T>(i);
|
||||||
|
T t1 = input.t<T>(r);
|
||||||
|
//math::nd4j_swap<T>(input(i), input(r));
|
||||||
|
input.t<T>(i) = t1;
|
||||||
|
input.t<T>(r) = t0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> indices(firstDim);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||||
|
|
||||||
|
// FIXME: parallelism!!
|
||||||
|
for(int i = firstDim-1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
output.t<T>(i) = input.t<T>(indices[r]);
|
||||||
|
if(i == r)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
output.t<T>(r) = input.t<T>(indices[i]);
|
||||||
|
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||||
|
}
|
||||||
|
rng.rewindH(firstDim-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||||
|
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
|
// apply Fisher-Yates shuffle
|
||||||
|
if(isInplace) {
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
|
||||||
|
if(i == r)
|
||||||
|
continue;
|
||||||
|
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||||
|
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||||
|
std::vector<int> indices(firstDim);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
bool isZeroShuffled = false;
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
||||||
|
if(r == 0)
|
||||||
|
isZeroShuffled = true;
|
||||||
|
if(i == r)
|
||||||
|
continue;
|
||||||
|
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
||||||
|
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||||
|
}
|
||||||
|
if(!isZeroShuffled)
|
||||||
|
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
||||||
|
}
|
||||||
|
rng.rewindH(firstDim-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
|
||||||
|
|
||||||
|
int opCode = (*intArgs)[0];
|
||||||
|
int dimSize = (*intArgs)[1];
|
||||||
|
Nd4jLong e;
|
||||||
|
Nd4jLong limg = 2 + dimSize;
|
||||||
|
std::vector<int> tadDimensions(dimSize);
|
||||||
|
for (e = 2; e < limg; e++)
|
||||||
|
tadDimensions[e-2] = (*intArgs)[e];
|
||||||
|
|
||||||
|
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), tadDimensions);
|
||||||
|
|
||||||
|
// increasing counter to skip numIndices
|
||||||
|
e++;
|
||||||
|
std::vector<int> indices;
|
||||||
|
for (; e < static_cast<Nd4jLong>(intArgs->size()); e++)
|
||||||
|
indices.push_back((*intArgs)[e]);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto inSubArr = input(indices[i], dimsToExclude, true);
|
||||||
|
auto updSubArr = updates(i, dimsToExclude, true);
|
||||||
|
|
||||||
|
if (inSubArr.lengthOf() != updSubArr.lengthOf())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
switch (opCode) {
|
||||||
|
case 0:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, indices.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
||||||
|
|
||||||
|
// updates and indices have same length
|
||||||
|
const Nd4jLong len = indices.lengthOf();
|
||||||
|
|
||||||
|
switch (opId) {
|
||||||
|
|
||||||
|
case 6: { // copy
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto inSubArr = input(i, dimensions);
|
||||||
|
inSubArr.p(indices.t<Nd4jLong>(i), updates.e(i));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, len);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("helpers::scatterSimple: operation is not implemented for given id !");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
|
||||||
|
|
||||||
|
T* gradIBuff = reinterpret_cast<T*>(gradI.getBuffer());
|
||||||
|
const T* gradOBuff = reinterpret_cast<T*>(gradO.getBuffer());
|
||||||
|
const Nd4jLong gradILen = gradI.lengthOf();
|
||||||
|
const Nd4jLong gradOLen = gradO.lengthOf(); // gradOLen >= gradILen
|
||||||
|
const Nd4jLong gradIEWS = sd::math::nd4j_abs<Nd4jLong>(gradI.ews());
|
||||||
|
const Nd4jLong gradOEWS = gradO.ews();
|
||||||
|
|
||||||
|
// initial zeroing of gradI content
|
||||||
|
if(gradIEWS == 1)
|
||||||
|
memset(gradIBuff, 0, gradILen * sizeof(T));
|
||||||
|
else {
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS)
|
||||||
|
gradIBuff[i] = static_cast<T>(0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(gradO.ordering() == 'c' && gradOEWS == 1) {
|
||||||
|
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for(Nd4jLong i=0; i<gradOLen; ++i) {
|
||||||
|
auto idx = shape::subArrayIndex(i, gradO.getShapeInfo(), gradI.getShapeInfo());
|
||||||
|
gradI.p(idx, gradI.e<T>(idx) + gradOBuff[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(gradO.ordering() == 'c' && gradOEWS > 1) {
|
||||||
|
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for(Nd4jLong i=0; i<gradOLen; ++i) {
|
||||||
|
auto idx = shape::subArrayIndex(i, gradO.getShapeInfo(), gradI.getShapeInfo());
|
||||||
|
gradI.p(idx, gradI.e<T>(idx) + gradOBuff[i * gradOEWS]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for(Nd4jLong i=0; i<gradOLen; ++i) {
|
||||||
|
|
||||||
|
auto fidx = shape::subArrayIndex(i, gradO.getShapeInfo(), gradI.getShapeInfo());
|
||||||
|
gradI.p(fidx, gradI.e<T>(fidx) + gradOBuff[shape::getIndexOffset(i, gradO.getShapeInfo())]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (gradO, gradI, reps), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void tileBP_, (const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void trace_(const NDArray& input, NDArray& output) {
|
||||||
|
const int inRank = input.rankOf();
|
||||||
|
auto setOfSubArrs = input.allTensorsAlongDimension({inRank-2, inRank-1});
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
|
output.p(i, setOfSubArrs.at(i)->getTrace());
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, setOfSubArrs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,56 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
|
||||||
|
|
||||||
|
auto dOdI = NDArray(&gradO); // dO/dI
|
||||||
|
const_cast<NDArray&>(input).fillAsTriangular<T>(0, diagonal, dOdI.sizeAt(-1), dOdI, 'b');
|
||||||
|
int dLen = dOdI.lengthOf();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
if (dOdI.t<T>(i) != static_cast<T>(0.f))
|
||||||
|
dOdI.t<T>(i) = static_cast<T>(1.f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, dLen);
|
||||||
|
|
||||||
|
// FIXME: !!!
|
||||||
|
gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO
|
||||||
|
}
|
||||||
|
|
||||||
|
void triuBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
#include<ops/declarable/helpers/transforms.h>
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
@ -34,7 +34,7 @@ namespace sd {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T, typename Z>
|
template <typename T, typename Z>
|
||||||
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void mergeMaxIndexCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, Nd4jLong* outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<Z*>(voutput);
|
auto output = reinterpret_cast<Z*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -46,54 +46,56 @@ namespace sd {
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
for (int i = 0; i < numArrays; i++) {
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
auto xShape = reinterpret_cast<Nd4jLong*>(inShapes[i]);
|
||||||
auto val = x[shape::getIndexOffset(e, xShape)];;
|
auto val = x[shape::getIndexOffset(e, xShape)];;
|
||||||
if (mVal < val) {
|
if (mVal < val) {
|
||||||
mIdx = static_cast<Z>(i);
|
mIdx = static_cast<Z>(i);
|
||||||
mVal = val;
|
mVal = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape)] = mIdx;
|
output[shape::getIndexOffset(e, outputShape)] = mIdx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Z>
|
template <typename T, typename Z>
|
||||||
static void mergeMaxIndex_(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
static void mergeMaxIndex_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
int nArrSize = static_cast<int>(inArrs.size());
|
||||||
|
std::vector<void*> inBuffers(nArrSize), inShapes(nArrSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrSize; e++) {
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
PointersManager manager(context, "mergeMaxIndex");
|
PointersManager manager(context, "mergeMaxIndex");
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
auto pInBuffers = reinterpret_cast<void**>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*)));
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
auto pInShapes = reinterpret_cast<void**>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*)));
|
||||||
auto length = output.lengthOf();
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
global_mergeMaxIndex_<T,Z><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeMaxIndexCudaLauncher<T, Z> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (pInBuffers, pInShapes, nArrSize, output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMaxIndex(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
NDArray::prepareSpecialUse({&output}, {});
|
|
||||||
for (auto v:inArrs)
|
NDArray::prepareSpecialUse({ &output }, inArrs);
|
||||||
v->syncToDevice();
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&output}, {});
|
NDArray::registerSpecialUse({ &output }, inArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void mergeMaxCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, Nd4jLong* outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -103,51 +105,163 @@ namespace sd {
|
||||||
T mVal = -DataTypeUtils::max<T>();
|
T mVal = -DataTypeUtils::max<T>();
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
for (int i = 0; i < numArrays; i++) {
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
auto x = reinterpret_cast<const T*>(inArrs[i]);
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
auto xShape = reinterpret_cast<const Nd4jLong*>(inShapes[i]);
|
||||||
auto val = x[shape::getIndexOffset(e, xShape)];;
|
auto val = x[shape::getIndexOffset(e, xShape)];;
|
||||||
if (mVal < val)
|
if (mVal < val)
|
||||||
mVal = val;
|
mVal = val;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
output[shape::getIndexOffset(e, outputShape)] = mVal;
|
output[shape::getIndexOffset(e, outputShape)] = mVal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mergeMax_(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
static void mergeMax_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
int nArrsSize = static_cast<int>(inArrs.size());
|
||||||
|
|
||||||
|
std::vector<void*> inBuffers(nArrsSize), inShapes(nArrsSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrsSize; e++) {
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
PointersManager manager(context, "mergeMax");
|
PointersManager manager(context, "mergeMax");
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
auto pInBuffers = reinterpret_cast<void**>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*)));
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
auto pInShapes = reinterpret_cast<void**>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*)));
|
||||||
auto length = output.lengthOf();
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
global_mergeMax_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeMaxCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (pInBuffers, pInShapes, nArrsSize, output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMax(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
NDArray::prepareSpecialUse({&output}, {});
|
|
||||||
for (auto v:inArrs)
|
NDArray::prepareSpecialUse({ &output }, inArrs);
|
||||||
v->syncToDevice();
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
||||||
NDArray::registerSpecialUse({&output}, {});
|
|
||||||
|
NDArray::registerSpecialUse({ &output }, inArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void mergeMaxBpCudaLauncher(void** inArrs, void** inShapes, void* vgradient, Nd4jLong* gradientShape, const int numArrays,
|
||||||
|
void** outArrs, void** outShapes, Nd4jLong length, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
auto grad = reinterpret_cast<T*>(vgradient);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
|
||||||
|
T mVal = -DataTypeUtils::max<T>();
|
||||||
|
int nMaxIndex = 0;
|
||||||
|
auto xOffset = e, zOffset = e, gradOffset = e;
|
||||||
|
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
shape::index2coords(e, gradientShape, coords);
|
||||||
|
gradOffset = shape::getOffset(gradientShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
|
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
auto xShape = reinterpret_cast<Nd4jLong*>(inShapes[i]);
|
||||||
|
xOffset = shape::getOffset(xShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto val = x[xOffset];
|
||||||
|
if (mVal < val) {
|
||||||
|
mVal = val;
|
||||||
|
nMaxIndex = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// outputs have to be pre-nullify
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
auto outShape = reinterpret_cast<Nd4jLong*>(outShapes[nMaxIndex]);
|
||||||
|
zOffset = shape::getOffset(outShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = reinterpret_cast<T*>(outArrs[nMaxIndex]);
|
||||||
|
|
||||||
|
output[zOffset] = grad[gradOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeMaxBp_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs, int nArrSize, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
std::vector<void*> inBuffers(nArrSize), inShapes(nArrSize), outBuffers(nArrSize), outShapes(nArrSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrSize; e++) {
|
||||||
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
|
outBuffers[e] = outArrs[e]->getSpecialBuffer();
|
||||||
|
outShapes[e] = outArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeMaxBp");
|
||||||
|
|
||||||
|
auto pInBuffers = reinterpret_cast<void**>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*)));
|
||||||
|
auto pInShapes = reinterpret_cast<void**>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*)));
|
||||||
|
|
||||||
|
auto pOutBuffers = reinterpret_cast<void**>(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*)));
|
||||||
|
auto pOutShapes = reinterpret_cast<void**>(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*)));
|
||||||
|
|
||||||
|
auto length = inArrs[nArrSize]->lengthOf();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeMaxBpCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (pInBuffers, pInShapes, inArrs[nArrSize]->getSpecialBuffer(),
|
||||||
|
inArrs[nArrSize]->getSpecialShapeInfo(), nArrSize, pOutBuffers, pOutShapes,
|
||||||
|
length, bSameOrderAndEws1);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeMaxBp(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
// not use gradient
|
||||||
|
int nArrSize = static_cast<int>(inArrs.size() - 1);
|
||||||
|
|
||||||
|
const std::vector<const NDArray*>& out = reinterpret_cast<const std::vector<const NDArray*>&>(outArrs);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse(out, inArrs);
|
||||||
|
|
||||||
|
bool bSameOrderAndEws1 = (1 == inArrs[nArrSize]->ews());
|
||||||
|
auto ordering = inArrs[nArrSize]->ordering();
|
||||||
|
|
||||||
|
for (int i = 0; i < nArrSize; ++i) {
|
||||||
|
bSameOrderAndEws1 &= (ordering == inArrs[i]->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == inArrs[i]->ews());
|
||||||
|
|
||||||
|
bSameOrderAndEws1 &= (ordering == outArrs[i]->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(inArrs[nArrSize]->dataType(), mergeMaxBp_, (context, inArrs, outArrs, nArrSize, bSameOrderAndEws1), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse( out, inArrs );
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void mergeAvgCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, Nd4jLong* outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -158,7 +272,7 @@ namespace sd {
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
for (int i = 0; i < numArrays; i++) {
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
auto xShape = reinterpret_cast<Nd4jLong*>(inShapes[i]);
|
||||||
|
|
||||||
sum += x[shape::getIndexOffset(e, xShape)];
|
sum += x[shape::getIndexOffset(e, xShape)];
|
||||||
}
|
}
|
||||||
|
@ -168,9 +282,9 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mergeAvg_(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
static void mergeAvg_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
std::vector<void*> inBuffers(inArrs.size()), inShapes(inArrs.size());
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
for (int e = 0; e < inArrs.size(); e++) {
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
|
@ -179,28 +293,111 @@ namespace sd {
|
||||||
|
|
||||||
PointersManager manager(context, "mergeAvg");
|
PointersManager manager(context, "mergeAvg");
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
auto pInBuffers = reinterpret_cast<void**>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*)));
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
auto pInShapes = reinterpret_cast<void**>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*)));
|
||||||
auto length = output.lengthOf();
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
global_mergeAvg_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeAvgCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (pInBuffers, pInShapes, (int)inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAvg(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
NDArray::prepareSpecialUse({&output}, {});
|
|
||||||
for (auto v:inArrs)
|
NDArray::prepareSpecialUse({ &output }, inArrs);
|
||||||
v->syncToDevice();
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&output}, {});
|
NDArray::registerSpecialUse({ &output }, inArrs);
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void mergeAvgBpCudaLauncher(void* vgradient, Nd4jLong* gradientShape, void** outArrs, void** outShapes,
|
||||||
|
const int numArrays, Nd4jLong length, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
auto grad = reinterpret_cast<T*>(vgradient);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = e, gradOffset = e;
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
shape::index2coords(e, gradientShape, coords);
|
||||||
|
gradOffset = shape::getOffset(gradientShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
auto outShape = reinterpret_cast<Nd4jLong*>(outShapes[i]);
|
||||||
|
zOffset = shape::getOffset(outShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = reinterpret_cast<T*>(outArrs[i]);
|
||||||
|
|
||||||
|
output[zOffset] = grad[gradOffset] / numArrays;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAvgBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
int nArrSize = static_cast<int>(outArrs.size());
|
||||||
|
|
||||||
|
std::vector<void*> outBuffers(nArrSize), outShapes(nArrSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrSize; e++) {
|
||||||
|
outBuffers[e] = outArrs[e]->getSpecialBuffer();
|
||||||
|
outShapes[e] = outArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeAvgBp");
|
||||||
|
|
||||||
|
auto pOutBuffers = reinterpret_cast<void**>(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*)));
|
||||||
|
auto pOutShapes = reinterpret_cast<void**>(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*)));
|
||||||
|
|
||||||
|
auto length = gradient.lengthOf();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeAvgBpCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
const std::vector<const NDArray*>& out = reinterpret_cast<const std::vector<const NDArray*>&>(outArrs);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse( out, { &gradient });
|
||||||
|
|
||||||
|
bool bSameOrderAndEws1 = (1 == gradient.ews());
|
||||||
|
auto ordering = gradient.ordering();
|
||||||
|
|
||||||
|
for (const auto& v : outArrs) {
|
||||||
|
bSameOrderAndEws1 &= (ordering == v->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == v->ews());
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse(out, { &gradient });
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void mergeAddCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, Nd4jLong* outputShape, Nd4jLong length) {
|
||||||
|
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -211,7 +408,7 @@ namespace sd {
|
||||||
|
|
||||||
for (int i = 0; i < numArrays; i++) {
|
for (int i = 0; i < numArrays; i++) {
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
auto xShape = reinterpret_cast<Nd4jLong*>(inShapes[i]);
|
||||||
|
|
||||||
sum += x[shape::getIndexOffset(e, xShape)];
|
sum += x[shape::getIndexOffset(e, xShape)];
|
||||||
}
|
}
|
||||||
|
@ -221,36 +418,120 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mergeAdd_(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
static void mergeAdd_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
std::vector<void *> inBuffers(inArrs.size());
|
|
||||||
std::vector<void *> inShapes(inArrs.size());
|
|
||||||
|
|
||||||
for (int e = 0; e < inArrs.size(); e++) {
|
int nArrSize = static_cast<int>(inArrs.size());
|
||||||
|
std::vector<void*> inBuffers(nArrSize), inShapes(nArrSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrSize; e++) {
|
||||||
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
inBuffers[e] = inArrs[e]->getSpecialBuffer();
|
||||||
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
inShapes[e] = inArrs[e]->getSpecialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
PointersManager manager(context, "mergeAdd");
|
PointersManager manager(context, "mergeAdd");
|
||||||
|
|
||||||
auto pInBuffers = reinterpret_cast<void **>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void *)));
|
auto pInBuffers = reinterpret_cast<void**>(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*)));
|
||||||
auto pInShapes = reinterpret_cast<void **>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void *)));
|
auto pInShapes = reinterpret_cast<void**>(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*)));
|
||||||
auto length = output.lengthOf();
|
auto length = output.lengthOf();
|
||||||
|
|
||||||
global_mergeAdd_<T><<<512, 512, 512, *context->getCudaStream()>>>(pInBuffers, pInShapes, (int) inArrs.size(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeAddCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (pInBuffers, pInShapes, nArrSize, output.getSpecialBuffer(), output.getSpecialShapeInfo(), length);
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES);
|
||||||
|
|
||||||
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAdd(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
NDArray::prepareSpecialUse({&output}, {});
|
|
||||||
for (auto v:inArrs)
|
NDArray::prepareSpecialUse({ &output }, inArrs);
|
||||||
v->syncToDevice();
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&output}, {});
|
NDArray::registerSpecialUse({ &output }, inArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void mergeAddBpCudaLauncher(void* vgradient, Nd4jLong* gradientShape, void** outArrs, void** outShapes,
|
||||||
|
const int numArrays, Nd4jLong length, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
auto grad = reinterpret_cast<T*>(vgradient);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = e, gradOffset = e;
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
shape::index2coords(e, gradientShape, coords);
|
||||||
|
gradOffset = shape::getOffset(gradientShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < numArrays; i++) {
|
||||||
|
|
||||||
|
if (!bSameOrderAndEws1) {
|
||||||
|
auto outShape = reinterpret_cast<Nd4jLong*>(outShapes[i]);
|
||||||
|
zOffset = shape::getOffset(outShape, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = reinterpret_cast<T*>(outArrs[i]);
|
||||||
|
|
||||||
|
output[zOffset] = grad[gradOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void mergeAddBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs, bool bSameOrderAndEws1) {
|
||||||
|
|
||||||
|
int nArrSize = static_cast<int>(outArrs.size());
|
||||||
|
|
||||||
|
std::vector<void*> outBuffers(nArrSize), outShapes(nArrSize);
|
||||||
|
|
||||||
|
for (int e = 0; e < nArrSize; e++) {
|
||||||
|
outBuffers[e] = outArrs[e]->getSpecialBuffer();
|
||||||
|
outShapes[e] = outArrs[e]->getSpecialShapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "mergeAddBp");
|
||||||
|
|
||||||
|
auto pOutBuffers = reinterpret_cast<void**>(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*)));
|
||||||
|
auto pOutShapes = reinterpret_cast<void**>(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*)));
|
||||||
|
|
||||||
|
auto length = gradient.lengthOf();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
mergeAddBpCudaLauncher<T> << <blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream() >> > (gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
|
const std::vector<const NDArray*>& out = reinterpret_cast<const std::vector<const NDArray*>& >(outArrs);
|
||||||
|
NDArray::prepareSpecialUse( out, { &gradient });
|
||||||
|
|
||||||
|
bool bSameOrderAndEws1 = (1 == gradient.ews());
|
||||||
|
auto ordering = gradient.ordering();
|
||||||
|
|
||||||
|
for (const auto& v : outArrs) {
|
||||||
|
bSameOrderAndEws1 &= (ordering == v->ordering());
|
||||||
|
bSameOrderAndEws1 &= (1 == v->ews());
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse( out, { &gradient });
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -52,13 +52,16 @@ namespace helpers {
|
||||||
|
|
||||||
void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions);
|
void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output);
|
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||||
|
|
||||||
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output);
|
void mergeMax(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||||
|
void mergeMaxBp(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs);
|
||||||
|
|
||||||
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output);
|
void mergeAvg(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||||
|
void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
|
||||||
|
|
||||||
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output);
|
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||||
|
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
|
||||||
|
|
||||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
||||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
|
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
|
||||||
|
|
|
@ -955,7 +955,160 @@ TEST_F(DeclarableOpsTests13, mergemax_2) {
|
||||||
|
|
||||||
ASSERT_EQ(20, status);
|
ASSERT_EQ(20, status);
|
||||||
}
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, mergemax_bp_1) {
|
||||||
|
|
||||||
|
NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x1.assign(3);
|
||||||
|
x2.assign(1);
|
||||||
|
x3.assign(2);
|
||||||
|
grad.linspace(.1, .1);
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::mergemax_bp op;
|
||||||
|
auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(3, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(grad.isSameShape(z));
|
||||||
|
ASSERT_TRUE(grad.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, mergemax_bp_2) {
|
||||||
|
|
||||||
|
NDArray x1('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
grad.linspace(.1, .1);
|
||||||
|
|
||||||
|
NDArray exp1('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp2('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp3('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::mergemax_bp op;
|
||||||
|
auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(3, result.size());
|
||||||
|
|
||||||
|
auto z1 = result.at(0);
|
||||||
|
auto z2 = result.at(1);
|
||||||
|
auto z3 = result.at(2);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
||||||
|
ASSERT_TRUE(exp2.isSameShape(z2));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(z2));
|
||||||
|
ASSERT_TRUE(exp3.isSameShape(z3));
|
||||||
|
ASSERT_TRUE(exp3.equalsTo(z3));
|
||||||
|
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, mergemax_bp_3) {
|
||||||
|
|
||||||
|
NDArray x1C('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2C('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3C('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
grad.linspace(.1, .1);
|
||||||
|
|
||||||
|
NDArray x1('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray exp1C('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp2C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp3C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray exp1('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp2('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp3('f', { 2, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x1.assign(x1C);
|
||||||
|
x2.assign(x2C);
|
||||||
|
x3.assign(x3C);
|
||||||
|
|
||||||
|
exp1.assign(exp1C);
|
||||||
|
exp2.assign(exp2C);
|
||||||
|
exp3.assign(exp3C);
|
||||||
|
|
||||||
|
sd::ops::mergemax_bp op;
|
||||||
|
auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(3, result.size());
|
||||||
|
|
||||||
|
auto z1 = result.at(0);
|
||||||
|
auto z2 = result.at(1);
|
||||||
|
auto z3 = result.at(2);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
||||||
|
ASSERT_TRUE(exp2.isSameShape(z2));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(z2));
|
||||||
|
ASSERT_TRUE(exp3.isSameShape(z3));
|
||||||
|
ASSERT_TRUE(exp3.equalsTo(z3));
|
||||||
|
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, mergeadd_bp_1) {
|
||||||
|
|
||||||
|
NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x1.assign(3);
|
||||||
|
x2.assign(1);
|
||||||
|
x3.assign(2);
|
||||||
|
grad.linspace(.1, .1);
|
||||||
|
|
||||||
|
sd::ops::mergeadd_bp op;
|
||||||
|
auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(3, result.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_TRUE(grad.isSameShape(z));
|
||||||
|
ASSERT_TRUE(grad.equalsTo(z));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, mergeavg_bp_1) {
|
||||||
|
|
||||||
|
NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x1.assign(3);
|
||||||
|
x2.assign(1);
|
||||||
|
x3.assign(2);
|
||||||
|
grad.linspace(.1, .1);
|
||||||
|
|
||||||
|
sd::ops::mergeavg_bp op;
|
||||||
|
auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(3, result.size());
|
||||||
|
|
||||||
|
grad.applyScalar(sd::scalar::Divide, 3, grad);
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
auto z = result.at(i);
|
||||||
|
ASSERT_TRUE(grad.isSameShape(z));
|
||||||
|
ASSERT_TRUE(grad.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.nd4j;
|
||||||
|
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public interface TFGraphRunnerService{
|
||||||
|
TFGraphRunnerService init(
|
||||||
|
List<String> inputNames,
|
||||||
|
List<String> outputNames,
|
||||||
|
byte[] graphBytes,
|
||||||
|
Map<String, INDArray> constants,
|
||||||
|
Map<String, String> inputDataTypes
|
||||||
|
);
|
||||||
|
|
||||||
|
Map<String,INDArray> run(Map<String,INDArray> inputs);
|
||||||
|
}
|
|
@ -1654,29 +1654,6 @@ public class SDVariable implements Serializable {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!(o instanceof SDVariable)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SDVariable that = (SDVariable) o;
|
|
||||||
|
|
||||||
if (!Objects.equals(varName, that.varName)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (variableType != that.variableType) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if(sameDiff != that.sameDiff){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return dataType == that.dataType;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
int result = super.hashCode();
|
int result = super.hashCode();
|
||||||
|
@ -1695,4 +1672,26 @@ public class SDVariable implements Serializable {
|
||||||
v.sameDiff = sd;
|
v.sameDiff = sd;
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o){
|
||||||
|
if(o == this) return true;
|
||||||
|
if(!(o instanceof SDVariable))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SDVariable s = (SDVariable)o;
|
||||||
|
if(!varName.equals(s.varName))
|
||||||
|
return false;
|
||||||
|
if(variableType != s.variableType)
|
||||||
|
return false;
|
||||||
|
if(dataType != s.dataType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){
|
||||||
|
INDArray a1 = getArr();
|
||||||
|
INDArray a2 = s.getArr();
|
||||||
|
return a1.equals(a2);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1234,13 +1234,14 @@ public class SameDiff extends SDBaseOps {
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass())
|
||||||
|
return false;
|
||||||
|
|
||||||
SameDiff sameDiff = (SameDiff) o;
|
SameDiff sameDiff = (SameDiff) o;
|
||||||
|
|
||||||
if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null)
|
boolean eqVars = variables.equals(sameDiff.variables);
|
||||||
return false;
|
boolean eqOps = ops.equals(sameDiff.ops);
|
||||||
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
|
return eqVars && eqOps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5843,4 +5844,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
return base + "_" + inc;
|
return base + "_" + inc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.internal;
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -28,6 +25,7 @@ import java.util.List;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data //TODO immutable?
|
@Data //TODO immutable?
|
||||||
@Builder
|
@Builder
|
||||||
|
@EqualsAndHashCode(exclude = {"gradient", "variableIndex"})
|
||||||
public class Variable {
|
public class Variable {
|
||||||
protected String name;
|
protected String name;
|
||||||
protected SDVariable variable;
|
protected SDVariable variable;
|
||||||
|
|
|
@ -173,9 +173,6 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
if(recordMetaData != null){
|
|
||||||
throw new UnsupportedOperationException("Evaluation with record metadata not yet implemented for EvaluationBinary");
|
|
||||||
}
|
|
||||||
eval(labels, networkPredictions, maskArray);
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -325,7 +325,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -229,7 +229,7 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
eval(labels, networkPredictions, maskArray);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
double min = in.minNumber().doubleValue();
|
double min = in.minNumber().doubleValue();
|
||||||
double max = in.maxNumber().doubleValue();
|
double max = in.maxNumber().doubleValue();
|
||||||
double mean = in.meanNumber().doubleValue();
|
double mean = in.meanNumber().doubleValue();
|
||||||
if (min >= 1 && max <= 2 && (in.length() == 1 || Math.abs(mean - 1.5) < 0.1))
|
if (min >= 1 && max <= 2 && (in.length() == 1 || Math.abs(mean - 1.5) < 0.2))
|
||||||
return null;
|
return null;
|
||||||
return "Failed: min = " + min + ", max = " + max + ", mean = " + mean;
|
return "Failed: min = " + min + ", max = " + max + ", mean = " + mean;
|
||||||
};
|
};
|
||||||
|
@ -87,7 +87,7 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
checkFn = in -> {
|
checkFn = in -> {
|
||||||
double mean = in.meanNumber().doubleValue();
|
double mean = in.meanNumber().doubleValue();
|
||||||
double stdev = in.std(true).getDouble(0);
|
double stdev = in.std(true).getDouble(0);
|
||||||
if (in.length() == 1 || (Math.abs(mean - 1) < 0.1 && Math.abs(stdev - 1) < 0.1))
|
if (in.length() == 1 || (Math.abs(mean - 1) < 0.2 && Math.abs(stdev - 1) < 0.2))
|
||||||
return null;
|
return null;
|
||||||
return "Failed: mean = " + mean + ", stdev = " + stdev;
|
return "Failed: mean = " + mean + ", stdev = " + stdev;
|
||||||
};
|
};
|
||||||
|
|
|
@ -3556,4 +3556,52 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided"));
|
assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEquals1(){
|
||||||
|
|
||||||
|
SameDiff sd1 = SameDiff.create();
|
||||||
|
SameDiff sd2 = SameDiff.create();
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable p1 = sd1.placeHolder("ph", DataType.FLOAT, -1, 10);
|
||||||
|
SDVariable p2 = sd2.placeHolder("ph", DataType.FLOAT, -1, 10);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable w1 = sd1.constant("c1",1.0f);
|
||||||
|
SDVariable w2 = sd2.constant("c1",1.0f);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable a1 = p1.add("add", w1);
|
||||||
|
SDVariable a2 = p2.add("add", w2);
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable w1a = sd1.constant("c2", 2.0f);
|
||||||
|
SDVariable w2a = sd2.constant("cX", 2.0f);
|
||||||
|
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
w2a.rename("c2");
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
sd2.createGradFunction("ph");
|
||||||
|
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
w2a.getArr().assign(3.0f);
|
||||||
|
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
|
||||||
|
w1a.getArr().assign(3.0f);
|
||||||
|
assertEquals(sd1, sd2);
|
||||||
|
|
||||||
|
SDVariable s1 = p1.sub("op", w1);
|
||||||
|
SDVariable s2 = p2.add("op", w1);
|
||||||
|
assertNotEquals(sd1, sd2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 180000L; //Can be slow on some CI machines such as PPC
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -95,7 +95,7 @@ public class Downloader {
|
||||||
}
|
}
|
||||||
// try extracting
|
// try extracting
|
||||||
try{
|
try{
|
||||||
ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(f.getAbsolutePath(), extractToDir.getAbsolutePath(), false);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t);
|
log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t);
|
||||||
f.delete();
|
f.delete();
|
||||||
|
|
|
@ -51,6 +51,10 @@ public class ArchiveUtils {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public static void unzipFileTo(String file, String dest) throws IOException {
|
public static void unzipFileTo(String file, String dest) throws IOException {
|
||||||
|
unzipFileTo(file, dest, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException {
|
||||||
File target = new File(file);
|
File target = new File(file);
|
||||||
if (!target.exists())
|
if (!target.exists())
|
||||||
throw new IllegalArgumentException("Archive doesnt exist");
|
throw new IllegalArgumentException("Archive doesnt exist");
|
||||||
|
@ -93,7 +97,9 @@ public class ArchiveUtils {
|
||||||
|
|
||||||
fos.close();
|
fos.close();
|
||||||
ze = zis.getNextEntry();
|
ze = zis.getNextEntry();
|
||||||
log.debug("File extracted: " + newFile.getAbsoluteFile());
|
if(logFiles) {
|
||||||
|
log.info("File extracted: " + newFile.getAbsoluteFile());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
zis.closeEntry();
|
zis.closeEntry();
|
||||||
|
@ -112,7 +118,9 @@ public class ArchiveUtils {
|
||||||
TarArchiveEntry entry;
|
TarArchiveEntry entry;
|
||||||
/* Read the tar entries using the getNextEntry method **/
|
/* Read the tar entries using the getNextEntry method **/
|
||||||
while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) {
|
while ((entry = (TarArchiveEntry) tarIn.getNextEntry()) != null) {
|
||||||
|
if(logFiles) {
|
||||||
log.info("Extracting: " + entry.getName());
|
log.info("Extracting: " + entry.getName());
|
||||||
|
}
|
||||||
/* If the entry is a directory, create the directory. */
|
/* If the entry is a directory, create the directory. */
|
||||||
|
|
||||||
if (entry.isDirectory()) {
|
if (entry.isDirectory()) {
|
||||||
|
|
|
@ -16,18 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion.graphrunner;
|
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.*;
|
||||||
import lombok.Singular;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.shade.protobuf.ByteString;
|
import org.nd4j.shade.protobuf.ByteString;
|
||||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.tensorflow.conversion.TensorDataType;
|
import org.nd4j.tensorflow.conversion.TensorDataType;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
@ -56,6 +54,7 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@NoArgsConstructor
|
||||||
public class GraphRunner implements Closeable {
|
public class GraphRunner implements Closeable {
|
||||||
|
|
||||||
private static boolean isTfWarmedUp = false;
|
private static boolean isTfWarmedUp = false;
|
||||||
|
@ -103,6 +102,9 @@ public class GraphRunner implements Closeable {
|
||||||
* @param inputDataTypes the expected input data types
|
* @param inputDataTypes the expected input data types
|
||||||
* @param outputDataTypes the expected output data types
|
* @param outputDataTypes the expected output data types
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public GraphRunner(List<String> inputNames,
|
public GraphRunner(List<String> inputNames,
|
||||||
List<String> outputNames,
|
List<String> outputNames,
|
||||||
|
@ -440,6 +442,7 @@ public class GraphRunner implements Closeable {
|
||||||
* @return a map of the output names to the
|
* @return a map of the output names to the
|
||||||
* ndarrays matching each output specified in the graph
|
* ndarrays matching each output specified in the graph
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public Map<String,INDArray> run(Map<String,INDArray> inputs) {
|
public Map<String,INDArray> run(Map<String,INDArray> inputs) {
|
||||||
if (!isTfWarmedUp && !isTfWarmingUp){
|
if (!isTfWarmedUp && !isTfWarmingUp){
|
||||||
isTfWarmingUp = true;
|
isTfWarmingUp = true;
|
||||||
|
@ -683,4 +686,7 @@ public class GraphRunner implements Closeable {
|
||||||
|
|
||||||
return builder1.build();
|
return builder1.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||||
|
|
||||||
|
import org.nd4j.TFGraphRunnerService;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.tensorflow.conversion.TensorDataType;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class GraphRunnerServiceProvider implements TFGraphRunnerService {
|
||||||
|
|
||||||
|
private GraphRunner graphRunner;
|
||||||
|
Map<String, INDArray> inputs;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TFGraphRunnerService init(
|
||||||
|
List<String> inputNames,
|
||||||
|
List<String> outputNames,
|
||||||
|
byte[] graphBytes,
|
||||||
|
Map<String, INDArray> constants,
|
||||||
|
Map<String, String> inputDataTypes){
|
||||||
|
if (inputNames.size() != inputDataTypes.size()){
|
||||||
|
throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()");
|
||||||
|
}
|
||||||
|
Map<String, TensorDataType> convertedDataTypes = new HashMap<>();
|
||||||
|
for (int i = 0; i < inputNames.size(); i++){
|
||||||
|
convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i))));
|
||||||
|
}
|
||||||
|
Map<String, INDArray> castConstants = new HashMap<>();
|
||||||
|
for (Map.Entry<String, INDArray> e: constants.entrySet()) {
|
||||||
|
DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey())));
|
||||||
|
castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype));
|
||||||
|
}
|
||||||
|
this.inputs = castConstants;
|
||||||
|
graphRunner = GraphRunner.builder().inputNames(inputNames)
|
||||||
|
.outputNames(outputNames).graphBytes(graphBytes)
|
||||||
|
.inputDataTypes(convertedDataTypes).build();
|
||||||
|
return this;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> run(Map<String, INDArray> inputs){
|
||||||
|
if (graphRunner == null){
|
||||||
|
throw new RuntimeException("GraphRunner not initialized.");
|
||||||
|
}
|
||||||
|
this.inputs.putAll(inputs);
|
||||||
|
return graphRunner.run(this.inputs);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2020 Konduit K.K..
|
||||||
|
#
|
||||||
|
# This program and the accompanying materials are made available under the
|
||||||
|
# terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
org.nd4j.tensorflow.conversion.graphrunner.GraphRunnerServiceProvider
|
4
pom.xml
4
pom.xml
|
@ -292,9 +292,9 @@
|
||||||
<javacpp-presets.version>1.5.3-SNAPSHOT</javacpp-presets.version>
|
<javacpp-presets.version>1.5.3-SNAPSHOT</javacpp-presets.version>
|
||||||
<javacv.version>1.5.3-SNAPSHOT</javacv.version>
|
<javacv.version>1.5.3-SNAPSHOT</javacv.version>
|
||||||
|
|
||||||
<python.version>3.7.6</python.version>
|
<python.version>3.7.7</python.version>
|
||||||
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
||||||
<numpy.version>1.18.1</numpy.version>
|
<numpy.version>1.18.2</numpy.version>
|
||||||
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
||||||
|
|
||||||
<openblas.version>0.3.9</openblas.version>
|
<openblas.version>0.3.9</openblas.version>
|
||||||
|
|
Loading…
Reference in New Issue