diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java index 6fcffd083..4300b3b32 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java @@ -124,7 +124,6 @@ public class TestUtils { public static INDArray randomOneHot(long examples, long nOut, Random rng){ INDArray arr = Nd4j.create(examples, nOut); for( int i=0; i p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces()); + Pair pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces()); + + INDArray dldin_dl4j = p.getSecond(); + + System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond())); + } + +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index 1fc3cc391..df072b64f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -132,7 +132,6 @@ public class TestUtils { public static INDArray randomOneHot(long examples, long nOut, Random rng){ INDArray arr = Nd4j.create(examples, nOut); for( int i=0; i pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + Field f = bn.getClass().getDeclaredField("helper"); + f.setAccessible(true); + f.set(bn, null); + assertNull(bn.getHelper()); + + net.output(in, true); + bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); + Pair p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + INDArray dldin_dl4j = p.getSecond(); + INDArray dldin_helper = pcudnn.getSecond(); + + assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5)); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index 95c04d154..af15f3b45 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -340,7 +341,8 @@ public class BackPropMLPTest extends BaseDL4JTest { public static float[] asFloat(INDArray arr) { long len = arr.length(); - // FIXME: int cast + if (len > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); float[] f = new float[(int) len]; NdIndexIterator iterator = new NdIndexIterator('c', arr.shape()); for (int i = 0; i < len; i++) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index a1512139f..73ebf1ccd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -320,7 +320,6 @@ public class MultiLayerTest extends BaseDL4JTest { public static float[] asFloat(INDArray arr) { long len = arr.length(); - // FIXME: int cast float[] f = new float[(int) len]; for (int i = 0; i < len; i++) f[i] = arr.getFloat(i); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index f5049d211..0a17441bc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -331,7 +331,6 @@ public class TestUpdaters extends BaseDL4JTest { double calculatedByHandMScalar = 0.2; double[] expectedM = Nd4j.ones(1, numParams).mul(calculatedByHandMScalar).data().asDouble(); - // FIXME: int cast double[] actualM = Arrays.copyOfRange(nadamUpdater.getM().data().asDouble(), 0, (int) numParams); for (int i = 0; i < actualM.length; i++) { actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 7aa86c0a2..a7ce1622f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -48,6 +48,7 @@ import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.learning.config.AdaGrad; @@ -664,8 +665,10 @@ public class TestOptimizers extends BaseDL4JTest { double xlm1 = parameters.getDouble(nDims - 2); double gl = 200 * (xl - xlm1 * xlm1); - // FIXME: int cast - gradient.put(0, (int) nDims - 1, gl); + if (nDims - 1 > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + gradient.put(0, (int)nDims - 1, gl); Gradient g = new DefaultGradient(); g.gradientForVariable().put("W", gradient); this.gradient = g; @@ -865,8 +868,7 @@ public class TestOptimizers extends BaseDL4JTest { @Override public long numParams() { - // FIXME: int cast - return (int) parameters.length(); + return parameters.length(); } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index b78a06093..12564f01a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest { SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3); SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10); - SDVariable b0 = sd.zero("b0", 1, 10); + SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10)); SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3); - SDVariable b1 = sd.zero("b1", 1, 3); + SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3)); SDVariable z0 = in.mmul(w0).add(b0); SDVariable a0 = sd.nn().tanh(z0); @@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest { Map placeholders = new HashMap<>(); placeholders.put("input", f); placeholders.put("label", l); - sd.exec(placeholders, lossMse.getVarName()); - INDArray outSd = a1.getArr(); + Map map = sd.output(placeholders, lossMse.name(), a1.name()); + INDArray outSd = map.get(a1.name()); INDArray outDl4j = net.output(f); assertEquals(testName, outDl4j, outSd); @@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check score double scoreDl4j = net.score(); - double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore(); + double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore(); assertEquals(testName, scoreDl4j, scoreSd, 1e-6); double lossRegScoreSD = sd.calcRegularizationScore(); @@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check gradients (before updater applied) Map grads = net.gradient().gradientForVariable(); - sd.execBackwards(placeholders); + Map gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name()); //Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only //We can check correctness though with training param checks later if(l1Val == 0 && l2Val == 0 && wdVal == 0) { - assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr()); - assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr()); - assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr()); - assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr()); + assertEquals(testName, grads.get("1_b"), gm.get(b1.name())); + assertEquals(testName, grads.get("1_W"), gm.get(w1.name())); + assertEquals(testName, grads.get("0_b"), gm.get(b0.name())); + assertEquals(testName, grads.get("0_W"), gm.get(w0.name())); } diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java index d296e78a4..6d826c5eb 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java @@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba } @Override - public Pair backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, + public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { this.eps = eps; val miniBatch = (int) input.size(0); @@ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, dstStride[0], dstStride[1], dstStride[2], dstStride[3])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0], - shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0], + (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, @@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba Pointer varCacheData = allocator.getPointer(varCache, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); - checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha, + checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, dBetaData, eps, meanCacheData, varCacheData)); @@ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba @Override - public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, + public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { this.eps = eps; final boolean isHalf = (x.dataType() == DataType.HALF); @@ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, dstStride[0], dstStride[1], dstStride[2], dstStride[3])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0], - shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0], + (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index 89f975293..2f47c2c8b 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -286,8 +286,8 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S for (INDArray w : exampleData) { val n = w.size(0); - // FIXME: int cast - minExamples = (int) Math.min(minExamples, n); + if (Math.min(minExamples, n) < Integer.MAX_VALUE) + minExamples = (int) Math.min(minExamples, n); } } } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java index 50401898c..d04ca652e 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java @@ -366,7 +366,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator { DataSet ds = mdsToDataSet(mds); if (totalOutcomes == -1) { - // FIXME: int cast inputColumns = (int) ds.getFeatures().size(1); totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1); } @@ -394,7 +393,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator { stored = next(); useStored = true; - // FIXME: int cast inputColumns = (int) stored.getFeatures().size(1); totalOutcomes = (int) stored.getLabels().size(1); } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java index d94ee8c97..619b31bdf 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java @@ -172,7 +172,6 @@ public abstract class AbstractDataSetIterator implements DataSetIterator { Pair pair = iterator.next(); if (numFeatures < 1) { if (pair.getFirst() instanceof INDArray) { - // FIXME: int cast numFeatures = (int) ((INDArray) pair.getFirst()).length(); numLabels = (int) ((INDArray) pair.getSecond()).length(); } else if (pair.getFirst() instanceof float[]) { diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorDataSetIterator.java index 2d90db817..23429037b 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorDataSetIterator.java @@ -95,7 +95,6 @@ public class IteratorDataSetIterator implements DataSetIterator { //Set columns etc for later use DataSet temp = list.get(0); - // FIXME: int cast inputColumns = (int) temp.getFeatures().size(1); totalOutcomes = temp.getLabels() == null ? 0 : (int) temp.getLabels().size(1); //May be null for layerwise pretraining } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorMultiDataSetIterator.java index d27fbbc98..822701d83 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/IteratorMultiDataSetIterator.java @@ -73,8 +73,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator { next = iterator.next(); } - // FIXME: int cast - int nExamples = (int) next.getFeatures(0).size(0); + long nExamples = next.getFeatures(0).size(0); if (countSoFar + nExamples <= batchSize) { //Add the entire MultiDataSet as-is list.add(next); @@ -140,7 +139,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator { return out; } - private static INDArray getRange(INDArray arr, int exampleFrom, int exampleToExclusive) { + private static INDArray getRange(INDArray arr, long exampleFrom, long exampleToExclusive) { if (arr == null) return null; diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/BaseFileIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/BaseFileIterator.java index ea16f8a18..01bd0c2a9 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/BaseFileIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/BaseFileIterator.java @@ -134,7 +134,7 @@ public abstract class BaseFileIterator implements Iterator { List remainder = new ArrayList<>(); int soFar = 0; for (T t : toMerge) { - int size = sizeOf(t); + long size = sizeOf(t); if (soFar + size <= batchSize) { correctNum.add(t); @@ -190,7 +190,7 @@ public abstract class BaseFileIterator implements Iterator { protected abstract T load(File f); - protected abstract int sizeOf(T of); + protected abstract long sizeOf(T of); protected abstract List split(T toSplit); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/FileDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/FileDataSetIterator.java index 8e6da3b0e..714f1a22c 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/FileDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/FileDataSetIterator.java @@ -151,7 +151,7 @@ public class FileDataSetIterator extends BaseFileIterator list) { + long[] retVal = new long[list.size()]; + for (int i = 0; i < list.size(); ++i) { + retVal[i] = list.get(i); + } + return retVal; + } /** * Constructor from parsed Keras layer configuration dictionary. * @@ -67,9 +75,7 @@ public class KerasReshape extends KerasLayer { if (innerConfig.containsKey(targetShape)) { @SuppressWarnings("unchecked") List targetShapeList = (List) innerConfig.get(targetShape); - - // FIXME: int cast - this.targetShape = ArrayUtil.toLongArray(ArrayUtil.toArray(targetShapeList)); + this.targetShape = listToLongArray(targetShapeList); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index e5caa3e3a..874931262 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -690,13 +690,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray testLabels = Nd4j.create(predictionsDl4j.shape()); if (testLabels.rank() == 2) { for (int i = 0; i < testLabels.size(0); i++) { - // FIXME: int cast testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0); } } else if (testLabels.rank() == 3) { for (int i = 0; i < testLabels.size(0); i++) { for (int j = 0; j < testLabels.size(1); j++) { - // FIXME: int cast testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0); } } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java index 74574ffb5..b9523f30b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java @@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree; import lombok.val; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.custom.KnnMinDistance; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; import java.io.Serializable; import java.util.ArrayList; @@ -28,79 +31,103 @@ import java.util.List; */ public class HyperRect implements Serializable { - private List points; + //private List points; + private float[] lowerEnds; + private float[] higherEnds; + private INDArray lowerEndsIND; + private INDArray higherEndsIND; - public HyperRect(List points) { - //this.points = points; - this.points = new ArrayList<>(points.size()); - for (int i = 0; i < points.size(); ++i) { - Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher); - this.points.add(newInterval); - } + public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { + this.lowerEnds = new float[lowerEndsIn.length]; + this.higherEnds = new float[lowerEndsIn.length]; + System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); + System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); + lowerEndsIND = Nd4j.createFromArray(lowerEnds); + higherEndsIND = Nd4j.createFromArray(higherEnds); + } + + public HyperRect(float[] point) { + this(point, point); + } + + public HyperRect(Pair ends) { + this(ends.getFirst(), ends.getSecond()); } public void enlargeTo(INDArray point) { - for (int i = 0; i < points.size(); i++) - points.get(i).enlarge(point.getDouble(i)); + float[] pointAsArray = point.toFloatVector(); + for (int i = 0; i < lowerEnds.length; i++) { + float p = pointAsArray[i]; + if (lowerEnds[i] > p) + lowerEnds[i] = p; + else if (higherEnds[i] < p) + higherEnds[i] = p; + } } - - public static List point(INDArray vector) { - List ret = new ArrayList<>(); + public static Pair point(INDArray vector) { + Pair ret = new Pair<>(); + float[] curr = new float[(int)vector.length()]; for (int i = 0; i < vector.length(); i++) { - double curr = vector.getDouble(i); - ret.add(new Interval(curr, curr)); + curr[i] = vector.getFloat(i); } + ret.setFirst(curr); + ret.setSecond(curr); return ret; } - public List contains(INDArray hPoint) { + /*public List contains(INDArray hPoint) { List ret = new ArrayList<>(); - for (int i = 0; i < hPoint.length(); i++) - ret.add(points.get(i).contains(hPoint.getDouble(i))); - return ret; - } - - public double minDistance(INDArray hPoint) { - double ret = 0.0; for (int i = 0; i < hPoint.length(); i++) { - double p = hPoint.getDouble(i); - Interval interval = points.get(i); - if (!interval.contains(p)) { - if (p < interval.lower) - ret += Math.pow((p - interval.lower), 2); - else - ret += Math.pow((p - interval.higher), 2); - } + ret.add(lowerEnds[i] <= hPoint.getDouble(i) && + higherEnds[i] >= hPoint.getDouble(i)); } - - ret = Math.pow(ret, 0.5); return ret; + }*/ + + public double minDistance(INDArray hPoint, INDArray output) { + Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); + return output.getFloat(0); + + /*double ret = 0.0; + double[] pointAsArray = hPoint.toDoubleVector(); + for (int i = 0; i < pointAsArray.length; i++) { + double p = pointAsArray[i]; + if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { + if (p < lowerEnds[i]) + ret += Math.pow((p - lowerEnds[i]), 2); + else + ret += Math.pow((p - higherEnds[i]), 2); + } + } + ret = Math.pow(ret, 0.5); + return ret;*/ } public HyperRect getUpper(INDArray hPoint, int desc) { - Interval interval = points.get(desc); - double d = hPoint.getDouble(desc); - if (interval.higher < d) + //Interval interval = points.get(desc); + float higher = higherEnds[desc]; + float d = hPoint.getFloat(desc); + if (higher < d) return null; - HyperRect ret = new HyperRect(new ArrayList<>(points)); - Interval i2 = ret.points.get(desc); - if (i2.lower < d) - i2.lower = d; + HyperRect ret = new HyperRect(lowerEnds,higherEnds); + if (ret.lowerEnds[desc] < d) + ret.lowerEnds[desc] = d; return ret; } public HyperRect getLower(INDArray hPoint, int desc) { - Interval interval = points.get(desc); - double d = hPoint.getDouble(desc); - if (interval.lower > d) + //Interval interval = points.get(desc); + float lower = lowerEnds[desc]; + float d = hPoint.getFloat(desc); + if (lower > d) return null; - HyperRect ret = new HyperRect(new ArrayList<>(points)); - Interval i2 = ret.points.get(desc); - if (i2.higher > d) - i2.higher = d; + HyperRect ret = new HyperRect(lowerEnds,higherEnds); + //Interval i2 = ret.points.get(desc); + if (ret.higherEnds[desc] > d) + ret.higherEnds[desc] = d; return ret; } @@ -108,33 +135,10 @@ public class HyperRect implements Serializable { public String toString() { String retVal = ""; retVal += "["; - for (val point : points) { - retVal += "(" + point.lower + " - " + point.higher + ") "; + for (int i = 0; i < lowerEnds.length; ++i) { + retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") "; } retVal += "]"; return retVal; } - - public static class Interval { - private double lower, higher; - - public Interval(double lower, double higher) { - this.lower = lower; - this.higher = higher; - } - - public boolean contains(double point) { - return lower <= point || point <= higher; - - } - - public void enlarge(double p) { - if (lower > p) - lower = p; - else if (higher < p) - higher = p; - } - - } - } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java index c5e2452f3..3e0b90119 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java @@ -56,7 +56,7 @@ public class KDTree implements Serializable { if (root == null) { root = new KDNode(point); - rect = new HyperRect(HyperRect.point(point)); + rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); } else { int disc = 0; KDNode node = root; @@ -125,15 +125,21 @@ public class KDTree implements Serializable { return node.getPoint(); } + // Share this data for recursive calls of "knn" + private float currentDistance; + private INDArray currentPoint; + private INDArray minDistance = Nd4j.scalar(0.f); - public List> knn(INDArray point, double distance) { - List> best = new ArrayList<>(); - knn(root, point, rect, distance, best, 0); - Collections.sort(best, new Comparator>() { + public List> knn(INDArray point, float distance) { + List> best = new ArrayList<>(); + currentDistance = distance; + currentPoint = point; + knn(root, rect, best, 0); + Collections.sort(best, new Comparator>() { @Override - public int compare(Pair o1, Pair o2) { - return Double.compare(o1.getKey(), o2.getKey()); + public int compare(Pair o1, Pair o2) { + return Float.compare(o1.getKey(), o2.getKey()); } }); @@ -141,22 +147,21 @@ public class KDTree implements Serializable { } - private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List> best, - int _disc) { - if (node == null || rect == null || rect.minDistance(point) > dist) + private void knn(KDNode node, HyperRect rect, List> best, int _disc) { + if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) return; int _discNext = (_disc + 1) % dims; - double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult() - .doubleValue(); + float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() + .floatValue(); - if (distance <= dist) { + if (distance <= currentDistance) { best.add(Pair.of(distance, node.getPoint())); } HyperRect lower = rect.getLower(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc); - knn(node.getLeft(), point, lower, dist, best, _discNext); - knn(node.getRight(), point, upper, dist, best, _discNext); + knn(node.getLeft(), lower, best, _discNext); + knn(node.getRight(), upper, best, _discNext); } /** @@ -171,7 +176,7 @@ public class KDTree implements Serializable { private Pair nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, int _disc) { - if (node == null || rect.minDistance(point) > dist) + if (node == null || rect.minDistance(point, minDistance) > dist) return Pair.of(Double.POSITIVE_INFINITY, null); int _discNext = (_disc + 1) % dims; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 1de7a379b..618ee0c94 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -16,6 +16,8 @@ package org.deeplearning4j.clustering.kdtree; +import org.joda.time.Instant; +import org.nd4j.shade.guava.base.Stopwatch; import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.deeplearning4j.clustering.BaseDL4JTest; @@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.guava.primitives.Floats; import org.opencv.ml.KNearest; import java.util.ArrayList; @@ -35,6 +38,8 @@ import java.util.Arrays; import java.util.List; import java.util.Random; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest { @Before public void setUp() { kdTree = new KDTree(2); - double[] data = new double[]{7,2}; + float[] data = new float[]{7,2}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{5,4}; + data = new float[]{5,4}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{2,3}; + data = new float[]{2,3}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{4,7}; + data = new float[]{4,7}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{9,6}; + data = new float[]{9,6}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{8,1}; + data = new float[]{8,1}; kdTree.insert(Nd4j.createFromArray(data)); } @@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void testKNN() { - int n = 10; - // make a KD-tree of dimension {#n} - KDTree kdTree = new KDTree(n); - for (int i = -1; i < n; i++) { + int dimensions = 512; + int vectorsNo = 50000; + // make a KD-tree of dimension {#dimensions} + Stopwatch stopwatch = Stopwatch.createStarted(); + KDTree kdTree = new KDTree(dimensions); + for (int i = -1; i < vectorsNo; i++) { // Insert a unit vector along each dimension - List vec = new ArrayList<>(n); - // i = -1 ensures the origin is in the Tree - for (int k = 0; k < n; k++) { - vec.add((k == i) ? 1.0 : 0.0); - } - INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec))); + INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); kdTree.insert(indVec); } + stopwatch.stop(); + System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); + Random rand = new Random(); // random point in the Hypercube - List pt = new ArrayList(n); - for (int k = 0; k < n; k++) { - pt.add(rand.nextDouble() * 10.0); + List pt = new ArrayList(dimensions); + for (int k = 0; k < dimensions; k++) { + pt.add(rand.nextFloat() * 10.0); } - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); + stopwatch.reset(); + stopwatch.start(); + List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); + stopwatch.stop(); + System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); } @Test @@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest { int n = 2; KDTree kdTree = new KDTree(n); - double[] data = new double[]{3,3}; + float[] data = new float[]{3,3}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{1,1}; + data = new float[]{1,1}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{2,2}; + data = new float[]{2,2}; kdTree.insert(Nd4j.createFromArray(data)); - data = new double[]{0,0}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5); + data = new float[]{0,0}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); @@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest { assertEquals(6, kdTree.size()); - double[] data = new double[]{8,1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0); - assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{8,1}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); + assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_2() { - double[] data = new double[]{8, 1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0); - assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{8, 1}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); + assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_3() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_4() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); } @Test public void testKNN_5() { - double[] data = new double[]{2, 3}; - val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); - assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); + float[] data = new float[]{2, 3}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); + assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); + assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); + assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); + assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); + assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); + assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); + assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); + assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); + assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); + assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); + assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); + assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); } @Test public void test_KNN_6() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void test_KNN_7() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void test_KNN_8() { - double[] data = new double[]{4, 6}; - val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); + float[] data = new float[]{4, 6}; + List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); @@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest { Duration duration = new Duration(start, end); System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); - List pt = new ArrayList(num); + List pt = new ArrayList(num); for (int k = 0; k < n; k++) { - pt.add((double)(num / 2)); + pt.add((float)(num / 2)); } start = System.currentTimeMillis(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); + List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); end = System.currentTimeMillis(); duration = new Duration(start, end); long elapsed = end - start; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 01b38a644..736998484 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -50,6 +50,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; import java.util.*; import static org.junit.Assert.*; @@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest { assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money")); } + @Test + public void testWordsNearestSum() throws IOException { + log.info("Load & Vectorize Sentences...."); + SentenceIterator iter = new BasicLineIterator(inputFile); + TokenizerFactory t = new DefaultTokenizerFactory(); + t.setTokenPreProcessor(new CommonPreprocessor()); + + log.info("Building model...."); + Word2Vec vec = new Word2Vec.Builder() + .minWordFrequency(5) + .iterations(1) + .layerSize(100) + .seed(42) + .windowSize(5) + .iterate(iter) + .tokenizerFactory(t) + .build(); + + log.info("Fitting Word2Vec model...."); + vec.fit(); + log.info("Writing word vectors to text file...."); + log.info("Closest Words:"); + Collection lst = vec.wordsNearestSum("day", 10); + log.info("10 Words closest to 'day': {}", lst); + assertTrue(lst.contains("week")); + assertTrue(lst.contains("night")); + assertTrue(lst.contains("year")); + assertTrue(lst.contains("years")); + assertTrue(lst.contains("time")); + } + private static void printWords(String target, Collection list, Word2Vec vec) { System.out.println("Words close to [" + target + "]:"); for (String word : list) { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java index 579caa0a3..086540090 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java @@ -104,7 +104,7 @@ public class InMemoryLookupTable implements WeightLoo } protected void initAdaGrad() { - int[] shape = new int[] {vocab.numWords() + 1, vectorLength}; + long[] shape = new long[] {vocab.numWords() + 1, vectorLength}; int length = ArrayUtil.prod(shape); adaGrad = new AdaGrad(shape, lr.get()); adaGrad.setStateViewArray(Nd4j.zeros(shape).reshape(1, length), shape, Nd4j.order(), true); @@ -124,8 +124,7 @@ public class InMemoryLookupTable implements WeightLoo if (adaGrad == null) initAdaGrad(); - // FIXME: int cast - return adaGrad.getGradient(gradient, column, ArrayUtil.toInts(syn0.shape())); + return adaGrad.getGradient(gradient, column, syn0.shape()); } @Override @@ -370,7 +369,6 @@ public class InMemoryLookupTable implements WeightLoo else { nextRandom.set(nextRandom.get() * 25214903917L + 11); - // FIXME: int cast int idx = (int) Math.abs((int) (nextRandom.get() >> 16) % table.length()); target = table.getInt(idx); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java index 80a2b6565..fdfd91926 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java @@ -33,7 +33,6 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.aggregates.Aggregate; -import org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW; import org.nd4j.linalg.api.ops.impl.nlp.CbowRound; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.DeviceLocalNDArray; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java index 71cf2c693..01bf7affd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java @@ -104,11 +104,10 @@ public class GloVe implements ElementsLearningAlgorit - weightAdaGrad = new AdaGrad(new int[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate); + weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate); bias = Nd4j.create(syn0.rows()); - // FIXME: int cast - biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), this.learningRate); + biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate); // maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8); @@ -237,15 +236,13 @@ public class GloVe implements ElementsLearningAlgorit private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) { //gradient for word vectors INDArray grad1 = contextVector.mul(gradient); - // FIXME: int cast - INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), ArrayUtil.toInts(syn0.shape())); + INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape()); //update vector wordVector.subi(update); double w1Bias = bias.getDouble(element1.getIndex()); - // FIXME: int cast - double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), ArrayUtil.toInts(bias.shape())); + double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape()); double update2 = w1Bias - biasGradient; bias.putScalar(element1.getIndex(), update2); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index 84fc17b7e..bc404ac14 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -351,13 +351,13 @@ public class BasicModelUtils implements ModelUtils if (lookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; INDArray syn0 = l.getSyn0(); - INDArray weights = syn0.norm2(0).rdivi(1).muli(words); + INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape()); + INDArray weights = temp.muli(words); INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0]; List ret = new ArrayList<>(); - // FIXME: int cast if (top > sort.length()) top = (int) sort.length(); //there will be a redundant word diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java index cb6c48872..1cda50100 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java @@ -72,7 +72,7 @@ public class GloveWeightLookupTable extends InMemoryL putVector(Word2Vec.DEFAULT_UNK, randUnk); } if (weightAdaGrad == null || reset) { - weightAdaGrad = new AdaGrad(new int[] {vocab.numWords() + 1, vectorLength}, lr.get()); + weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get()); } @@ -81,7 +81,7 @@ public class GloveWeightLookupTable extends InMemoryL bias = Nd4j.create(syn0.rows()); if (biasAdaGrad == null || reset) { - biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), lr.get()); + biasAdaGrad = new AdaGrad(bias.shape(), lr.get()); } @@ -140,13 +140,13 @@ public class GloveWeightLookupTable extends InMemoryL private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) { //gradient for word vectors INDArray grad1 = contextVector.mul(gradient); - INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape())); + INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); //update vector wordVector.subi(update); double w1Bias = bias.getDouble(w1.getIndex()); - double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape())); + double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); double update2 = w1Bias - biasGradient; bias.putScalar(w1.getIndex(), update2); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index 70d106a59..d15e961b7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.Model; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.function.Consumer; import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; import org.nd4j.linalg.primitives.Pair; @@ -293,7 +294,8 @@ public class GradientCheckUtil { ss = n; } - // FIXME: int cast + if (ss > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); stepSizeForParam.put(paramNames.get(i), (int) ss); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java index 6033f0030..86c0cdf76 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertex.java @@ -140,10 +140,9 @@ public class ElementWiseVertex extends GraphVertex { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; - // FIXME: int cast - val fd = (int) firstConv.getChannels(); - val fw = (int) firstConv.getWidth(); - val fh = (int) firstConv.getHeight(); + val fd = firstConv.getChannels(); + val fw = firstConv.getWidth(); + val fh = firstConv.getHeight(); for (int i = 1; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != InputType.Type.CNN) { @@ -155,10 +154,9 @@ public class ElementWiseVertex extends GraphVertex { InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; - // FIXME: int cast - val od = (int) otherConv.getChannels(); - val ow = (int) otherConv.getWidth(); - val oh = (int) otherConv.getHeight(); + val od = otherConv.getChannels(); + val ow = otherConv.getWidth(); + val oh = otherConv.getHeight(); if (fd != od || fw != ow || fh != oh) { throw new InvalidInputTypeException( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index c76df66f6..77dd41c3a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -94,13 +94,12 @@ public class MergeVertex extends GraphVertex { // CNN3D inputs: check that the channels, width and height match: InputType.InputTypeConvolutional3D firstConv = (InputType.InputTypeConvolutional3D) first; - // FIXME: int cast - val fd = (int) firstConv.getDepth(); - val fw = (int) firstConv.getWidth(); - val fh = (int) firstConv.getHeight(); - val fc = (int) firstConv.getChannels(); + val fd = firstConv.getDepth(); + val fw = firstConv.getWidth(); + val fh = firstConv.getHeight(); + val fc = firstConv.getChannels(); - int depthSum = fc; + long depthSum = fc; InputType.InputTypeConvolutional3D otherConv = null; for (int i = 1; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != InputType.Type.CNN3D) { @@ -109,10 +108,10 @@ public class MergeVertex extends GraphVertex { } otherConv = (InputType.InputTypeConvolutional3D) vertexInputs[i]; - val od = (int) otherConv.getDepth(); - val ow = (int) otherConv.getWidth(); - val oh = (int) otherConv.getHeight(); - val oc = (int) otherConv.getChannels(); + val od = otherConv.getDepth(); + val ow = otherConv.getWidth(); + val oh = otherConv.getHeight(); + val oc = otherConv.getChannels(); if (fd != od || fw != ow || fh != oh) { throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:" + "first [channels,width,height] = [" + fd + "," + fw + "," + fh @@ -177,12 +176,11 @@ public class MergeVertex extends GraphVertex { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; - // FIXME: int cast - val fd = (int) firstConv.getChannels(); - val fw = (int) firstConv.getWidth(); - val fh = (int) firstConv.getHeight(); + val fd = firstConv.getChannels(); + val fw = firstConv.getWidth(); + val fh = firstConv.getHeight(); - int depthSum = fd; + long depthSum = fd; for (int i = 1; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != InputType.Type.CNN) { @@ -194,10 +192,9 @@ public class MergeVertex extends GraphVertex { InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; - // FIXME: int cast - val od = (int) otherConv.getChannels(); - val ow = (int) otherConv.getWidth(); - val oh = (int) otherConv.getHeight(); + val od = otherConv.getChannels(); + val ow = otherConv.getWidth(); + val oh = otherConv.getHeight(); if (fw != ow || fh != oh) { throw new InvalidInputTypeException( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java index 6e2213b4e..c5034129c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/PoolHelperVertex.java @@ -131,12 +131,11 @@ public class PoolHelperVertex extends GraphVertex { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; - // FIXME: int cast - val fd = (int) firstConv.getChannels(); - val fw = (int) firstConv.getWidth(); - val fh = (int) firstConv.getHeight(); + val fd = firstConv.getChannels(); + val fw = firstConv.getWidth(); + val fh = firstConv.getHeight(); - int depthSum = fd; + long depthSum = fd; for (int i = 1; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != InputType.Type.CNN) { @@ -148,10 +147,9 @@ public class PoolHelperVertex extends GraphVertex { InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; - // FIXME: int cast - int od = (int) otherConv.getChannels(); - int ow = (int) otherConv.getWidth(); - int oh = (int) otherConv.getHeight(); + long od = otherConv.getChannels(); + long ow = otherConv.getWidth(); + long oh = otherConv.getHeight(); if (fw != ow || fh != oh) { throw new InvalidInputTypeException( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java index a5d6c72f4..910db350c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/UnstackVertex.java @@ -150,12 +150,11 @@ public class UnstackVertex extends GraphVertex { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; - // FIXME: int cast - val fd = (int) firstConv.getChannels(); - val fw = (int) firstConv.getWidth(); - val fh = (int) firstConv.getHeight(); + val fd = firstConv.getChannels(); + val fw = firstConv.getWidth(); + val fh = firstConv.getHeight(); - int depthSum = fd; + long depthSum = fd; for (int i = 1; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != InputType.Type.CNN) { @@ -167,10 +166,9 @@ public class UnstackVertex extends GraphVertex { InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; - // FIXME: int cast - val od = (int) otherConv.getChannels(); - val ow = (int) otherConv.getWidth(); - val oh = (int) otherConv.getHeight(); + val od = otherConv.getChannels(); + val ow = otherConv.getWidth(); + val oh = otherConv.getHeight(); if (fw != ow || fh != oh) { throw new InvalidInputTypeException( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 85da86fa2..047618661 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -402,18 +402,17 @@ public abstract class InputType implements Serializable { //Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something // like FeedForwardToCnnPreProcessor - // FIXME: int cast switch (inputArray.rank()) { case 2: - return InputType.feedForward((int) inputArray.size(1)); + return InputType.feedForward(inputArray.size(1)); case 3: - return InputType.recurrent((int) inputArray.size(1), (int) inputArray.size(2)); + return InputType.recurrent(inputArray.size(1), (int) inputArray.size(2)); case 4: //Order: [minibatch, channels, height, width] -> [h, w, c] - return InputType.convolutional((int) inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1)); + return InputType.convolutional(inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1)); case 5: //Order: [minibatch, channels, depth, height, width] -> [d, h, w, c] - return InputType.convolutional3D((int) inputArray.size(2), (int) inputArray.size(3), + return InputType.convolutional3D(inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(4), (int) inputArray.size(1)); default: throw new IllegalArgumentException( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index b73265763..1bde3d912 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -152,17 +152,18 @@ public class Cnn3DLossLayer extends FeedForwardLayer { } @Override - public void setNIn(int nIn){ + public void setNIn(long nIn){ throw new UnsupportedOperationException( "Cnn3DLossLayer has no parameters, thus nIn will always equal nOut."); } @Override - public void setNOut(int nOut){ + public void setNOut(long nOut){ throw new UnsupportedOperationException( "Cnn3DLossLayer has no parameters, thus nIn will always equal nOut."); } + @Override @SuppressWarnings("unchecked") public Cnn3DLossLayer build() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 7b25ff797..3bcae0357 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -145,13 +145,13 @@ public class CnnLossLayer extends FeedForwardLayer { } @Override - public void setNIn(int nIn){ + public void setNIn(long nIn){ throw new UnsupportedOperationException( "This layer has no parameters, thus nIn will always equal nOut."); } @Override - public void setNOut(int nOut){ + public void setNOut(long nOut){ throw new UnsupportedOperationException( "This layer has no parameters, thus nIn will always equal nOut."); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index b65f94d00..d4ccc4811 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -88,7 +88,7 @@ public class Convolution1DLayer extends ConvolutionLayer { //Probably: user did InputType.recurrent(x) without specifying sequence length outLength = -1; } else { - outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0], + outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } return InputType.recurrent(nOut, outLength); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index b1bd5eacf..026f0d350 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -117,14 +117,14 @@ public abstract class FeedForwardLayer extends BaseLayer { * this is the input channels, otherwise is the previous layer size. * */ - protected int nIn = 0; + protected long nIn = 0; /** * Number of inputs for the layer (usually the size of the last layer).
Note that for Convolutional layers, * this is the input channels, otherwise is the previous layer size. * */ - protected int nOut = 0; + protected long nOut = 0; /** * Number of inputs for the layer (usually the size of the last layer).
Note that for Convolutional layers, @@ -144,8 +144,7 @@ public abstract class FeedForwardLayer extends BaseLayer { * @param nIn Number of inputs for the layer */ public T nIn(long nIn) { - // FIXME: int cast - this.setNIn((int) nIn); + this.setNIn(nIn); return (T) this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index c8ce1ffe0..7c97930ae 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -41,12 +41,9 @@ public class InputTypeUtil { Class layerClass) { InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; - // FIXME: int cast - val hIn = (int) i.getHeight(); - val wIn = (int) i.getWidth(); + val hIn = i.getHeight(); + val wIn = i.getWidth(); - val inHeight = (int) i.getHeight(); - val inWidth = (int) i.getWidth(); int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same int padW = (padding == null ? 0 : padding[1]); int kH = kernelSize[0]; @@ -69,13 +66,13 @@ public class InputTypeUtil { } if (convolutionMode == ConvolutionMode.Same) { - int hOut = stride[0] * hIn; - int wOut = stride[1] * wIn; + long hOut = stride[0] * hIn; + long wOut = stride[1] * wIn; return InputType.convolutional(hOut, wOut, outputDepth); } - int hOut = sH * (hIn - 1) + kH - 2 * padH; - int wOut = sW * (wIn - 1) + kW - 2 * padW; + long hOut = sH * (hIn - 1) + kH - 2 * padH; + long wOut = sW * (wIn - 1) + kW - 2 * padW; return InputType.convolutional(hOut, wOut, outputDepth); } @@ -91,10 +88,9 @@ public class InputTypeUtil { InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType; - // FIXME: int cast - val inDepth = (int) i.getDepth(); - val inHeight = (int) i.getHeight(); - val inWidth = (int) i.getWidth(); + long inDepth = i.getDepth(); + long inHeight = i.getHeight(); + long inWidth = i.getWidth(); int padD = (padding == null ? 0 : padding[0]); int padH = (padding == null ? 0 : padding[1]); @@ -211,9 +207,9 @@ public class InputTypeUtil { return InputType.convolutional3D(outD, outH, outW, outputChannels); } - int dOut = (inDepth - kD + 2 * padD) / sD + 1; - int hOut = (inHeight - kH + 2 * padH) / sH + 1; - int wOut = (inWidth - kW + 2 * padW) / sW + 1; + long dOut = (inDepth - kD + 2 * padD) / sD + 1; + long hOut = (inHeight - kH + 2 * padH) / sH + 1; + long wOut = (inWidth - kW + 2 * padW) / sW + 1; return InputType.convolutional3D(dOut, hOut, wOut, outputChannels); } @@ -296,9 +292,8 @@ public class InputTypeUtil { InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; - // FIXME: int cast - val inHeight = (int) i.getHeight(); - val inWidth = (int) i.getWidth(); + long inHeight = i.getHeight(); + long inWidth = i.getWidth(); int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same int padW = (padding == null ? 0 : padding[1]); int kH = kernelSize[0]; @@ -379,8 +374,8 @@ public class InputTypeUtil { return InputType.convolutional(outH, outW, outputDepth); } - int hOut = (inHeight - kH + 2 * padH) / sH + 1; - int wOut = (inWidth - kW + 2 * padW) / sW + 1; + long hOut = (inHeight - kH + 2 * padH) / sH + 1; + long wOut = (inWidth - kW + 2 * padW) / sW + 1; return InputType.convolutional(hOut, wOut, outputDepth); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 4787d1082..fc805f0ca 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer { val weightsShape = new long[] {outputSize, featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); if (hasBias) { - val biasShape = new long[] {1, nOut}; + val biasShape = new long[] {nOut}; params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); } } @@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer { if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, result); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 5426fda9b..ef07c9dc5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer { val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); if (hasBias) { - val biasShape = new long[] {1, nOut}; + val biasShape = new long[] {nOut}; params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); } } @@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer { if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, permutedResult); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index 209c61bca..df0b16e6c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -142,13 +142,13 @@ public class RnnLossLayer extends FeedForwardLayer { } @Override - public void setNIn(int nIn){ + public void setNIn(long nIn){ throw new UnsupportedOperationException( "This layer has no parameters, thus nIn will always equal nOut."); } @Override - public void setNOut(int nOut){ + public void setNOut(long nOut){ throw new UnsupportedOperationException( "This layer has no parameters, thus nIn will always equal nOut."); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index de491290f..4da7ff011 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -82,12 +82,12 @@ public class Subsampling1DLayer extends SubsamplingLayer { } InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; long inputTsLength = r.getTimeSeriesLength(); - int outLength; + long outLength; if (inputTsLength < 0) { //Probably: user did InputType.recurrent(x) without specifying sequence length outLength = -1; } else { - outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0], + outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } return InputType.recurrent(r.getSize(), outLength); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 0d0ccba9b..2fcc345a1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -32,6 +32,7 @@ import org.deeplearning4j.util.ValidationUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.Collection; @@ -138,9 +139,11 @@ public class Subsampling3DLayer extends NoParamLayer { + "\"): Expected CNN input, got " + inputType); } - // FIXME: int cast + long inChannels = ((InputType.InputTypeConvolutional3D) inputType).getChannels(); + if (inChannels > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation - convolutionMode, (int) ((InputType.InputTypeConvolutional3D) inputType).getChannels(), + convolutionMode, (int) inChannels, layerIndex, getLayerName(), Subsampling3DLayer.class); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index d142d52a9..24db83bd3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -83,11 +83,10 @@ public class Upsampling3D extends BaseUpsamplingLayer { } InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType; - // FIXME: int cast - int inHeight = (int) i.getHeight(); - int inWidth = (int) i.getWidth(); - int inDepth = (int) i.getDepth(); - int inChannels = (int) i.getChannels(); + long inHeight = (int) i.getHeight(); + long inWidth = (int) i.getWidth(); + long inDepth = (int) i.getDepth(); + long inChannels = (int) i.getChannels(); return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java index f10ae5bad..6890f83e8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java @@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex { defineVertex(temp, tempInputs); List list = new ArrayList<>(); for (Integer i : tempInputs.map.keySet()) { - list.add(tempInputs.map.get(i).getVarName()); + list.add(tempInputs.map.get(i).name()); } params.defineInputs(list.toArray(new String[list.size()])); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java index f76cb0dad..539289eca 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java @@ -259,7 +259,7 @@ public class OCNNOutputLayer extends BaseOutputLayer { } @Override - public void setNOut(int nOut){ + public void setNOut(long nOut){ throw new UnsupportedOperationException( "Unable to specify number of outputs with ocnn. Outputs are fixed to 1."); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 7c292bafa..32d7bfb73 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -79,6 +79,7 @@ import org.nd4j.linalg.dataset.api.DataSetUtil; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.heartbeat.Heartbeat; import org.nd4j.linalg.heartbeat.reports.Environment; @@ -3329,7 +3330,6 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //In 99+% of cases, the input and labels dimension 0 size should be identical //The only real exceptions: space to batch, and batch to space layers //In those cases, we should base it on the labels size, as this impacts gradient calculation - // FIXME: int cast return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0); } @@ -3653,7 +3653,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (endTimeIdx > timeSeriesLength) endTimeIdx = timeSeriesLength; - // FIXME: int cast + if (startTimeIdx > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); List list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks); setInputs(list.get(0)); @@ -3799,9 +3800,10 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } - // FIXME: int cast + if (minibatchSize > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); Pair outPair = - current.feedForwardMaskArrays(inputMasks, maskState, (int) minibatchSize); + current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize); map.put(topologicalOrder[i], outPair); } } @@ -4664,7 +4666,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @return Size of the layer */ - public int layerSize(int layer) { + public long layerSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (layers.length - 1) + " inclusive"); @@ -4683,7 +4685,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @return Size of the layer */ - public int layerInputSize(int layer) { + public long layerInputSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + (layers.length - 1) + " inclusive"); @@ -4701,7 +4703,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @param layerName Name of the layer to get the size of * @return Size of the layer */ - public int layerSize(String layerName) { + public long layerSize(String layerName) { Layer l = getLayer(layerName); if(l == null){ throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); @@ -4712,8 +4714,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } FeedForwardLayer ffl = (FeedForwardLayer) conf; - // FIXME: int cast - return (int) ffl.getNOut(); + return ffl.getNOut(); } /** @@ -4727,7 +4728,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @param layerName Name of the layer to get the size of * @return Size of the layer */ - public int layerInputSize(String layerName) { + public long layerInputSize(String layerName) { Layer l = getLayer(layerName); if(l == null){ throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); @@ -4738,8 +4739,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } FeedForwardLayer ffl = (FeedForwardLayer) conf; - // FIXME: int cast - return (int) ffl.getNIn(); + return ffl.getNIn(); } /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index 73cd7db4d..cb58a9813 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex { } try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ - return Nd4j.hstack(in); + return Nd4j.concat(1, in); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java index 4ca04c418..009505057 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java @@ -43,10 +43,10 @@ import java.util.Arrays; * @author Justin Long (crockpotveggies) */ public class UnstackVertex extends BaseGraphVertex { - private int from; + private long from; private int stackSize; private long forwardShape[]; - private int step; + private long step; public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) { this(graph, name, vertexIndex, null, null, from, stackSize, dataType); @@ -77,10 +77,9 @@ public class UnstackVertex extends BaseGraphVertex { // once we know the inputs, save the shape and interval size for doBackward this.forwardShape = Arrays.copyOf(inputs[0].shape(), inputs[0].rank()); - // FIXME: int cast - this.step = (int) inputs[0].size(0) / stackSize; - int start = from * step; - int end = (from + 1) * step; + this.step = inputs[0].size(0) / stackSize; + long start = from * step; + long end = (from + 1) * step; INDArray ret; switch (inputs[0].rank()) { //TODO remove the dups here if/when possible (gradient checks must pass) @@ -108,8 +107,8 @@ public class UnstackVertex extends BaseGraphVertex { throw new IllegalStateException("Cannot do backward pass: error not set"); INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape); - int start = from * step; - int end = (from + 1) * step; + long start = from * step; + long end = (from + 1) * step; switch (forwardShape.length) { case 2: @@ -154,8 +153,8 @@ public class UnstackVertex extends BaseGraphVertex { } //Mask arrays are either 1d (column vector) or 2d... - int start = from * minibatchSize; - int end = (from + 1) * minibatchSize; + long start = from * minibatchSize; + long end = (from + 1) * minibatchSize; INDArray outMask = maskArrays[0].get(NDArrayIndex.interval(start, end), NDArrayIndex.all()); return new Pair<>(outMask, currentMaskState); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java index a1dd47469..797676441 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java @@ -87,9 +87,8 @@ public class LastTimeStepVertex extends BaseGraphVertex { INDArray out; if (mask == null) { - // FIXME: int cast //No mask array -> extract same (last) column for all - int lastTS = (int) inputs[0].size(2) - 1; + long lastTS = inputs[0].size(2) - 1; out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS)); out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out); fwdPassTimeSteps = null; //Null -> last time step for all examples @@ -99,8 +98,7 @@ public class LastTimeStepVertex extends BaseGraphVertex { //Want the index of the last non-zero entry in the mask array. //Check a little here by using mulRowVector([0,1,2,3,...]) and argmax - // FIXME: int cast - int maxTsLength = (int) fwdPassShape[2]; + long maxTsLength = fwdPassShape[2]; INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType()); INDArray temp = mask.mulRowVector(row); INDArray lastElementIdx = Nd4j.argMax(temp, 1); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index fdf12d2f8..750bca77d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -346,7 +346,6 @@ public abstract class AbstractLayer f mmul here, then reshape to 6d in f order INDArray epsNext2d = w2d.mmul(delta2d); //TODO can we reuse im2col array instead of allocating new result array? - INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new int[] {kW, kH, inDepth, outW, outH, miniBatch}, true); + INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new long[] {kW, kH, inDepth, outW, outH, miniBatch}, true); //Calculate epsilonNext by doing im2col reduction. //Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW] @@ -282,7 +282,7 @@ public class ConvolutionLayer extends BaseLayer Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation ); } else { @@ -397,10 +397,12 @@ public class ConvolutionLayer extends BaseLayer Integer.MAX_VALUE || kW > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); + Convolution.im2col(im2ColIn, (int)kH, (int)kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], convolutionMode == ConvolutionMode.Same, col2); - INDArray im2col2d = Shape.newShapeNoCopy(col, new int[] {miniBatch * outH * outW, inDepth * kH * kW}, false); + INDArray im2col2d = Shape.newShapeNoCopy(col, new long[] {miniBatch * outH * outW, inDepth * kH * kW}, false); //Current order of weights: [depthOut,depthIn,kH,kW], c order //Permute to give [kW,kH,depthIn,depthOut], f order @@ -418,7 +420,7 @@ public class ConvolutionLayer extends BaseLayer Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } pad = ConvolutionUtils.getSameModeTopLeftPadding( outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation); } else { @@ -205,8 +206,8 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); } - int outH = outSize[0]; - int outW = outSize[1]; + long outH = outSize[0]; + long outW = outSize[1]; val miniBatch = input.size(0); INDArray output = workspaceMgr.create( diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index 422a253d2..9808b3a24 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -33,6 +33,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -90,10 +91,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + long miniBatch = input.size(0); + int inH = (int)input.size(2); + int inW = (int)input.size(3); int inDepth = (int) depthWiseWeights.size(1); int kH = (int) depthWiseWeights.size(2); @@ -194,9 +194,8 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { + " " + layerId()); } - // FIXME: int cast - int inDepth = (int) depthWiseWeights.size(1); - int outDepth = (int) pointWiseWeights.size(0); + long inDepth = depthWiseWeights.size(1); + long outDepth = pointWiseWeights.size(0); if (input.size(1) != inDepth) { String layerName = conf.getLayer().getLayerName(); @@ -220,7 +219,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { if (convolutionMode == ConvolutionMode.Same) { outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation - // FIXME: int cast + if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation ); } else { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java index b726ea87c..50ea9c9e3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java @@ -75,11 +75,10 @@ public class SpaceToDepth extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inDepth = (int) input.size(1); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + long miniBatch = input.size(0); + long inDepth = input.size(1); + long inH = input.size(2); + long inW = input.size(3); INDArray input = this.input.castTo(dataType); //No-op if already correct type @@ -122,17 +121,16 @@ public class SpaceToDepth extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - // FIXME: int cast - int miniBatch = (int) input.size(0); - int inDepth = (int) input.size(1); - int inH = (int) input.size(2); - int inW = (int) input.size(3); + long miniBatch = (int) input.size(0); + long inDepth = (int) input.size(1); + long inH = (int) input.size(2); + long inW = (int) input.size(3); INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); @@ -106,15 +105,14 @@ public class Upsampling2D extends AbstractLayer Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int[] indexes = new int[(int) input.length()]; for (int i = 0; i < indexes.length; i++) { indexes[i] = input.getInt(i, 0); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index 1dcc556b6..0d9ae18e7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -16,21 +16,28 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; +import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; +import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; /** @@ -56,33 +63,59 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { } @Override - public Pair backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { - //2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166 - //Also no MKL-DNN implemented for backprop anyway + public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, + INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { + if(input.dataType() != DataType.FLOAT) + return null; //MKL-DNN only supports float /* - INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon}; + //TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335 + List args = new ArrayList<>(); + args.add(input); + args.add(meanCache); + args.add(varCache); + args.add(epsilon); + if(gamma != null) + args.add(gamma.reshape(gamma.length())); + if(beta != null) + args.add(beta.reshape(beta.length())); - INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape()); - INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, } - - BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder() - .applyBeta(gamma != null) - .applyGamma(gamma != null) - .axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases - .epsilon(eps) - .inputArrays(in) - .outputArrays(new INDArray[]{out}) + DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(args.toArray(new INDArray[0])) + .addIntegerArguments( + gamma == null ? 0 : 1, //Apply scale + beta == null ? 0 : 1, //Apply beta + 1) //Axis (NCHW) + .addFloatingPointArguments(eps) .build(); - Nd4j.exec(bn); - */ + INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape()); + INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape()); + + op.setOutputArgument(0, epsAtInput); + op.setOutputArgument(1, dLdm); + op.setOutputArgument(2, dLdv); + if(dGammaView != null) { + //Both are always null/not null simultaneously + op.setOutputArgument(3, dGammaView.reshape(dGammaView.length())); + op.setOutputArgument(4, dBetaView.reshape(dBetaView.length())); + } + + + Nd4j.exec(op); + + Gradient g = new DefaultGradient(); + g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView); + g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView); + + return new Pair<>(g, epsAtInput); + */ return null; } @Override - public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, + public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { if(x.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports float diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java new file mode 100644 index 000000000..ed9a2ef15 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java @@ -0,0 +1,168 @@ +package org.deeplearning4j.nn.layers.mkldnn; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn; +import org.deeplearning4j.nn.layers.recurrent.LSTMHelper; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.BooleanIndexing; +import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class MKLDNNLSTMHelper implements LSTMHelper { + @Override + public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) { + //TODO check other activation functions for MKLDNN + return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled(); + } + + @Override + public Pair backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, + INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, + int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, + String recurrentWeightKey, String biasWeightKey, Map gradientViews, + INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) { + //Not yet implemented/supported + return null; + } + + @Override + public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, + INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training, + INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards, + String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) { + + /* + DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward) + Inputs: + x = [bS, nIn, sL] + Wx = [nIn, 4*nOut] + Wr = [nOut, 4*nOut] + Wp = [3*nOut] Optional peephole weights + b = [4*nOut] + seqLen = [bS] + initialOut = [bs, nOut] + initialCell = [bs, nOut] + + Outputs: + out = [bS, nOut, sL] + outLast = [bs, nOut] + cellLast = [bs,nOut] + + Gates order: input, forget, input modulation, output + + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + + INDArray b1d = biases.reshape(biases.length()); + INDArray seqLen = null; + if(maskArray != null){ + seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen]) + } + + List args = new ArrayList<>(); + args.add(input); + args.add(inputWeights); + args.add(recurrentWeights); + if(hasPeepholeConnections){ + throw new IllegalStateException("Not yet implemented"); + } + args.add(b1d); + if(seqLen != null) + args.add(seqLen); + if(prevOutputActivations != null) + args.add(prevOutputActivations); + if(prevMemCellState != null) + args.add(prevMemCellState); + + IActivation a = ((LSTM)conf.getLayer()).getActivationFn(); + + DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer") + .addInputs(args.toArray(new INDArray[0])) + .addBooleanArguments( + true, //hasBiases + seqLen != null, //hasSeqLen + prevOutputActivations != null, //hasInitH + prevMemCellState != null, //hasInitC + hasPeepholeConnections, //hasPh + true, //retFullSeq + true, //retLastH + true //retLastC + ) + .addIntegerArguments( + 2, //data format: 2 = [bS, nIn, sL] + 0, //direction: 0 = forward + activationToArg(gateActivationFn), //Gate activation + activationToArg(a), //Cell state activation + activationToArg(a) //Output activation (same as cell in DL4J) + ) + .build(); + + List outShapes = op.calculateOutputShape(); + + for(LongShapeDescriptor lsd : outShapes){ + INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder()); + op.addOutputArgument(arr); + } + + FwdPassReturn f = new FwdPassReturn(); + f.fwdPassOutput = op.getOutputArgument(0); + f.lastAct = op.getOutputArgument(1); + f.lastMemCell = op.getOutputArgument(2); + + return f; + } + + @Override + public Map helperMemoryUse() { + return Collections.emptyMap(); + } + + private int activationToArg(IActivation a){ + //0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + if(a instanceof ActivationTanH) + return 0; + if(a instanceof ActivationReLU) + return 1; + if(a instanceof ActivationSigmoid) + return 2; + if(a instanceof ActivationIdentity) + return 3; + if(a instanceof ActivationLReLU) + return 4; + if(a instanceof ActivationThresholdedReLU) + return 5; + if(a instanceof ActivationHardSigmoid) + return 7; + if(a instanceof ActivationELU) + return 8; + if(a instanceof ActivationSoftSign) + return 9; + if(a instanceof ActivationSoftPlus) + return 10; + throw new IllegalStateException("Unknown or not supported activation function: " + a); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index d5a4c75af..8c8f329ea 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer ret = null; try { - ret = helper.backpropGradient(in, eps, ArrayUtil.toInts(shape), gamma, dGammaView, dBetaView, + ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView, layerConf.getEps(), workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } @@ -438,7 +439,6 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); + Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, + INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); - INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, + INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); INDArray getMeanCache(DataType dataType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index 3191d1dda..fe482ad62 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -144,7 +144,7 @@ public class LocalResponseNormalization } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } @@ -211,7 +211,7 @@ public class LocalResponseNormalization } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index bb551cd3f..3e46fd044 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -114,10 +114,9 @@ public class Yolo2OutputLayer extends AbstractLayer C = (input.size(1)/b) - 5 + long mb = input.size(0); + long h = input.size(2); + long w = input.size(3); + long b = boundingBoxPriors.size(0); + long c = input.size(1)/b-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5 INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c'); INDArray output5 = output.reshape('c', mb, b, 5+c, h, w); @@ -77,7 +76,7 @@ public class YoloUtils { //TODO OPTIMIZE? INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W] INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c] - .dup('c').reshape('c', new int[]{mb*b*h*w, c}); + .dup('c').reshape('c', new long[]{mb*b*h*w, c}); Transforms.softmax(classPredictionsPreSoftmax2d, false); INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3); @@ -173,13 +172,12 @@ public class YoloUtils { throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold); } - // FIXME: int cast //Activations format: [mb, 5b+c, h, w] - int mb = (int) networkOutput.size(0); - int h = (int) networkOutput.size(2); - int w = (int) networkOutput.size(3); - int b = (int) boundingBoxPriors.size(0); - int c = (int) (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5 + long mb = networkOutput.size(0); + long h = networkOutput.size(2); + long w = networkOutput.size(3); + long b = boundingBoxPriors.size(0); + long c = (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5 //Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W] INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java index c3e77bb99..692713f6e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java @@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.LayerHelper; +import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper; +import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper; import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; @@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer Integer.MAX_VALUE)) || + recurrentWeights.size(0) > Integer.MAX_VALUE || input.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int timeSeriesLength = (int) (is2dInput ? 1 : input.size(2)); int hiddenLayerSize = (int) recurrentWeights.size(0); int miniBatchSize = (int) input.size(0); @@ -550,7 +553,8 @@ public class LSTMHelpers { for (long iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) { try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.RNN_BP_LOOP_WORKING_MEM)) { - // FIXME: int cast + if (iTimeIndex > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int time = (int) iTimeIndex; int inext = 1; @@ -574,8 +578,6 @@ public class LSTMHelpers { (iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]); INDArray currMemCellState = fwdPass.memCellState[(int) time]; - - // FIXME: int cast //LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out) INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv. diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index 82a599acb..dd1b03d63 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -89,8 +89,7 @@ public class RnnLossLayer extends BaseLayer detached memory + public DL4JSameDiffMemoryMgr(String workingMemoryWs, String outputWs, WorkspaceConfiguration confWorking, + WorkspaceConfiguration confOutput){ + this.workingMemoryWs = workingMemoryWs; + this.outputWs = outputWs; + this.confWorking = confWorking; + this.confOutput = confOutput; + } + + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + String wsName = detached ? outputWs : workingMemoryWs; + WorkspaceConfiguration wsConf = detached ? confOutput : confWorking; + + if(wsName == null){ + //Scoped out + INDArray ret = Nd4j.createUninitializedDetached(dataType, shape); + Preconditions.checkState(!ret.isAttached(), "Returned array should be detached"); + return ret; + } else { + MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, wsName); + try (MemoryWorkspace mw = ws.notifyScopeBorrowed()) { + return Nd4j.createUninitialized(dataType, shape); + } + } + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return allocate(detached, descriptor.dataType(), descriptor.getShape()); + } + + @Override + public void release(INDArray array) { + //No-op - DL4J workspaces handles this + } + + @Override + public void close() { + //No-op - DL4J workspaces handles this + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index f1f4b536d..1d2abe2b6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -31,9 +31,12 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.factory.Nd4j; @@ -95,113 +98,160 @@ public class SameDiffGraphVertex extends BaseGraphVertex { @Override public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - if(sameDiff == null){ + if (sameDiff == null) { doInit(); } - - Map phMap = new HashMap<>(); - config.validateInput(inputs); - for(int i=0; i 0) { - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - } - INDArray result = sameDiff.outputSingle(phMap, outputKey); - - //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere - sameDiff.clearPlaceholders(true); - sameDiff.clearOpInputs(); - return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); } - } - @Override - public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { - Gradient g = new DefaultGradient(); - - INDArray[] dLdIns; - try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ - if(sameDiff == null){ - doInit(); + Map phMap = new HashMap<>(); + config.validateInput(inputs); + for(int i=0; i inputNames = config.getVertexParams().getInputs(); - if(!sameDiff.hasGradientFunction()) { - //Create when scoped out, to ensure any arrays are not in WS - String[] inArr = inputNames.toArray(new String[inputNames.size()]); - sameDiff.createGradFunction(inArr); - } - config.validateInput(inputs); - Map phMap = new HashMap<>(); - List inputs = config.getVertexParams().getInputs(); - int i=0; - for(String s : inputs){ - phMap.put(s, this.inputs[i++]); - } - for( int j=0; j 0) { //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //TODO Find a more efficient solution for this for (Map.Entry e : paramTable.entrySet()) { INDArray arr = e.getValue(); sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } + } + INDArray result = sameDiff.outputSingle(phMap, outputKey); - List required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated - for(String s : inputNames){ - required.add(sameDiff.getVariable(s).gradient().getVarName()); - } - sameDiff.execBackwards(phMap, required); - for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); - INDArray dl4jGrad = gradTable.get(s); - dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS - g.gradientForVariable().put(s, dl4jGrad); - } + //Edge case: "vertex" is just an identity activation, for example + //TODO there may be a cleaner way to do this... + if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){ + result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result); + } else if(actScopedOut && result.isAttached()){ + result = result.detach(); + } - dLdIns = new INDArray[inputs.size()]; - String fnName = fn.getGradPlaceholderName(); - for(int j=0; j doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { + Gradient g = new DefaultGradient(); + + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + if (sameDiff == null) { + doInit(); } } - //TODO optimize - for( int i=0; i inputNames = config.getVertexParams().getInputs(); + if(!sameDiff.hasGradientFunction()) { + //Create when scoped out, to ensure any arrays are not in WS + String[] inArr = inputNames.toArray(new String[inputNames.size()]); + sameDiff.createGradFunction(inArr); + } + config.validateInput(inputs); + + //Configure memory management for SameDiff instance - use DL4J workspaces + Map sessionMap = sameDiff.getFunction("grad").getSessions(); + if(!sessionMap.containsKey(Thread.currentThread().getId())){ + sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad"))); + } + String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM); + String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD); + WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM); + WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD); + + boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD); + Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out"); + SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput); + sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr); + + + + Map phMap = new HashMap<>(); + List inputs = config.getVertexParams().getInputs(); + int i=0; + for(String s : inputs){ + phMap.put(s, this.inputs[i++]); + } + for( int j=0; j required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); + } + + required.addAll(paramTable.keySet()); + required.addAll(inputNames); + + Map gradsMap = sameDiff.calculateGradients(phMap, required); + for(String s : paramTable.keySet() ){ + INDArray sdGrad = gradsMap.get(s); + INDArray dl4jGrad = gradTable.get(s); + dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS + g.gradientForVariable().put(s, dl4jGrad); + } + + INDArray[] dLdIns = new INDArray[inputs.size()]; + String fnName = fn.getGradPlaceholderName(); + for(int j=0; j { assertInputSet(false); try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - if(sameDiff == null){ + if (sameDiff == null) { doInit(); } - - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); - bl.validateInput(input); - - Map phMap = new HashMap<>(); - phMap.put(INPUT_KEY, input); - if(maskArray != null){ - phMap.put(MASK_KEY, maskArray); - } else { - phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); - } - - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - - Map out = sameDiff.output(phMap, outputKey); - INDArray result = out.get(outputKey); - - //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere - sameDiff.clearPlaceholders(true); - sameDiff.clearOpInputs(); - - return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); } + + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + bl.validateInput(input); + + Map phMap = new HashMap<>(); + phMap.put(INPUT_KEY, input); + if(maskArray != null){ + phMap.put(MASK_KEY, maskArray); + } else { + phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); + } + + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); + } + + //Configure memory management for SameDiff instance - use DL4J workspaces + String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM); + String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); + WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM); + WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS); + boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS); + Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out"); + SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput); + + InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId()); + if(is == null){ + is = new InferenceSession(sameDiff); + sameDiff.getSessions().put(Thread.currentThread().getId(), is); + } + is.setMmgr(mmgr); + + Map out = sameDiff.output(phMap, outputKey); + INDArray result = out.get(outputKey); + + //Edge case - identity activation + //TODO there may be a cleaner way to do this... + if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){ + result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result); + } else if(actScopedOut && result.isAttached()){ + result = result.detach(); + } + + + //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere + sameDiff.clearPlaceholders(true); + sameDiff.clearOpInputs(); + + return result; } @@ -122,63 +150,72 @@ public class SameDiffLayer extends AbstractLayer { Gradient g = new DefaultGradient(); INDArray dLdIn; - try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ - if(sameDiff == null){ + + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + if (sameDiff == null) { doInit(); } - if(!sameDiff.hasGradientFunction()) { + if (!sameDiff.hasGradientFunction()) { //Create when scoped out, to ensure any arrays are not in WS sameDiff.createGradFunction(INPUT_KEY); } - - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); - bl.validateInput(input); - - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - - Map phMap = new HashMap<>(); - phMap.put(INPUT_KEY, input); - phMap.put(fn.getGradPlaceholderName(), epsilon); - if(maskArray != null){ - phMap.put(MASK_KEY, maskArray); - } else { - phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); - } - - List requiredGrads = new ArrayList<>(paramTable.size() + 1); - requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName()); - for(String s : paramTable.keySet()){ - requiredGrads.add(sameDiff.grad(s).getVarName()); - } - - sameDiff.execBackwards(phMap, requiredGrads); - for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); - INDArray dl4jGrad = gradTable.get(s); - dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS - g.gradientForVariable().put(s, dl4jGrad); - } - - SDVariable v = sameDiff.grad(INPUT_KEY); - dLdIn = v.getArr(); - - if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){ - //Edge case with lambda layers like identity: SameDiff doesn't store the placeholders - // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here - dLdIn = epsilon; - } } + //Configure memory management for SameDiff instance - use DL4J workspaces + Map sessionMap = sameDiff.getFunction("grad").getSessions(); + if(!sessionMap.containsKey(Thread.currentThread().getId())){ + sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad"))); + } + String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM); + String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD); + WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM); + WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD); + + boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD); + Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out"); + SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput); + sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr); + + + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + bl.validateInput(input); + + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); + } + + Map phMap = new HashMap<>(); + phMap.put(INPUT_KEY, input); + phMap.put(fn.getGradPlaceholderName(), epsilon); + if(maskArray != null){ + phMap.put(MASK_KEY, maskArray); + } else { + phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); + } + + List requiredGrads = new ArrayList<>(paramTable.size() + 1); + requiredGrads.add(INPUT_KEY); + requiredGrads.addAll(paramTable.keySet()); + + Map m = sameDiff.calculateGradients(phMap, requiredGrads); + for(String s : paramTable.keySet() ){ + INDArray sdGrad = m.get(s); + INDArray dl4jGrad = gradTable.get(s); + dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS + g.gradientForVariable().put(s, dl4jGrad); + } + + dLdIn = m.get(INPUT_KEY); + //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + Pair ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + return ret; } /**Returns the parameters of the neural network as a flattened row vector @@ -291,7 +328,7 @@ public class SameDiffLayer extends AbstractLayer { fn = sameDiff.f().externalErrors(layerOutput); fn.outputVariable(); - this.outputKey = outputVar.getVarName(); + this.outputKey = outputVar.name(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index e5ca125cd..35c44d17d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -29,9 +29,12 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.dataset.api.DataSet; @@ -95,40 +98,59 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - - Map phMap = new HashMap<>(); - phMap.put(INPUT_KEY, input); - if(!activations && layerConf().labelsRequired() && labels != null) { - phMap.put(LABELS_KEY, labels); - } - - String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName(); - - INDArray out = sameDiff.outputSingle(phMap, s); - - //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere - sameDiff.clearPlaceholders(true); - sameDiff.clearOpInputs(); - - if(activations) { - Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " + - "null - error during execution or this variable (as defined by method activationsVertexName()) " + - "does not exist", layerConf().activationsVertexName()); - return workspaceMgr.dup(ArrayType.ACTIVATIONS, out); - } else { - return out; - } } + + //Configure memory management for SameDiff instance - use DL4J workspaces + String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM); + String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); + WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM); + WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS); + boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS); + Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out"); + SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput); + + InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId()); + if(is == null){ + is = new InferenceSession(sameDiff); + sameDiff.getSessions().put(Thread.currentThread().getId(), is); + } + is.setMmgr(mmgr); + + + + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); + } + + Map phMap = new HashMap<>(); + phMap.put(INPUT_KEY, input); + if(!activations && layerConf().labelsRequired() && labels != null) { + phMap.put(LABELS_KEY, labels); + } + + String s = activations ? layerConf().activationsVertexName() : outputVar.name(); + + INDArray out = sameDiff.outputSingle(phMap, s); + + //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere + sameDiff.clearPlaceholders(true); + sameDiff.clearOpInputs(); + + //Edge case: vertex is just an Identity function, for example + //TODO there may be a cleaner way to do this... + if(!actScopedOut && !out.data().getParentWorkspace().getId().equals(wsNameOutput)){ + out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out); + } else if(actScopedOut && out.isAttached()){ + out = out.detach(); + } + + return out; } @@ -141,50 +163,76 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - - List gradVarNames = new ArrayList<>(); - for(String s : paramTable.keySet()){ - gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName()); - } - gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName()); - - Map phMap = new HashMap<>(); - phMap.put(INPUT_KEY, input); - phMap.put(LABELS_KEY, labels); - - sameDiff.execBackwards(phMap, gradVarNames); - for(String s : paramTable.keySet() ){ - INDArray sdGrad = sameDiff.grad(s).getArr(); - INDArray dl4jGrad = gradTable.get(s); - dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS - g.gradientForVariable().put(s, dl4jGrad); - } - - dLdIn = sameDiff.grad(INPUT_KEY).getArr(); } + //Configure memory management for SameDiff instance - use DL4J workspaces + Map sessionMap = sameDiff.getFunction("grad").getSessions(); + if(!sessionMap.containsKey(Thread.currentThread().getId())){ + sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad"))); + } + String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM); + String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD); + WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM); + WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD); + + boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD); + Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out"); + SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput); + sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr); + + if(!sameDiff.hasGradientFunction()) { + //Create when scoped out, to ensure any arrays are not in WS + sameDiff.createGradFunction(INPUT_KEY); + } + + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); + } + + List gradVarNames = new ArrayList<>(); + gradVarNames.addAll(paramTable.keySet()); + gradVarNames.add(INPUT_KEY); + + Map phMap = new HashMap<>(); + phMap.put(INPUT_KEY, input); + phMap.put(LABELS_KEY, labels); + + Map grads = sameDiff.calculateGradients(phMap, gradVarNames); + for(String s : paramTable.keySet() ){ + INDArray sdGrad = grads.get(s); + INDArray dl4jGrad = gradTable.get(s); + dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS + g.gradientForVariable().put(s, dl4jGrad); + if(sdGrad.closeable()){ + sdGrad.close(); + } + } + + dLdIn = grads.get(INPUT_KEY); + //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere sameDiff.clearPlaceholders(true); sameDiff.clearOpInputs(); - return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS + //TODO there may be a cleaner way to do this... + if(!actGradScopedOut && !dLdIn.data().getParentWorkspace().getId().equals(wsNameActGrad)){ + dLdIn = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn); + } else if(actGradScopedOut && dLdIn.isAttached()){ + dLdIn = dLdIn.detach(); + } + + return new Pair<>(g, dLdIn); } /**Returns the parameters of the neural network as a flattened row vector @@ -297,7 +345,7 @@ public class SameDiffOutputLayer extends AbstractLayer Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return (int) input.size(0); } @@ -862,7 +864,8 @@ public class VariationalAutoencoder implements Layer { @Override public int getInputMiniBatchSize() { - // FIXME: int cast + if (input.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return (int) input.size(0); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 24de38f56..dcf82a20a 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -75,6 +75,7 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.heartbeat.Heartbeat; import org.nd4j.linalg.heartbeat.reports.Environment; @@ -425,7 +426,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) { - // FIXME: int cast + if (input.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0), LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); } @@ -439,7 +441,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura //In 99+% of cases, the input and labels dimension 0 size should be identical //The only real exceptions: space to batch, and batch to space layers //In those cases, we should base it on the labels size, as this impacts gradient calculation - // FIXME: int cast + if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return labels == null ? (int) input.size(0) : (int)labels.size(0); } @@ -2074,7 +2077,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura if (endTimeIdx > timeSeriesLength) endTimeIdx = timeSeriesLength; - // FIXME: int cast + if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, featuresMaskArray, labelsMaskArray); @@ -2211,7 +2215,9 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura public int[] predict(INDArray d) { INDArray output = output(d, Layer.TrainingMode.TEST); - // FIXME: int cast + if (d.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); + int[] ret = new int[(int) d.size(0)]; if (d.isRowVectorOrScalar()) ret[0] = Nd4j.getBlasWrapper().iamax(output); @@ -2335,7 +2341,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer(); - // FIXME: int cast + if (layerConf.getNOut() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut())); } @@ -2584,7 +2591,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(), data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); - // FIXME: int cast + if (data.getFeatures().size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); IOutputLayer ol = (IOutputLayer) getOutputLayer(); if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) { inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) @@ -2647,7 +2655,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura IOutputLayer ol = (IOutputLayer) getOutputLayer(); if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){ - // FIXME: int cast + if (data.getFeatures().size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast, (int) data.getFeatures().size(0), mgr); } @@ -2811,7 +2820,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura throw new IllegalArgumentException( "Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")"); - // FIXME: int cast + if (input.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); setInputMiniBatchSize((int) input.size(0)); } } @@ -3086,7 +3096,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura if(!conf().isMiniBatch()) return 1; - // FIXME: int cast + if (input.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return (int) input.size(0); } @@ -3256,7 +3267,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { if (featuresMaskArray != null) { - // FIXME: int cast + if (featuresMaskArray.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); //New approach: use feedForwardMaskArray method feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); @@ -3438,7 +3450,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura val startTimeIdx = i * fwdLen; val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength); - // FIXME: int cast + if (endTimeIdx > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask); setLayerMaskArrays(subsets[2], subsets[3]); @@ -3943,7 +3956,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura } FeedForwardLayer ffl = (FeedForwardLayer) conf; - // FIXME: int cast + if (ffl.getNOut() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return (int) ffl.getNOut(); } @@ -3969,7 +3983,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura } FeedForwardLayer ffl = (FeedForwardLayer) conf; - // FIXME: int cast + if (ffl.getNIn() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); return (int) ffl.getNIn(); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java index f30ec84ae..59ecf39f0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.ArrayList; @@ -108,7 +109,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } //Between last decoder layer and parameters for p(x|z): - // FIXME: int cast + if (nIn > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1]; paramCount += (lastDecLayerSize + 1) * nDistributionParams; @@ -294,7 +296,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } //Finally, p(x|z): - // FIXME: int cast + if (nIn > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; INDArray pxzWeightView = @@ -402,7 +405,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } //Finally, p(x|z): - // FIXME: int cast + if (nIn > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; INDArray pxzWeightView = diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index 10195b59d..c991f41c9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.deeplearning4j.nn.workspace.ArrayType; @@ -111,7 +112,8 @@ public abstract class BaseMultiLayerUpdater implements Updater if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable, layers[i], var)) { - // FIXME: int cast + if (paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE || paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); //Create a new block List list = new ArrayList<>(); list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, @@ -122,9 +124,11 @@ public abstract class BaseMultiLayerUpdater implements Updater updaterBlocks.add(currentBlock); } else { - // FIXME: int cast + long newOffset = currentBlock.getParamOffsetEnd() + paramSizeThisVariable; + if (newOffset > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); //Add to existing updater block - currentBlock.setParamOffsetEnd((int) (currentBlock.getParamOffsetEnd() + paramSizeThisVariable)); + currentBlock.setParamOffsetEnd((int) newOffset); currentBlock.setUpdaterViewOffsetEnd( currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable); currentBlock.getLayersAndVariablesInBlock() diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java index cc3ec16fb..f6c2e269c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java @@ -25,6 +25,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener { private int frequency; private int iterationCount = 0; - private List> scoreVsIter = new ArrayList<>(); + //private List> scoreVsIter = new ArrayList<>(); + + public static class ScoreStat { + public static final int BUCKET_LENGTH = 10000; + + private int position = 0; + private int bucketNumber = 1; + private List indexes; + private List scores; + + public ScoreStat() { + indexes = new ArrayList<>(1); + indexes.add(new long[BUCKET_LENGTH]); + scores = new ArrayList<>(1); + scores.add(new double[BUCKET_LENGTH]); + } + + public List getIndexes() { + return indexes; + } + + public List getScores() { + return scores; + } + + public long[] getEffectiveIndexes() { + return Arrays.copyOfRange(indexes.get(0), 0, position); + } + + public double[] getEffectiveScores() { + return Arrays.copyOfRange(scores.get(0), 0, position); + } + + + /* + Originally scores array is initialized with BUCKET_LENGTH size. + When data doesn't fit there - arrays size is increased for BUCKET_LENGTH, + old data is copied and bucketNumber (counter of reallocations) being incremented. + + If we got more score points than MAX_VALUE - they are put to another item of scores list. + */ + private void reallocateGuard() { + if (position >= BUCKET_LENGTH * bucketNumber) { + + long fullLength = (long)BUCKET_LENGTH * bucketNumber; + + if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) { + position = 0; + long[] newIndexes = new long[BUCKET_LENGTH]; + double[] newScores = new double[BUCKET_LENGTH]; + indexes.add(newIndexes); + scores.add(newScores); + } + else { + long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH]; + double[] newScores = new double[(int)fullLength + BUCKET_LENGTH]; + System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength); + System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength); + scores.remove(scores.size()-1); + indexes.remove(indexes.size()-1); + int lastIndex = scores.size() == 0 ? 0 : scores.size()-1; + scores.add(lastIndex, newScores); + indexes.add(lastIndex, newIndexes); + } + bucketNumber += 1; + } + } + + public void addScore(long index, double score) { + reallocateGuard(); + scores.get(scores.size() - 1)[position] = score; + indexes.get(scores.size() - 1)[position] = index; + position += 1; + } + } + + ScoreStat scoreVsIter = new ScoreStat(); /** * Constructor for collecting scores with default saving frequency of 1 @@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener { public void iterationDone(Model model, int iteration, int epoch) { if (++iterationCount % frequency == 0) { double score = model.score(); - scoreVsIter.add(new Pair<>(iterationCount, score)); + scoreVsIter.reallocateGuard(); + scoreVsIter.addScore(iteration, score); } } - public List> getScoreVsIter() { + public ScoreStat getScoreVsIter() { return scoreVsIter; } @@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener { public void exportScores(OutputStream outputStream, String delimiter) throws IOException { StringBuilder sb = new StringBuilder(); sb.append("Iteration").append(delimiter).append("Score"); - for (Pair p : scoreVsIter) { - sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond()); + int largeBuckets = scoreVsIter.indexes.size(); + for (int j = 0; j < largeBuckets; ++j) { + long[] indexes = scoreVsIter.indexes.get(j); + double[] scores = scoreVsIter.scores.get(j); + + int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position; + + for (int i = 0; i < effectiveLength; ++i) { + sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]); + } } outputStream.write(sb.toString().getBytes("UTF-8")); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index 1e5ca2675..f0c8d76c9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -29,6 +29,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -62,10 +63,9 @@ public class Convolution1DUtils { * @param dilation Kernel dilation * @return Output size (width) */ - public static int getOutputSize(int inH, int kernel, int strides, int padding, + public static long getOutputSize(long inH, int kernel, int strides, int padding, ConvolutionMode convolutionMode, int dilation) { - // FIXME: int cast - int eKernel = effectiveKernelSize(kernel, dilation); + long eKernel = effectiveKernelSize(kernel, dilation); if (convolutionMode == ConvolutionMode.Same) { return (int) Math.ceil(inH / ((double) strides)); } @@ -85,7 +85,8 @@ public class Convolution1DUtils { */ public static int getOutputSize(INDArray inputData, int kernel, int strides, int padding, ConvolutionMode convolutionMode, int dilation) { - // FIXME: int cast + if (inputData.size(2) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int inH = (int) inputData.size(2); int eKernel = effectiveKernelSize(kernel, dilation); boolean atrous = (eKernel == kernel); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java index 809ffde45..7d844db27 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java @@ -61,15 +61,14 @@ public class Convolution3DUtils { ConvolutionMode convolutionMode, int[] dilation, boolean isNCDHW) { // NCDHW vs. NDHWC - int inD = (int) (isNCDHW ? inputData.size(2) : inputData.size(1)); - int inH = (int) (isNCDHW ? inputData.size(3) : inputData.size(2)); - int inW = (int) (isNCDHW ? inputData.size(4) : inputData.size(3)); + long inD = (isNCDHW ? inputData.size(2) : inputData.size(1)); + long inH = (isNCDHW ? inputData.size(3) : inputData.size(2)); + long inW = (isNCDHW ? inputData.size(4) : inputData.size(3)); int[] eKernel = effectiveKernelSize(kernel, dilation); boolean atrous = (eKernel == kernel); - // FIXME: int cast - val inShape = new int[]{inD, inH, inW}; + val inShape = new long[]{inD, inH, inW}; validateShapes(ArrayUtil.toInts(inputData.shape()), eKernel, strides, padding, convolutionMode, dilation, inShape, atrous); if (convolutionMode == ConvolutionMode.Same) { @@ -80,16 +79,16 @@ public class Convolution3DUtils { return new int[]{outD, outH, outW}; } - int outD = (inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1; - int outH = (inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1; - int outW = (inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1; + int outD = ((int)inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1; + int outH = ((int)inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1; + int outW = ((int)inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1; return new int[]{outD, outH, outW}; } private static void validateShapes(int[] inputDataShape, int[] eKernel, int[] strides, int[] padding, - ConvolutionMode convolutionMode, int[] dilation, int[] inShape, + ConvolutionMode convolutionMode, int[] dilation, long[] inShape, boolean atrous) { String[] dims = new String[]{"depth", "height", "width"}; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 4c1207d32..d5c8ee1f6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -36,6 +36,8 @@ import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.exception.ND4JArraySizeException; +import org.nd4j.linalg.factory.NDArrayFactory; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -73,7 +75,8 @@ public class ConvolutionUtils { public static int[] getDeconvolutionOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation) { - // FIXME: int cast + if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int hIn = (int) inputData.size(2); int wIn = (int) inputData.size(3); int[] eKernel = effectiveKernelSize(kernel, dilation); @@ -104,7 +107,8 @@ public class ConvolutionUtils { */ public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode, int[] dilation) { - // FIXME: int cast + if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int inH = (int) inputData.size(2); int inW = (int) inputData.size(3); @@ -499,7 +503,7 @@ public class ConvolutionUtils { } } - public static INDArray reshape2dTo4d(INDArray in2d, int[] toShape, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if(in2d.rank() != 2) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2"); if (toShape.length != 4) @@ -513,7 +517,7 @@ public class ConvolutionUtils { return workspaceMgr.leverageTo(type, out.permute(0, 3, 1, 2)); } - public static INDArray reshape2dTo5d(Convolution3D.DataFormat format, INDArray in2d, int n, int d, int h, int w, int ch, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape2dTo5d(Convolution3D.DataFormat format, INDArray in2d, long n, long d, long h, long w, long ch, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if(in2d.rank() != 2) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2"); @@ -580,14 +584,21 @@ public class ConvolutionUtils { int inW; int inDepth; - // FIXME: int cast if (inputType instanceof InputType.InputTypeConvolutional) { InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional) inputType; + if (conv.getHeight() > Integer.MAX_VALUE || conv.getWidth() > Integer.MAX_VALUE || + conv.getChannels() > Integer.MAX_VALUE){ + throw new ND4JArraySizeException(); + } inH = (int) conv.getHeight(); inW = (int) conv.getWidth(); inDepth = (int) conv.getChannels(); } else if (inputType instanceof InputType.InputTypeConvolutionalFlat) { InputType.InputTypeConvolutionalFlat conv = (InputType.InputTypeConvolutionalFlat) inputType; + if (conv.getHeight() > Integer.MAX_VALUE || conv.getWidth() > Integer.MAX_VALUE || + conv.getDepth() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } inH = (int) conv.getHeight(); inW = (int) conv.getWidth(); inDepth = (int) conv.getDepth(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index f356fab71..80383698b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -20,6 +20,7 @@ import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -193,7 +194,7 @@ public class TimeSeriesUtils { } - public static INDArray reshape2dTo3d(INDArray in, int miniBatchSize, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) { + public static INDArray reshape2dTo3d(INDArray in, long miniBatchSize, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) { if (in.rank() != 2) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2"); //Based on: RnnToFeedForwardPreProcessor @@ -220,7 +221,6 @@ public class TimeSeriesUtils { in = in.dup('f'); } - // FIXME: int cast int[] idxs = new int[(int) in.size(2)]; int j=0; for( int i=idxs.length-1; i>=0; i--){ @@ -248,7 +248,8 @@ public class TimeSeriesUtils { in = workspaceMgr.dup(arrayType, in, 'f'); } - // FIXME: int cast + if (in.size(2) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int[] idxs = new int[(int) in.size(2)]; int j=0; for( int i=idxs.length-1; i>=0; i--){ @@ -291,7 +292,8 @@ public class TimeSeriesUtils { + " with shape " + Arrays.toString(mask.shape())); } - // FIXME: int cast + if (mask.size(1) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int[] idxs = new int[(int) mask.size(1)]; int j=0; for( int i=idxs.length-1; i>=0; i--){ @@ -319,7 +321,8 @@ public class TimeSeriesUtils { + " with shape " + Arrays.toString(mask.shape())); } - // FIXME: int cast + if (mask.size(1) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int[] idxs = new int[(int) mask.size(1)]; int j=0; for( int i=idxs.length-1; i>=0; i--){ @@ -358,9 +361,8 @@ public class TimeSeriesUtils { INDArray out; if (mask == null) { - // FIXME: int cast //No mask array -> extract same (last) column for all - int lastTS = (int) pullFrom.size(2) - 1; + long lastTS = pullFrom.size(2) - 1; out = pullFrom.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS)); fwdPassTimeSteps = null; //Null -> last time step for all examples } else { @@ -396,9 +398,8 @@ public class TimeSeriesUtils { INDArray out; if (mask == null) { - // FIXME: int cast //No mask array -> extract same (last) column for all - int lastTS = (int) pullFrom.size(2) - 1; + long lastTS = pullFrom.size(2) - 1; out = pullFrom.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS)); fwdPassTimeSteps = null; //Null -> last time step for all examples } else { diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java index 251849f3e..c60822ef7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java @@ -116,7 +116,6 @@ public class TestUtils { public static INDArray randomOneHot(long examples, long nOut, Random rng){ INDArray arr = Nd4j.create(examples, nOut); for( int i=0; i indexes = statTest.getIndexes(); + List scores = statTest.getScores(); + + assertTrue(indexes.size() == 1); + assertTrue(scores.size() == 1); + + assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); + assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); + assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1); + assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4); + } + + @Test + public void testScoreStatAverage() { + int dataSize = 1000000; + long[] indexes = new long[dataSize]; + double[] scores = new double[dataSize]; + + for (int i = 0; i < dataSize; ++i) { + indexes[i] = i; + scores[i] = i; + } + + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + for (int i = 0; i < dataSize; ++i) { + statTest.addScore(indexes[i], scores[i]); + } + + long[] indexesStored = statTest.getIndexes().get(0); + double[] scoresStored = statTest.getScores().get(0); + + assertArrayEquals(indexes, indexesStored); + assertArrayEquals(scores, scoresStored, 1e-4); + } + + @Test + public void testScoresClean() { + int dataSize = 10256; // expected to be placed in 2 buckets of 10k elements size + long[] indexes = new long[dataSize]; + double[] scores = new double[dataSize]; + + for (int i = 0; i < dataSize; ++i) { + indexes[i] = i; + scores[i] = i; + } + + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + for (int i = 0; i < dataSize; ++i) { + statTest.addScore(indexes[i], scores[i]); + } + + long[] indexesEffective = statTest.getEffectiveIndexes(); + double[] scoresEffective = statTest.getEffectiveScores(); + + assertArrayEquals(indexes, indexesEffective); + assertArrayEquals(scores, scoresEffective, 1e-4); + } + + @Ignore + @Test + public void testScoreStatBig() { + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + long bigLength = (long)Integer.MAX_VALUE + 5; + for (long i = 0; i < bigLength; ++i) { + double score = (double)i; + statTest.addScore(i, score); + } + + List indexes = statTest.getIndexes(); + List scores = statTest.getScores(); + + assertTrue(indexes.size() == 2); + assertTrue(scores.size() == 2); + + for (int i = 0; i < 5; ++i) { + assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i); + assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i); + + } + } +} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java index 297e6d9e1..3dde1ae9b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java @@ -93,11 +93,11 @@ public class Glove implements Serializable { VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) { //gradient for word vectors INDArray grad1 = contextVector.mul(gradient); - INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape())); + INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); wordVector.subi(update); double w1Bias = bias.getDouble(w1.getIndex()); - double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape())); + double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); double update2 = w1Bias - biasGradient; bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2); return new Pair<>(update, (float) update2); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java index bbad9f2e3..20035688b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java @@ -214,7 +214,6 @@ public class FirstIterationFunction implements else { nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - // FIXME: int cast int idx = Math.abs((int) (nextRandom.get() >> 16) % (int) negativeHolder.getTable().length()); target = negativeHolder.getTable().getInt(idx); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java index c34156484..e540c23f0 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java @@ -222,7 +222,6 @@ public class SecondIterationFunction implements FlatMapFunction> 16) % negativeHolder.getTable().length()); target = negativeHolder.getTable().getInt(idx); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java index 0d97f9bc4..78799b22b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java @@ -162,7 +162,6 @@ public class SentenceBatch implements Function label = 1; } else { nextRandom.set(nextRandom.get() * 25214903917L + 11); - // FIXME: int cast target = table.getInt((int) (nextRandom.get() >> 16) % (int) table.length()); if (target == 0) target = (int) nextRandom.get() % (numWords - 1) + 1; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java index 45eca7327..10e2a3050 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java @@ -187,7 +187,6 @@ public class Word2VecPerformer implements VoidFunction, Ato } else { nextRandom.set(nextRandom.get() * 25214903917L + 11); - // FIXME: int cast target = table.getInt((int) (nextRandom.get() >> 16) % (int) table.length()); if (target == 0) target = (int) nextRandom.get() % (numWords - 1) + 1; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java index 539755ee6..0e346b622 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java @@ -337,7 +337,6 @@ public class Word2VecPerformerVoid implements VoidFunction, label = 1; } else { nextRandom.set(nextRandom.get() * 25214903917L + 11); - // FIXME: int cast target = table.getInt((int) (nextRandom.get() >> 16) % (int) table.length()); if (target == 0) target = (int) nextRandom.get() % (numWords - 1) + 1; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java index d9b14651e..3476c5dd2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java @@ -39,7 +39,7 @@ public class StatsCalculationHelper { private long initialModelAfter; private long lastDataSetBefore; private long lastProcessBefore; - private int totalExampleCount; + private long totalExampleCount; private List dataSetGetTimes = new ArrayList<>(); private List processMiniBatchTimes = new ArrayList<>(); @@ -65,7 +65,7 @@ public class StatsCalculationHelper { lastDataSetBefore = timeSource.currentTimeMillis(); } - public void logNextDataSetAfter(int numExamples) { + public void logNextDataSetAfter(long numExamples) { long now = timeSource.currentTimeMillis(); long duration = now - lastDataSetBefore; dataSetGetTimes.add(new BaseEventStats(lastDataSetBefore, duration)); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java index 6fa148394..15ce0eb32 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java @@ -84,9 +84,8 @@ public class ExecuteWorkerMultiDataSetFlatMap implemen s.logNextDataSetBefore(); MultiDataSet next = batchedIterator.next(); - // FIXME: int cast if (stats) - s.logNextDataSetAfter((int) next.getFeatures(0).size(0)); + s.logNextDataSetAfter(next.getFeatures(0).size(0)); if (stats) { s.logProcessMinibatchBefore(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java new file mode 100644 index 000000000..1092ff02b --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.spark.impl.common.reduce; + +import org.apache.spark.api.java.function.Function2; +import scala.Tuple2; + +/** + * Add both elements of a {@code Tuple2} + */ +public class LongDoubleReduceFunction + implements Function2, Tuple2, Tuple2> { + @Override + public Tuple2 call(Tuple2 f, Tuple2 s) throws Exception { + return new Tuple2<>(f._1() + s._1(), f._2() + s._2()); + } +} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java index e8ad74ba7..0e639a462 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java @@ -38,6 +38,7 @@ import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; import org.deeplearning4j.spark.impl.SparkListenable; import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction; +import org.deeplearning4j.spark.impl.common.reduce.LongDoubleReduceFunction; import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn; import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn; import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction; @@ -374,11 +375,11 @@ public class SparkComputationGraph extends SparkListenable { * in one go) */ public double calculateScore(JavaRDD data, boolean average, int minibatchSize) { - JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(), - sc.broadcast(network.params(false)), minibatchSize)); + JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(), + sc.broadcast(network.params()), minibatchSize)); //Reduce to a single tuple, with example count + sum of scores - Tuple2 countAndSumScores = rdd.reduce(new IntDoubleReduceFunction()); + Tuple2 countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); if (average) { return countAndSumScores._2() / countAndSumScores._1(); } else { @@ -409,10 +410,10 @@ public class SparkComputationGraph extends SparkListenable { * in one go) */ public double calculateScoreMultiDataSet(JavaRDD data, boolean average, int minibatchSize) { - JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(), - sc.broadcast(network.params(false)), minibatchSize)); + JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(), + sc.broadcast(network.params()), minibatchSize)); //Reduce to a single tuple, with example count + sum of scores - Tuple2 countAndSumScores = rdd.reduce(new IntDoubleReduceFunction()); + Tuple2 countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); if (average) { return countAndSumScores._2() / countAndSumScores._1(); } else { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java index cec2f5b17..6d730b60b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java @@ -71,7 +71,7 @@ public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction featuresList = new ArrayList<>(batchSize); List keyList = new ArrayList<>(batchSize); - List origSizeList = new ArrayList<>(); + List origSizeList = new ArrayList<>(); long[][] firstShapes = null; boolean sizesDiffer = false; @@ -96,8 +96,7 @@ public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction implements PairFlatMapFunction implements PairFlatMapFunction, Tuple2> { +public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class); private String json; private Broadcast params; @@ -50,9 +50,9 @@ public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); + return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -65,13 +65,12 @@ public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction> out = new ArrayList<>(); + List> out = new ArrayList<>(); while (iter.hasNext()) { DataSet ds = iter.next(); double score = network.score(ds, false); - // FIXME: int cast - int numExamples = (int) ds.getFeatures().size(0); + long numExamples = ds.getFeatures().size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java index bf9e3f596..f72fdbb34 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java @@ -35,7 +35,7 @@ import java.util.Iterator; import java.util.List; /** Function used to score a MultiDataSet using a given ComputationGraph */ -public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction, Tuple2> { +public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class); private String json; @@ -50,9 +50,9 @@ public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); + return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator(); } MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -66,13 +66,12 @@ public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction> out = new ArrayList<>(); + List> out = new ArrayList<>(); while (iter.hasNext()) { MultiDataSet ds = iter.next(); double score = network.score(ds, false); - // FIXME: int cast - int numExamples = (int) ds.getFeatures(0).size(0); + long numExamples = ds.getFeatures(0).size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 03e4e55cf..0672b158a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -105,7 +105,6 @@ public class FeedForwardWithKeyFunction fMaskList.add(t2._2()._2()); keyList.add(t2._1()); - // FIXME: int cast origSizeList.add((int) t2._2()._1().size(0)); tupleCount++; } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java index 8063ba8e3..98a2639ef 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java @@ -64,7 +64,6 @@ public class ScoreFlatMapFunction implements FlatMapFunction, DataSet ds = iter.next(); double score = network.score(ds, false); - // FIXME: int cast val numExamples = (int) ds.getFeatures().size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java index 3896ae61d..71fcd1680 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java @@ -247,10 +247,8 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker implements DataSetIterator { protected Collection dataSetStreams; protected DataSetPreProcessor preprocessor; protected Iterator iter; - protected int totalOutcomes = -1; - protected int inputColumns = -1; + protected long totalOutcomes = -1; + protected long inputColumns = -1; protected int batch = -1; protected DataSet preloadedDataSet; protected int cursor = 0; @@ -46,7 +47,7 @@ public abstract class BaseDataSetIterator implements DataSetIterator { public int inputColumns() { if (inputColumns == -1) preloadDataSet(); - return inputColumns; + return (int)inputColumns; } @Override @@ -112,7 +113,9 @@ public abstract class BaseDataSetIterator implements DataSetIterator { private void preloadDataSet() { preloadedDataSet = load(iter.next()); - // FIXME: int cast + if (preloadedDataSet.getLabels().size(1) > Integer.MAX_VALUE || + preloadedDataSet.getFeatures().size(1) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); totalOutcomes = (int) preloadedDataSet.getLabels().size(1); inputColumns = (int) preloadedDataSet.getFeatures().size(1); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java index a8a91ed3d..992b13c38 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java @@ -67,7 +67,6 @@ public class PathSparkDataSetIterator extends BaseDataSetIterator { ds = load(iter.next()); } - // FIXME: int cast totalOutcomes = ds.getLabels() == null ? 0 : (int) ds.getLabels().size(1); //May be null for layerwise pretraining inputColumns = (int) ds.getFeatures().size(1); batch = ds.numExamples(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java index 6285778a6..53af6aa21 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java @@ -18,6 +18,7 @@ package org.deeplearning4j.spark.iterator; import org.apache.spark.input.PortableDataStream; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.exception.ND4JArraySizeException; import java.io.IOException; import java.io.InputStream; @@ -53,7 +54,9 @@ public class PortableDataStreamDataSetIterator extends BaseDataSetIterator Integer.MAX_VALUE || + ds.getFeatures().size(1) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); totalOutcomes = (int) ds.getLabels().size(1); inputColumns = (int) ds.getFeatures().size(1); batch = ds.numExamples(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java index bd4bfac5a..a0792b659 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java @@ -26,9 +26,9 @@ import lombok.Getter; public class ExampleCountEventStats extends BaseEventStats { @Getter - private final int totalExampleCount; + private final long totalExampleCount; - public ExampleCountEventStats(long startTime, long durationMs, int totalExampleCount) { + public ExampleCountEventStats(long startTime, long durationMs, long totalExampleCount) { super(startTime, durationMs); this.totalExampleCount = totalExampleCount; } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java index b501b834e..cfa081710 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java @@ -31,6 +31,7 @@ import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.FeatureUtil; import scala.Tuple2; @@ -122,7 +123,8 @@ public class MLLibUtil { if (!arr.isVector()) { throw new IllegalArgumentException("passed in array must be a vector"); } - // FIXME: int cast + if (arr.length() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); double[] ret = new double[(int) arr.length()]; for (int i = 0; i < arr.length(); i++) { ret[i] = arr.getDouble(i); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index 3fd23bbd3..934ccc110 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -33,6 +33,7 @@ import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage; import org.deeplearning4j.util.UIDProvider; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.io.ClassPathResource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -205,7 +206,8 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { if(layers[i].type() == Layer.Type.CONVOLUTIONAL){ INDArray output = activations.get(i+1); //Offset by 1 - activations list includes input - // FIXME: int cast + if (output.shape()[0] - 1 > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int sampleDim = output.shape()[0] == 1 ? 0 : rnd.nextInt((int) output.shape()[0] - 1) + 1; if (cnt == 0) { INDArray inputs = layers[i].input(); @@ -426,7 +428,8 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { val height = (numRows * (tShape[1] + border + padding_col)) + padding_col + zoomPadding + zoomWidth; - // FIXME: int cast + if (height > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); BufferedImage outputImage = new BufferedImage(maxWidth, (int) height, BufferedImage.TYPE_BYTE_GRAY); Graphics2D graphics2D = outputImage.createGraphics(); @@ -571,7 +574,8 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { */ graphics2D.setPaint(borderColor); - // FIXME: int cast + if (tad2D.shape()[0] > Integer.MAX_VALUE || tad2D.shape()[1] > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); graphics2D.drawRect(columnOffset, rowOffset, (int) tad2D.shape()[0], (int) tad2D.shape()[1]); diff --git a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/UNet.java b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/UNet.java index 5aad00939..4e481655c 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/UNet.java +++ b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/UNet.java @@ -215,7 +215,7 @@ public class UNet extends ZooModel { .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv9-2") - .addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1) + .addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(1) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.IDENTITY).build(), "conv9-3") .addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT) diff --git a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/util/BaseLabels.java b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/util/BaseLabels.java index 771b5f461..d2e86fe94 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/util/BaseLabels.java +++ b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/util/BaseLabels.java @@ -90,9 +90,8 @@ public abstract class BaseLabels implements Labels { Preconditions.checkState(predictions.size(1) == labels.size(), "Invalid input array:" + " expected array with size(1) equal to numLabels (%s), got array with shape %s", labels.size(), predictions.shape()); - // FIXME: int cast - int rows = (int) predictions.size(0); - int cols = (int) predictions.size(1); + long rows = predictions.size(0); + long cols = predictions.size(1); if (predictions.isColumnVectorOrScalar()) { predictions = predictions.ravel(); rows = (int) predictions.size(0); diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index 683914119..72e51cd92 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -116,7 +116,6 @@ public class TestUtils { public static INDArray randomOneHot(long examples, long nOut, Random rng){ INDArray arr = Nd4j.create(examples, nOut); for( int i=0; i& capabilities(); }; } diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 917630bce..9c1b44818 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -1004,7 +1004,7 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *extraParams) const { if (isS()) throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) + if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1017,7 +1017,7 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != dataType()) + if(target.lengthOf() != 1 || target.dataType() != dataType()) throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1030,7 +1030,7 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataType::BOOL) + if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1043,7 +1043,7 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataType::INT64) + if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -2104,7 +2104,7 @@ void NDArray::operator+=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (this->lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2138,7 +2138,7 @@ void NDArray::operator-=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2171,7 +2171,7 @@ void NDArray::operator*=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2208,7 +2208,7 @@ void NDArray::operator/=(const NDArray& other) { throw nd4j::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); } - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2520,12 +2520,12 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { target->assign(this); target->applyPairwiseTransform(op.p, *other, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -2560,13 +2560,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { NDArray temp(target->_shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -2599,13 +2599,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* o if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { NDArray temp(target->_shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -3178,9 +3178,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) { return; } - if (other.isScalar()) { + if (other.lengthOf() == 1) { - if(this->isScalar()) { + if(lengthOf() == 1) { NDArray::preparePrimaryUse({this}, {&other}); BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {&other}); @@ -3559,7 +3559,7 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (!scalar->isScalar()) + if (scalar->lengthOf() != 1) throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); if(target == nullptr) target = this; @@ -3678,7 +3678,7 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const NDArray::prepareSpecialUse({target}, {this}); - if (target->isScalar()) { + if (target->lengthOf() == 1) { NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo()); } else { @@ -4060,7 +4060,7 @@ template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, c //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const NDArray& scalar) { - if(!scalar.isScalar()) + if(scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::p method: input array must be scalar!"); if (i >= _length) throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); @@ -4074,7 +4074,7 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { - if(!scalar.isScalar()) + if(scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::p method: input array must be scalar!"); if (i >= _length) throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); @@ -4313,17 +4313,16 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const { return shape::getIndexOffset(i, _shapeInfo); } +//////////////////////////////////////////////////////////////////////// NDArray NDArray::like() { - NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext()); - return res; + return NDArray(shapeInfo(), this->dataType(), false, getContext()); } +//////////////////////////////////////////////////////////////////////// NDArray NDArray::ulike() { - // FIXME: it should be non-memset array - NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext()); - return res; + return NDArray(this, false, getContext()); } //////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 06be5be04..b2679f537 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1704,7 +1704,7 @@ ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOf void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, - int* hIindexes, int* dIindexes); + void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo); ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 243bd96f7..7449bb022 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2767,6 +2767,68 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { delete(reinterpret_cast(ptr)); } +template +static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, + void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, + void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, + void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, + void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, + void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { + + auto hIindexes = reinterpret_cast(vIindexes); + + int numThreads = omp_get_max_threads(); + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) + { + for (int i = 0; i < numOfSubArrs; ++i) { + + int threadIndex = omp_get_thread_num(); + const auto xIndex = hIindexes[i]; + const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; + + if (!isOwner) + continue; + + NDArray inSubArr( + reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), + hXShapeInfo); + NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), + hYShapeInfo); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) { + continue; + } + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + break; + default: + continue; + } + } + } + +} //////////////////////////////////////////////////////////////////////// void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, @@ -2774,60 +2836,11 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, - int* hIindexes, int* dIindexes) { + void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); try { - - int numThreads = omp_get_max_threads(); - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { - for (int i = 0; i < numOfSubArrs; ++i) { - - int threadIndex = omp_get_thread_num(); - const auto xIndex = hIindexes[i]; - const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; - - if (!isOwner) - continue; - - NDArray inSubArr( - reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), - hXShapeInfo); - NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), - hYShapeInfo); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) { - continue; - } - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); - break; - default: - continue; - } - } - } + BUILD_SINGLE_SELECTOR(iType, _scatterUpdate, (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), INDEXING_TYPES); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index dc9d37b03..2db1aa128 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3198,14 +3198,15 @@ void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { } /////////////////////////////////////////////////////////////////// -template +template __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const int* indexes) { + const void* vindexes) { __shared__ T *x, *y; __shared__ Nd4jLong arrLenX, arrLenY; + auto indexes = reinterpret_cast(vindexes); for (int e = 0; e < numOfSubArrs; e++ ) { @@ -3261,10 +3262,10 @@ __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArr } } -template -__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) { +template +__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) { - scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); + scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); } @@ -3274,15 +3275,17 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, - int* hIindexes, int* dIndexes) { + void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { try { auto stream = reinterpret_cast(extraPointers[1]); - nd4j::DataType type = ArrayOptions::dataType(hXShapeInfo); + auto type = ArrayOptions::dataType(hXShapeInfo); + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); + + BUILD_DOUBLE_SELECTOR(type, iType, scatterUpdateCudaLauncher, + (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIindexes), + LIBND4J_TYPES, INDEXING_TYPES); - BUILD_SINGLE_SELECTOR(type, scatterUpdateCudaLauncher, - (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIndexes), - LIBND4J_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index 0810d2e6e..7fa9722db 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -106,6 +106,12 @@ public struct FlatNode : IFlatbufferObject #endif public DType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } + public string ControlDeps(int j) { int o = __p.__offset(42); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsLength { get { int o = __p.__offset(42); return o != 0 ? __p.__vector_len(o) : 0; } } + public string VarControlDeps(int j) { int o = __p.__offset(44); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } } public static Offset CreateFlatNode(FlatBufferBuilder builder, int id = 0, @@ -126,9 +132,15 @@ public struct FlatNode : IFlatbufferObject VectorOffset outputNamesOffset = default(VectorOffset), StringOffset opNameOffset = default(StringOffset), VectorOffset outputTypesOffset = default(VectorOffset), - Offset scalarOffset = default(Offset)) { - builder.StartObject(19); + Offset scalarOffset = default(Offset), + VectorOffset controlDepsOffset = default(VectorOffset), + VectorOffset varControlDepsOffset = default(VectorOffset), + VectorOffset controlDepForOffset = default(VectorOffset)) { + builder.StartObject(22); FlatNode.AddOpNum(builder, opNum); + FlatNode.AddControlDepFor(builder, controlDepForOffset); + FlatNode.AddVarControlDeps(builder, varControlDepsOffset); + FlatNode.AddControlDeps(builder, controlDepsOffset); FlatNode.AddScalar(builder, scalarOffset); FlatNode.AddOutputTypes(builder, outputTypesOffset); FlatNode.AddOpName(builder, opNameOffset); @@ -150,7 +162,7 @@ public struct FlatNode : IFlatbufferObject return FlatNode.EndFlatNode(builder); } - public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(19); } + public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } @@ -200,6 +212,18 @@ public struct FlatNode : IFlatbufferObject public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void AddScalar(FlatBufferBuilder builder, Offset scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } + public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(19, controlDepsOffset.Value, 0); } + public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddVarControlDeps(FlatBufferBuilder builder, VectorOffset varControlDepsOffset) { builder.AddOffset(20, varControlDepsOffset.Value, 0); } + public static VectorOffset CreateVarControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateVarControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepFor(FlatBufferBuilder builder, VectorOffset controlDepForOffset) { builder.AddOffset(21, controlDepForOffset.Value, 0); } + public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static Offset EndFlatNode(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java index f739551f1..8a72cc00a 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java @@ -66,6 +66,12 @@ public final class FlatNode extends Table { public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } public FlatArray scalar() { return scalar(new FlatArray()); } public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } + public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; } + public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } + public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -86,9 +92,15 @@ public final class FlatNode extends Table { int outputNamesOffset, int opNameOffset, int outputTypesOffset, - int scalarOffset) { - builder.startObject(19); + int scalarOffset, + int controlDepsOffset, + int varControlDepsOffset, + int controlDepForOffset) { + builder.startObject(22); FlatNode.addOpNum(builder, opNum); + FlatNode.addControlDepFor(builder, controlDepForOffset); + FlatNode.addVarControlDeps(builder, varControlDepsOffset); + FlatNode.addControlDeps(builder, controlDepsOffset); FlatNode.addScalar(builder, scalarOffset); FlatNode.addOutputTypes(builder, outputTypesOffset); FlatNode.addOpName(builder, opNameOffset); @@ -110,7 +122,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -150,6 +162,15 @@ public final class FlatNode extends Table { public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); } + public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } + public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py index 520fe1aad..889eca62f 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py @@ -294,7 +294,52 @@ class FlatNode(object): return obj return None -def FlatNodeStart(builder): builder.StartObject(19) + # FlatNode + def ControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def ControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatNode + def VarControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def VarControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatNode + def ControlDepFor(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatNode + def ControlDepForLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatNodeStart(builder): builder.StartObject(22) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) @@ -324,4 +369,10 @@ def FlatNodeAddOpName(builder, opName): builder.PrependUOffsetTRelativeSlot(16, def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0) def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1) def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0) +def FlatNodeAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(19, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0) +def FlatNodeStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTRelativeSlot(20, flatbuffers.number_types.UOffsetTFlags.py_type(varControlDeps), 0) +def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0) +def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs index 9764668a0..325094654 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs @@ -37,6 +37,12 @@ public struct FlatVariable : IFlatbufferObject public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } } public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } + public string ControlDeps(int j) { int o = __p.__offset(18); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsLength { get { int o = __p.__offset(18); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepForOp(int j) { int o = __p.__offset(20); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepForOpLength { get { int o = __p.__offset(20); return o != 0 ? __p.__vector_len(o) : 0; } } + public string ControlDepsForVar(int j) { int o = __p.__offset(22); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } + public int ControlDepsForVarLength { get { int o = __p.__offset(22); return o != 0 ? __p.__vector_len(o) : 0; } } public static Offset CreateFlatVariable(FlatBufferBuilder builder, Offset idOffset = default(Offset), @@ -45,8 +51,14 @@ public struct FlatVariable : IFlatbufferObject VectorOffset shapeOffset = default(VectorOffset), Offset ndarrayOffset = default(Offset), int device = 0, - VarType variabletype = VarType.VARIABLE) { - builder.StartObject(7); + VarType variabletype = VarType.VARIABLE, + VectorOffset controlDepsOffset = default(VectorOffset), + VectorOffset controlDepForOpOffset = default(VectorOffset), + VectorOffset controlDepsForVarOffset = default(VectorOffset)) { + builder.StartObject(10); + FlatVariable.AddControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.AddControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.AddControlDeps(builder, controlDepsOffset); FlatVariable.AddDevice(builder, device); FlatVariable.AddNdarray(builder, ndarrayOffset); FlatVariable.AddShape(builder, shapeOffset); @@ -57,7 +69,7 @@ public struct FlatVariable : IFlatbufferObject return FlatVariable.EndFlatVariable(builder); } - public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } + public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(10); } public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } @@ -68,6 +80,18 @@ public struct FlatVariable : IFlatbufferObject public static void AddNdarray(FlatBufferBuilder builder, Offset ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); } public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); } public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); } + public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(7, controlDepsOffset.Value, 0); } + public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepForOp(FlatBufferBuilder builder, VectorOffset controlDepForOpOffset) { builder.AddOffset(8, controlDepForOpOffset.Value, 0); } + public static VectorOffset CreateControlDepForOpVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepForOpVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddControlDepsForVar(FlatBufferBuilder builder, VectorOffset controlDepsForVarOffset) { builder.AddOffset(9, controlDepsForVarOffset.Value, 0); } + public static VectorOffset CreateControlDepsForVarVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } + public static VectorOffset CreateControlDepsForVarVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } + public static void StartControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static Offset EndFlatVariable(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java index 37e2053c2..d73c990bb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.java @@ -28,6 +28,12 @@ public final class FlatVariable extends Table { public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } + public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; } + public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; } + public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public static int createFlatVariable(FlatBufferBuilder builder, int idOffset, @@ -36,8 +42,14 @@ public final class FlatVariable extends Table { int shapeOffset, int ndarrayOffset, int device, - byte variabletype) { - builder.startObject(7); + byte variabletype, + int controlDepsOffset, + int controlDepForOpOffset, + int controlDepsForVarOffset) { + builder.startObject(10); + FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.addControlDeps(builder, controlDepsOffset); FlatVariable.addDevice(builder, device); FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addShape(builder, shapeOffset); @@ -48,7 +60,7 @@ public final class FlatVariable extends Table { return FlatVariable.endFlatVariable(builder); } - public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } + public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } @@ -58,6 +70,15 @@ public final class FlatVariable extends Table { public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); } + public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); } + public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatVariable(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py index e2679c6cd..d0036c247 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.py @@ -90,7 +90,52 @@ class FlatVariable(object): return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return 0 -def FlatVariableStart(builder): builder.StartObject(7) + # FlatVariable + def ControlDeps(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatVariable + def ControlDepForOp(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepForOpLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # FlatVariable + def ControlDepsForVar(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # FlatVariable + def ControlDepsForVarLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatVariableStart(builder): builder.StartObject(10) def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0) def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0) @@ -99,4 +144,10 @@ def FlatVariableStartShapeVector(builder, numElems): return builder.StartVector( def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0) def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0) def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0) +def FlatVariableAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0) +def FlatVariableStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatVariableAddControlDepForOp(builder, controlDepForOp): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepForOp), 0) +def FlatVariableStartControlDepForOpVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatVariableAddControlDepsForVar(builder, controlDepsForVar): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepsForVar), 0) +def FlatVariableStartControlDepsForVarVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatVariableEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index 286547552..6ca85f7b0 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -35,7 +35,10 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_OUTPUTNAMES = 34, VT_OPNAME = 36, VT_OUTPUTTYPES = 38, - VT_SCALAR = 40 + VT_SCALAR = 40, + VT_CONTROLDEPS = 42, + VT_VARCONTROLDEPS = 44, + VT_CONTROLDEPFOR = 46 }; int32_t id() const { return GetField(VT_ID, 0); @@ -94,6 +97,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *scalar() const { return GetPointer(VT_SCALAR); } + const flatbuffers::Vector> *controlDeps() const { + return GetPointer> *>(VT_CONTROLDEPS); + } + const flatbuffers::Vector> *varControlDeps() const { + return GetPointer> *>(VT_VARCONTROLDEPS); + } + const flatbuffers::Vector> *controlDepFor() const { + return GetPointer> *>(VT_CONTROLDEPFOR); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && @@ -132,6 +144,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(outputTypes()) && VerifyOffset(verifier, VT_SCALAR) && verifier.VerifyTable(scalar()) && + VerifyOffset(verifier, VT_CONTROLDEPS) && + verifier.VerifyVector(controlDeps()) && + verifier.VerifyVectorOfStrings(controlDeps()) && + VerifyOffset(verifier, VT_VARCONTROLDEPS) && + verifier.VerifyVector(varControlDeps()) && + verifier.VerifyVectorOfStrings(varControlDeps()) && + VerifyOffset(verifier, VT_CONTROLDEPFOR) && + verifier.VerifyVector(controlDepFor()) && + verifier.VerifyVectorOfStrings(controlDepFor()) && verifier.EndTable(); } }; @@ -196,6 +217,15 @@ struct FlatNodeBuilder { void add_scalar(flatbuffers::Offset scalar) { fbb_.AddOffset(FlatNode::VT_SCALAR, scalar); } + void add_controlDeps(flatbuffers::Offset>> controlDeps) { + fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps); + } + void add_varControlDeps(flatbuffers::Offset>> varControlDeps) { + fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps); + } + void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { + fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); + } explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -228,9 +258,15 @@ inline flatbuffers::Offset CreateFlatNode( flatbuffers::Offset>> outputNames = 0, flatbuffers::Offset opName = 0, flatbuffers::Offset> outputTypes = 0, - flatbuffers::Offset scalar = 0) { + flatbuffers::Offset scalar = 0, + flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset>> varControlDeps = 0, + flatbuffers::Offset>> controlDepFor = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); + builder_.add_controlDepFor(controlDepFor); + builder_.add_varControlDeps(varControlDeps); + builder_.add_controlDeps(controlDeps); builder_.add_scalar(scalar); builder_.add_outputTypes(outputTypes); builder_.add_opName(opName); @@ -272,7 +308,10 @@ inline flatbuffers::Offset CreateFlatNodeDirect( const std::vector> *outputNames = nullptr, const char *opName = nullptr, const std::vector *outputTypes = nullptr, - flatbuffers::Offset scalar = 0) { + flatbuffers::Offset scalar = 0, + const std::vector> *controlDeps = nullptr, + const std::vector> *varControlDeps = nullptr, + const std::vector> *controlDepFor = nullptr) { return nd4j::graph::CreateFlatNode( _fbb, id, @@ -293,7 +332,10 @@ inline flatbuffers::Offset CreateFlatNodeDirect( outputNames ? _fbb.CreateVector>(*outputNames) : 0, opName ? _fbb.CreateString(opName) : 0, outputTypes ? _fbb.CreateVector(*outputTypes) : 0, - scalar); + scalar, + controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, + controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0); } inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index bd2274dad..dd83c4356 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -344,11 +344,65 @@ nd4j.graph.FlatNode.prototype.scalar = function(obj) { return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null; }; +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.controlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 42); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.controlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 42); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.varControlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 44); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.varControlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 44); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatNode.prototype.controlDepFor = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 46); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.controlDepForLength = function() { + var offset = this.bb.__offset(this.bb_pos, 46); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatNode.startFlatNode = function(builder) { - builder.startObject(19); + builder.startObject(22); }; /** @@ -713,6 +767,93 @@ nd4j.graph.FlatNode.addScalar = function(builder, scalarOffset) { builder.addFieldOffset(18, scalarOffset, 0); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsOffset + */ +nd4j.graph.FlatNode.addControlDeps = function(builder, controlDepsOffset) { + builder.addFieldOffset(19, controlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} varControlDepsOffset + */ +nd4j.graph.FlatNode.addVarControlDeps = function(builder, varControlDepsOffset) { + builder.addFieldOffset(20, varControlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createVarControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startVarControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepForOffset + */ +nd4j.graph.FlatNode.addControlDepFor = function(builder, controlDepForOffset) { + builder.addFieldOffset(21, controlDepForOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createControlDepForVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index ca1a705a0..465490722 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -57,7 +57,10 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_SHAPE = 10, VT_NDARRAY = 12, VT_DEVICE = 14, - VT_VARIABLETYPE = 16 + VT_VARIABLETYPE = 16, + VT_CONTROLDEPS = 18, + VT_CONTROLDEPFOROP = 20, + VT_CONTROLDEPSFORVAR = 22 }; const IntPair *id() const { return GetPointer(VT_ID); @@ -80,6 +83,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VarType variabletype() const { return static_cast(GetField(VT_VARIABLETYPE, 0)); } + const flatbuffers::Vector> *controlDeps() const { + return GetPointer> *>(VT_CONTROLDEPS); + } + const flatbuffers::Vector> *controlDepForOp() const { + return GetPointer> *>(VT_CONTROLDEPFOROP); + } + const flatbuffers::Vector> *controlDepsForVar() const { + return GetPointer> *>(VT_CONTROLDEPSFORVAR); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && @@ -93,6 +105,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyTable(ndarray()) && VerifyField(verifier, VT_DEVICE) && VerifyField(verifier, VT_VARIABLETYPE) && + VerifyOffset(verifier, VT_CONTROLDEPS) && + verifier.VerifyVector(controlDeps()) && + verifier.VerifyVectorOfStrings(controlDeps()) && + VerifyOffset(verifier, VT_CONTROLDEPFOROP) && + verifier.VerifyVector(controlDepForOp()) && + verifier.VerifyVectorOfStrings(controlDepForOp()) && + VerifyOffset(verifier, VT_CONTROLDEPSFORVAR) && + verifier.VerifyVector(controlDepsForVar()) && + verifier.VerifyVectorOfStrings(controlDepsForVar()) && verifier.EndTable(); } }; @@ -121,6 +142,15 @@ struct FlatVariableBuilder { void add_variabletype(VarType variabletype) { fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, static_cast(variabletype), 0); } + void add_controlDeps(flatbuffers::Offset>> controlDeps) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps); + } + void add_controlDepForOp(flatbuffers::Offset>> controlDepForOp) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp); + } + void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); + } explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -141,8 +171,14 @@ inline flatbuffers::Offset CreateFlatVariable( flatbuffers::Offset> shape = 0, flatbuffers::Offset ndarray = 0, int32_t device = 0, - VarType variabletype = VarType_VARIABLE) { + VarType variabletype = VarType_VARIABLE, + flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset>> controlDepForOp = 0, + flatbuffers::Offset>> controlDepsForVar = 0) { FlatVariableBuilder builder_(_fbb); + builder_.add_controlDepsForVar(controlDepsForVar); + builder_.add_controlDepForOp(controlDepForOp); + builder_.add_controlDeps(controlDeps); builder_.add_device(device); builder_.add_ndarray(ndarray); builder_.add_shape(shape); @@ -161,7 +197,10 @@ inline flatbuffers::Offset CreateFlatVariableDirect( const std::vector *shape = nullptr, flatbuffers::Offset ndarray = 0, int32_t device = 0, - VarType variabletype = VarType_VARIABLE) { + VarType variabletype = VarType_VARIABLE, + const std::vector> *controlDeps = nullptr, + const std::vector> *controlDepForOp = nullptr, + const std::vector> *controlDepsForVar = nullptr) { return nd4j::graph::CreateFlatVariable( _fbb, id, @@ -170,7 +209,10 @@ inline flatbuffers::Offset CreateFlatVariableDirect( shape ? _fbb.CreateVector(*shape) : 0, ndarray, device, - variabletype); + variabletype, + controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + controlDepForOp ? _fbb.CreateVector>(*controlDepForOp) : 0, + controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0); } inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) { diff --git a/libnd4j/include/graph/generated/variable_generated.js b/libnd4j/include/graph/generated/variable_generated.js index 9012af2de..4bcdcd741 100644 --- a/libnd4j/include/graph/generated/variable_generated.js +++ b/libnd4j/include/graph/generated/variable_generated.js @@ -125,11 +125,65 @@ nd4j.graph.FlatVariable.prototype.variabletype = function() { return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE; }; +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDeps = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 18); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepsLength = function() { + var offset = this.bb.__offset(this.bb_pos, 18); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDepForOp = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 20); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepForOpLength = function() { + var offset = this.bb.__offset(this.bb_pos, 20); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @param {number} index + * @param {flatbuffers.Encoding=} optionalEncoding + * @returns {string|Uint8Array} + */ +nd4j.graph.FlatVariable.prototype.controlDepsForVar = function(index, optionalEncoding) { + var offset = this.bb.__offset(this.bb_pos, 22); + return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatVariable.prototype.controlDepsForVarLength = function() { + var offset = this.bb.__offset(this.bb_pos, 22); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatVariable.startFlatVariable = function(builder) { - builder.startObject(7); + builder.startObject(10); }; /** @@ -209,6 +263,93 @@ nd4j.graph.FlatVariable.addVariabletype = function(builder, variabletype) { builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsOffset + */ +nd4j.graph.FlatVariable.addControlDeps = function(builder, controlDepsOffset) { + builder.addFieldOffset(7, controlDepsOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepsVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepsVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepForOpOffset + */ +nd4j.graph.FlatVariable.addControlDepForOp = function(builder, controlDepForOpOffset) { + builder.addFieldOffset(8, controlDepForOpOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepForOpVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepForOpVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} controlDepsForVarOffset + */ +nd4j.graph.FlatVariable.addControlDepsForVar = function(builder, controlDepsForVarOffset) { + builder.addFieldOffset(9, controlDepsForVarOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatVariable.createControlDepsForVarVector = function(builder, data) { + builder.startVector(4, data.length, 4); + for (var i = data.length - 1; i >= 0; i--) { + builder.addOffset(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatVariable.startControlDepsForVarVector = function(builder, numElems) { + builder.startVector(4, numElems, 4); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 930702f6d..92975e216 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -52,6 +52,12 @@ table FlatNode { //Scalar value - used for scalar ops. Should be single value only. scalar:FlatArray; + + //Control dependencies + controlDeps:[string]; + varControlDeps:[string]; + controlDepFor:[string]; + } root_type FlatNode; \ No newline at end of file diff --git a/libnd4j/include/graph/scheme/variable.fbs b/libnd4j/include/graph/scheme/variable.fbs index 31eafafa7..1e8010d43 100644 --- a/libnd4j/include/graph/scheme/variable.fbs +++ b/libnd4j/include/graph/scheme/variable.fbs @@ -37,6 +37,10 @@ table FlatVariable { device:int; // default is -1, which means _auto_ variabletype:VarType; + + controlDeps:[string]; + controlDepForOp:[string]; + controlDepsForVar:[string]; } root_type FlatVariable; \ No newline at end of file diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 585db0198..d5ea9abe9 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -60,6 +60,7 @@ namespace nd4j { Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector &shape); Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); + Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo); Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace); Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index 531b68004..bcedd727e 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -99,6 +99,10 @@ namespace nd4j { return bufferForShapeInfo(descriptor).primaryAsT(); } + Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); + } + Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); return bufferForShapeInfo(descriptor).primaryAsT(); diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index 4004b9895..aae62594c 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -102,6 +102,10 @@ namespace nd4j { return bufferForShapeInfo(descriptor).primaryAsT(); } + Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); + } + Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); return bufferForShapeInfo(descriptor).primaryAsT(); diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index ef84cc077..b50104bee 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -268,6 +268,21 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, if(aRank == 2 && isBVector) return mmulMxV(A, B, C, alpha, beta, outOrder); + // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm + if(isAVector && bRank == 2) { + NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} + NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N} + auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} + delete A2; + delete C2; + + if(!C) { + result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} + return result; + } + return C; + } + // batched matrix multiplication return mmulNxN(A, B, C, alpha, beta, outOrder); } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b803bdb8d..4b1f3448f 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -78,7 +78,9 @@ (28, LogicalXor) ,\ (29, LogicalNot) ,\ (30, LogicalAnd), \ - (31, DivideNoNan) + (31, DivideNoNan), \ + (32, IGamma), \ + (33, IGammac) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ @@ -117,6 +119,8 @@ #define TRANSFORM_STRICT_OPS \ + (2, ScaledTanh), \ + (3, Affine), \ (4, TanhDerivative), \ (5, HardTanhDerivative), \ (6, SigmoidDerivative), \ @@ -245,7 +249,9 @@ (43, TruncateMod) ,\ (44, SquaredReverseSubtract) ,\ (45, ReversePow), \ - (46, DivideNoNan) + (46, DivideNoNan), \ + (47, IGamma), \ + (48, IGammac) @@ -380,7 +386,9 @@ (35, AMinPairwise) ,\ (36, TruncateMod), \ (37, ReplaceNans), \ - (38, DivideNoNan) + (38, DivideNoNan), \ + (39, IGamma), \ + (40, IGammac) diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index c665a0abc..0450e50ab 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -49,6 +49,8 @@ namespace nd4j { static BroadcastOpsTuple DivideNoNan(); static BroadcastOpsTuple Multiply(); static BroadcastOpsTuple Subtract(); + static BroadcastOpsTuple IGamma(); + static BroadcastOpsTuple IGammac(); }; } diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 9162d89bf..5aea215c1 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp new file mode 100644 index 000000000..6bd1c88ed --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igamma) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igamma, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); + + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp new file mode 100644 index 000000000..89494dc4b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igammac) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igammac, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igammac) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp new file mode 100644 index 000000000..a7e825a9c --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_knn_mindistance) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto lowest = INPUT_VARIABLE(1); + auto highest = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length"); + REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type"); + + helpers::knn_mindistance(*input, *lowest, *highest, *output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(knn_mindistance) { + auto input = inputShape->at(0); + + // always return scalar here + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input))); + } + + DECLARE_TYPES(knn_mindistance) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 6ef4a49d5..5641bab43 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -29,84 +29,8 @@ namespace nd4j { namespace ops { -CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - // FIXME: double? - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); - - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); - - // normalized output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); - if(applyScale) - sigmaInvGam *= *gamma; - - NDArray inputMinusMean; - if(!input->isSameShape(output) && !mean->isSameShape(output)) { - auto inputTiled = NDArray(output, false, block.launchContext()); - input->tile(inputTiled); - inputMinusMean = inputTiled - *mean; - } - else - inputMinusMean = *input - *mean; - - if (applyOffset) - output->assign(inputMinusMean * sigmaInvGam + *beta); - else - output->assign(inputMinusMean * sigmaInvGam); - - return Status::OK(); -} - - DECLARE_TYPES(batchnorm) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - ////////////////////////////////////////////////////////////////////////// -DECLARE_SHAPE_FN(batchnorm) { - - std::vector inArrs(block.width()); - auto in = inputShape->at(0); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); - - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, DataTypeUtils::pickFloatingType(ArrayOptions::dataType(in)))); - return SHAPELIST(result); -} - -////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { +CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { auto input = INPUT_VARIABLE(0); auto mean = INPUT_VARIABLE(1); @@ -123,7 +47,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { if(applyScale) gamma = INPUT_VARIABLE(3); if(applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + beta = INPUT_VARIABLE(3 + (int)applyScale); const int numOfIntArgs = block.getIArguments()->size(); const int inRank = input->rankOf(); @@ -137,30 +61,31 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { axes.push_back(inRank-1); // default dimension to reduce along is last dimension const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_NEW op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // get, for example, something like {1, inDim1, 1, inDim3, 1} if axes = {1, 3} - std::vector expShapeWithUnities(inRank, 1); - for(int i = 0; i < numOfAxes; ++i) - expShapeWithUnities[axes[i]] = input->sizeAt(axes[i]); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape = numOfAxes == 1 ? std::vector(1, input->sizeAt(axes[0])) : expShapeWithUnities; - std::string expShapeStr = ShapeUtils::shapeAsString(expShape); + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } - REQUIRE_TRUE(ShapeUtils::shapeAsString(mean) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of mean array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(ShapeUtils::shapeAsString(variance) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of variance array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(variance).c_str()); + REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); if(gamma) - REQUIRE_TRUE(ShapeUtils::shapeAsString(gamma) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of gamma array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); if(beta) - REQUIRE_TRUE(ShapeUtils::shapeAsString(beta) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of beta array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(beta).c_str()); + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); // types of all input arrays should be the same for(int i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !"); + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !"); - nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0); + nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); @@ -168,15 +93,15 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { return Status::OK(); } -DECLARE_TYPES(batchnorm_new) { +DECLARE_TYPES(batchnorm) { getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } -DECLARE_SHAPE_FN(batchnorm_new) { +DECLARE_SHAPE_FN(batchnorm) { auto inShapeInfo = inputShape->at(0); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape return SHAPELIST(CONSTANT(outShapeInfo)); @@ -184,290 +109,177 @@ DECLARE_SHAPE_FN(batchnorm_new) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - NDArray *dLdO = nullptr; // next epsilon - auto dLdI = OUTPUT_VARIABLE(0); - auto dLdM = OUTPUT_VARIABLE(1); - auto dLdV = OUTPUT_VARIABLE(2); - NDArray *dLdG = nullptr; - NDArray *dLdB = nullptr; + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon + NDArray* gamma = nullptr; + NDArray* beta = nullptr; - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - // FIXME: double? - const double epsilon = T_ARG(0); + NDArray* dLdI = OUTPUT_VARIABLE(0); + NDArray* dLdM = OUTPUT_VARIABLE(1); + NDArray* dLdV = OUTPUT_VARIABLE(2); + NDArray* dLdG = nullptr; + NDArray* dLdB = nullptr; - const int dLdONum = static_cast(applyScale) + static_cast(applyOffset); + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); if(applyScale) { - gamma = INPUT_VARIABLE(3); + gamma = INPUT_VARIABLE(4); dLdG = OUTPUT_VARIABLE(3); } if(applyOffset) { - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - dLdB = OUTPUT_VARIABLE(3 + static_cast(applyScale)); + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } - - dLdO = INPUT_VARIABLE(3 + dLdONum); - - std::vector inArrs(block.width()); - for(int i = 0; i < 4 + dLdONum; ++i) - inArrs[i] = INPUT_VARIABLE(i); - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes + // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + if(i != 3) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); // ***** calculations ***** // - auto sigmaInv = (*variance + epsilon).transform(transform::RSqrt); - - NDArray sigmaInvGamdLdO = -sigmaInv * *dLdO; - if(applyScale) - sigmaInvGamdLdO *= *gamma; + // formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - NDArray inputMinusMean; - if(!input->isSameShape(dLdO) && !mean->isSameShape(dLdO)) { - auto inputTiled = NDArray(dLdO, false, block.launchContext()); - input->tile(inputTiled); - inputMinusMean = inputTiled - *mean; - } - else - inputMinusMean = *input - *mean; + // consider mean and variance as constants (since we get them as inputs and don't calculate them) + // dLdI = (dLdO * gamma) / (variance + epsilon)^0.5 + // dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5 + // dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5 + // dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5 + // dLdB = dLdO_sum + + const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); + + NDArray temp1 = *variance + epsilon; + temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + if(applyScale) + temp2 *= *gamma; // gamma / (variance + epsilon)^0.5 + + NDArray temp3(input); // empty array with same shape as input + input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean + temp3 *= *dLdO; // (input - mean) * dLdO + + const bool keepUnitiesInShape = inRank == mean->rankOf(); // dLdI - if(!dLdI->isSameShape(dLdO)) - dLdI->assign( (-sigmaInvGamdLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdI->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdI->assign(-sigmaInvGamdLdO); + dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI); // dLdM - if(!dLdM->isSameShape(dLdO)) - dLdM->assign( sigmaInvGamdLdO.reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdM->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdM->assign(sigmaInvGamdLdO); + dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes - // dLdV - if(!dLdV->isSameShape(dLdO)) { - dLdV->assign( (sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdV->getShapeInfo(), dLdO->getShapeInfo())) ); - } - else - dLdV->assign(sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f); + // dLdB + if(applyOffset) + dLdB->assign(dLdM); + + // dLdM + // dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); + // dLdM->applyTransform(nd4j::transform::Neg); + *dLdM = 0; // put zeros so far + + //dLdV + temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum // dLdG if(applyScale) { - if(!dLdG->isSameShape(dLdO)) - dLdG->assign( (sigmaInv * inputMinusMean * *dLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdG->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdG->assign(sigmaInv * inputMinusMean * *dLdO); + dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG); + // dLdV->assign(dLdG); + dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma); } + else + // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); - // dLdB - if(applyOffset) { - if(!dLdB->isSameShape(dLdO)) - dLdB->assign(dLdO->reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdB->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdB->assign(dLdO); - } + // dLdV + // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1); + // *dLdV *= -0.5; + *dLdV = 0; // put zeros so far return Status::OK(); } - DECLARE_TYPES(batchnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, nd4j::DataType::ANY) - ->setAllowedInputTypes(2, nd4j::DataType::ANY) - ->setAllowedInputTypes(3, nd4j::DataType::ANY) - ->setAllowedInputTypes(4, nd4j::DataType::ANY) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(batchnorm_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, nd4j::DataType::ANY) + ->setAllowedInputTypes(2, nd4j::DataType::ANY) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, nd4j::DataType::ANY) + ->setAllowedInputTypes(5, nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batchnorm_bp) { + Nd4jLong* inShapeInfo = inputShape->at(0); + Nd4jLong* meanShapeInfo = inputShape->at(1); + const bool applyScale = (bool)INT_ARG(0); const bool applyOffset = (bool)INT_ARG(1); - const int dLdONum = static_cast(applyScale) + static_cast(applyOffset); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - std::vector inArrs(block.width()); - for(int i = 0; i < 4 + dLdONum; ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto shapes = SHAPELIST(); - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); + // dLdI shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo)); - Nd4jLong* dLdIShapeInfo(nullptr), *dLdMShapeInfo(nullptr), *dLdVShapeInfo(nullptr), *dLdGShapeInfo(nullptr), *dLdBShapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), dLdIShapeInfo); - COPY_SHAPE(inputShape->at(1), dLdMShapeInfo); - COPY_SHAPE(inputShape->at(2), dLdVShapeInfo); + // dLdM shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo)); - if(applyScale) { - COPY_SHAPE(inputShape->at(3), dLdGShapeInfo); - } - if(applyOffset){ - COPY_SHAPE(inputShape->at(3 + static_cast(applyScale)), dLdBShapeInfo); - } + // dLdV shapeInfo (same as dLdM) + shapes->push_back(shapes->at(shapes->size()-1)); - if(!applyScale && !applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo)); + // dLdG shapeInfo (same as dLdM) + if(applyScale) + shapes->push_back(shapes->at(shapes->size()-1)); - if(applyScale && !applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo)); + // dLdB shapeInfo (same as dLdM) + if(applyOffset) + shapes->push_back(shapes->at(shapes->size()-1)); - if(!applyScale && applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdBShapeInfo)); - - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo), CONSTANT(dLdBShapeInfo)); + return shapes; } - // ////////////////////////////////////////////////////////////////////////// - // CONFIGURABLE_OP_IMPL(batchnorm_bp, 5, 1, true, 0, 1) { - - // NDArray* input = INPUT_VARIABLE(0); - // NDArray* epsilon = INPUT_VARIABLE(1); - // NDArray* gamma = INPUT_VARIABLE(2); - // NDArray* dGlobalMeanView = INPUT_VARIABLE(3); - // NDArray* dGlobalVarView = INPUT_VARIABLE(4); - // NDArray* outEpsilon = this->getZ(block); - // std::vector argI = *(block.getIArguments()); - // const int bS = epsilon->sizeAt(0); - // bool isLockGammaBeta = (bool)argI[0]; - // const int* epsilonShape = epsilon->getShapeInfo() + 1; - // const T eps = (T)1e-5; - - // int rank = epsilon->rankOf(); - // std::initializer_list dimensions; - // int effectiveBatchSize; - // if (rank == 2) { - // dimensions = {0}; - // effectiveBatchSize = bS; - // } - // else if (rank == 4) { - // dimensions = {0, 2, 3}; - // effectiveBatchSize = input->sizeAt(0)*input->sizeAt(2)*input->sizeAt(3); - // } - // else - // throw "Graph operation batchnorm_bp: the epsilon rank must be equal to 2 or 4 !"; - - // NDArray *mean(nullptr), *var(nullptr), *dBeta(nullptr), *dGamma(nullptr), *dLdVar(nullptr), *dxmu1(nullptr), *dxmu2(nullptr); - // mean = input->template reduceAlongDimension>(dimensions); - // var = input->template varianceAlongDimension>(false, dimensions); - // var->template applyScalar>(eps, nullptr); - // auto std = new NDArray(var->getShapeInfo(), block.getWorkspace()); - // var->template applyTransform>(std, nullptr); - - // auto xMu = new NDArray(input->getShapeInfo(), block.getWorkspace()); - // auto xHat = new NDArray(input->getShapeInfo(), block.getWorkspace()); - // auto temp1 = new NDArray(epsilon->getShapeInfo(), block.getWorkspace()); - // auto temp2 = new NDArray(std->getShapeInfo(), block.getWorkspace()); - // auto dGammaView = new NDArray('c', {1, epsilonShape[1]}, block.getWorkspace()); - // auto dBetaView = new NDArray('c', {1, epsilonShape[1]}, block.getWorkspace()); - // auto dxhat = new NDArray(epsilon->getShapeInfo(), block.getWorkspace()); - - // if (rank == 2) { - // input->subRowVector(mean, xMu); - // xMu->divRowVector(std, xHat); - // } - // else { - // input->template applyBroadcast>({1}, mean, xMu, nullptr); - // xMu->template applyBroadcast>({1}, std, xHat, nullptr); - // } - - // dBeta = epsilon->sum(dimensions); // dL/dBeta = sum_examples dL/dOut - // epsilon->template applyPairwiseTransform>(xHat, temp1, nullptr); //dL/dGamma = sum_examples dL/dOut .* xHat - // dGamma = temp1->sum(dimensions); //dL/dGamma = sum_examples dL/dOut .* xHat - - // if (isLockGammaBeta) - // epsilon->template applyPairwiseTransform>(gamma, dxhat, nullptr); - // else {// Standard case - // if(rank == 2) - // epsilon->mulRowVector(gamma, dxhat); //dL/dxHat = dL/dOut . gamma Shape: [minibatchSize, nOut] - // else - // epsilon->template applyBroadcast>({1}, gamma, dxhat, nullptr); - // } - - // // dLdVar - dL/dVariance, shape: [1, miniBatch] - // dxhat->template applyPairwiseTransform>(xMu, temp1, nullptr); - // dLdVar = temp1->sum(dimensions); - // dLdVar->template applyScalar>((T)-0.5, nullptr); - // T powParams[] = {(T)(-3.)}; - // std->template applyTransform>(temp2, powParams); - // dLdVar->template applyPairwiseTransform>(temp2, nullptr); - - // //dL/dmu - // dxmu1 = dxhat->sum(dimensions); - // dxmu1->template applyPairwiseTransform>(std, nullptr); - // dxmu1->template applyTransform>(); - // dxmu2 = xMu->sum(dimensions); - // dxmu2->template applyScalar>((T)(-2.)/effectiveBatchSize); - // dxmu2->template applyPairwiseTransform>(dLdVar, nullptr); - - // dxmu1->template applyPairwiseTransform>(dxmu2, nullptr); - // NDArray* dLdmu = dxmu1; // = dL/dmu Shape: [1, nOut] - - // //Note the array reuse here: dxhat, xMu, dLdVar, dLdmu - all are invalid after this line (but aren't used later anyway) - // NDArray* dLdx = dxhat; - // dLdVar->template applyScalar>((T)(2.)/effectiveBatchSize); - // dLdmu->template applyScalar>((T)(1.)/effectiveBatchSize); - // if(rank == 2) { - // dLdx->divRowVector(std, dLdx); - // xMu->mulRowVector(dLdVar, xMu); - // } - // else { - // dLdx->template applyBroadcast>({1}, std, dLdx, nullptr); - // xMu->template applyBroadcast>({1}, dLdVar, xMu, nullptr); - // } - // dLdx->template applyPairwiseTransform>(xMu, nullptr); - // if(rank == 2) - // dLdx->addRowVector(dLdmu, dLdx); - // else - // dLdx->template applyBroadcast>({1}, dLdmu, dLdx, nullptr); - - // *outEpsilon = *dLdx; - - // //TODO rework this to avoid the assign here - // // dGammaView->assign(dGamma); - // // dBetaView->assign(dBeta); - // // dGlobalMeanView->assign((T)0.); - // // dGlobalVarView->assign((T)0.); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, dGlobalMeanView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, dGlobalVarView); - - // delete std; - // delete xMu; - // delete xHat; - // delete mean; - // delete var; - // delete dBeta; - // delete dGamma; - // delete dLdVar; - // delete dxmu1; - // delete dxmu2; - // delete temp1; - // delete temp2; - // delete dxhat; - // delete dGammaView; - // delete dBetaView; - - // return ND4J_STATUS_OK; - // } - - - } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index f69b6c0f9..caff807f8 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -66,8 +66,10 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { if(!isNCHW) output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + if(isSameMode){ // SAME + //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); + } NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 4011d5e32..538214b14 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -47,7 +47,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) { NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext()); factorT.p(0, factor); // this is contrast calculation - *output = (*input - mean) * factorT + mean; + output->assign((*input - mean) * factorT + mean); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 1e0330294..e2fe58b7a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -38,15 +38,19 @@ namespace nd4j { REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); if (scales->lengthOf() < maxOutputSize) maxOutputSize = scales->lengthOf(); - double threshold = 0.5; + double overlayThreshold = 0.5; + double scoreThreshold = - DataTypeUtils::infOrMax(); if (block.getTArguments()->size() > 0) - threshold = T_ARG(0); + overlayThreshold = T_ARG(0); + if (block.getTArguments()->size() > 1) + scoreThreshold = T_ARG(1); - helpers::nonMaxSuppressionV2(block.launchContext(), boxes, scales, maxOutputSize, threshold, output); + helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp new file mode 100644 index 000000000..4f405d8c8 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 10/17/2019 +// + +#include +#include + +#if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.getIArguments()->size() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, "image.non_max_suppression_overlaps: Max output size argument cannot be retrieved."); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression_overlaps: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, "image.non_max_suppression_overlaps: The boxes array should be square, but {%lld, %lld} is given", boxes->sizeAt(0), boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression_overlaps: The rank of scales array should be 1, but %i is given", boxes->rankOf()); + +// if (scales->lengthOf() < maxOutputSize) +// maxOutputSize = scales->lengthOf(); + double overlapThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + if (block.getTArguments()->size() > 0) + overlapThreshold = T_ARG(0); + if (block.getTArguments()->size() > 1) + scoreThreshold = T_ARG(1); + + // TODO: refactor helpers to multithreaded facility + helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, maxOutputSize, overlapThreshold, + scoreThreshold, output); + return Status::OK(); + } + + DECLARE_SHAPE_FN(non_max_suppression_overlaps) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong *outputShape = nullptr; + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.getIArguments()->size() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); + + double overlapThreshold = 0.5; + double scoreThreshold = 0.; + + Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric(block.launchContext(), INPUT_VARIABLE(0), + INPUT_VARIABLE(1), maxOutputSize, overlapThreshold, scoreThreshold, nullptr); //shape::sizeAt(in, 0); + if (boxSize < maxOutputSize) + maxOutputSize = boxSize; + + outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); + + return SHAPELIST(outputShape); + } + DECLARE_TYPES(non_max_suppression_overlaps) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INDICES}); + } + + } +} +#endif diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp new file mode 100644 index 000000000..ed1e9e0f3 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp @@ -0,0 +1,404 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#if NOT_EXCLUDED(OP_lstmLayer) + +#include +#include + +namespace nd4j { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen: + // 1) [bS] always + + // ******* + // initial output hI: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // output at last step hL: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp)); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp)); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + if(directionMode == 0) { // forward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL); + } + else if(directionMode == 1) { // backward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL); + } + else { // bidirectional + + NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); + NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); + NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); + NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr); + + if(Wp) { + WpFwd = new NDArray((*Wp)({0,1, 0,0})); + WpBwd = new NDArray((*Wp)({1,2, 0,0})); + } + if(b) { + bFwd = new NDArray((*b)({0,1, 0,0})); + bBwd = new NDArray((*b)({1,2, 0,0})); + } + if(hI) { + hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); + hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); + } + if(cI) { + cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); + cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); + } + if(hL) { + hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0})); + hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0})); + } + if(cL) { + cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0})); + cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0})); + } + + if(h) { + if(directionMode == 2) { // sum + hFwd = h; + hBwd = new NDArray(h, false, h->getContext()); + } + else if(directionMode == 3) { // concat + hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0})); + hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0})); + } + else { // directionMode == 4 + hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0})); + hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0})); + } + } + + // FIXME - following two calls are independent and may run in different streams + helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd); + helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd); + + if(h && directionMode == 2) + *h += *hBwd; + + delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; + delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd; + if(hFwd != h) + delete hFwd; + } + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayer) { + getOpDescriptor() + ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayer) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim + + const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument) + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + DataType type; + if(x->isR()) + type = x->dataType(); + else + type = nd4j::DataType::FLOAT32; + + std::vector shapes; + + // evaluate h shape (output) + if(retFullSeq) { + + std::vector hShape; + + if(directionMode <= 2) { // single direction or bidirectional with sum + if(dataFormat == 0) + hShape = {sL, bS, nOut}; + else if(dataFormat == 1) + hShape = {bS, sL, nOut}; + else if(dataFormat == 2) + hShape = {bS, nOut, sL}; + } + else if(directionMode == 3) { // bidirectional with concat + + if(dataFormat == 0) + hShape = {sL, bS, 2*nOut}; + else if(dataFormat == 1) + hShape = {bS, sL, 2*nOut}; + else if(dataFormat == 2) + hShape = {bS, 2*nOut, sL}; + } + else { // bidirectional with extra output dimension equal to 2 + hShape = {sL, 2, bS, nOut}; + } + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape)); + } + + // evaluate hL shape (output at last step) + if(retLastH) { + + std::vector hLShape; + + if(directionMode < 2) + hLShape = {bS, nOut}; + else + hLShape = {2, bS, nOut}; + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape)); + + if(retLastC) // cL and hL have same shapes + shapes.push_back(shapes.back()); + } + + // evaluate cL shape (cell state at last step) + if(retLastC && !retLastH) { + + std::vector cLShape; + + if(directionMode < 2) + cLShape = {bS, nOut}; + else + cLShape = {2, bS, nOut}; + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape)); + } + + return new ShapeList(shapes); +} + + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index b3b2463cd..7ee53b52a 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -357,6 +357,28 @@ namespace nd4j { #if NOT_EXCLUDED(OP_Pow) DECLARE_BROADCASTABLE_OP(Pow, 0, 0); #endif + + /** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ + #if NOT_EXCLUDED(OP_igamma) + DECLARE_BROADCASTABLE_OP(igamma, 0, 0); + #endif + /** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ + #if NOT_EXCLUDED(OP_igammac) + DECLARE_BROADCASTABLE_OP(igammac, 0, 0); + #endif } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java b/libnd4j/include/ops/declarable/headers/kernels.h similarity index 70% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java rename to libnd4j/include/ops/declarable/headers/kernels.h index 1246f66fa..8fb2bab62 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/ImportState.java +++ b/libnd4j/include/ops/declarable/headers/kernels.h @@ -14,18 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.imports.graphmapper; +// +// @author raver119@gmail.com +// -import lombok.Data; -import org.nd4j.autodiff.samediff.SameDiff; - -import java.util.Map; - -@Data -public class ImportState { - private SameDiff sameDiff; - private GRAPH_TYPE graph; - private Map variables; +#ifndef LIBND4J_KERNELS_H +#define LIBND4J_KERNELS_H +#include +namespace nd4j { + namespace ops { + #if NOT_EXCLUDED(OP_knn_mindistance) + DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); + #endif + } } + +#endif //LIBND4J_KERNELS_H diff --git a/libnd4j/include/ops/declarable/headers/nn.h b/libnd4j/include/ops/declarable/headers/nn.h index 313707869..9f9b0e40a 100644 --- a/libnd4j/include/ops/declarable/headers/nn.h +++ b/libnd4j/include/ops/declarable/headers/nn.h @@ -29,12 +29,12 @@ namespace nd4j { #if NOT_EXCLUDED(OP_softmax) DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); - #endif + #endif /** * Local response normalization implementation as TF. * input: 4D array - * + * * T args: * * 0: bias @@ -42,8 +42,8 @@ namespace nd4j { * 2: beta * * Int arg: depth - optional local radius - * - * output - 4D array + * + * output - 4D array */ #if NOT_EXCLUDED(OP_lrn) DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); @@ -51,10 +51,10 @@ namespace nd4j { /** * Local response normalization - backprop variant. - * input: + * input: * 0 - 4D array of data * 1 - epsilon - 4D array of approximation - * + * * T args: * * 0: bias @@ -70,34 +70,31 @@ namespace nd4j { #endif /** - * Batch normalization implementation. + * Batch normalization implementation. * Reference: https://arxiv.org/abs/1502.03167v3 - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: * variance: * gamma: * beta: - * + * * Int args: * 0: apply scale * 1: apply offset - * - * + * + * * T args: * 0: epsilon */ #if NOT_EXCLUDED(OP_batchnorm) DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); #endif - #if NOT_EXCLUDED(OP_batchnorm_new) - DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2); - #endif /** * back prop in batch normalization - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: @@ -105,11 +102,11 @@ namespace nd4j { * gamma: optional * beta: optional * dLdOut: next epsilon - * + * * Int args: * 0: apply scale - * 1: apply offset - * + * 1: apply offset + * * T args: * 0: epsilon * @@ -117,8 +114,8 @@ namespace nd4j { * dL/dInput * dL/dMean * dL/dVariance - * dL/dGamma - * dL/dBeta + * dL/dGamma, optional + * dL/dBeta, optional */ #if NOT_EXCLUDED(OP_batchnorm) DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); @@ -131,30 +128,30 @@ namespace nd4j { * x: parameters, any shape * y: gradients. same shape as x * lr: optional, learning rate - * + * * T args: * 0: optional, learning rate */ #if NOT_EXCLUDED(OP_apply_sgd) - DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); #endif /** * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167. * Expected arguments: * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width + * bS - batch size + * iH - input height + * iW - input width * iD - input depth (or number of channels) * scale: 1D input array of scale factors, shape [iD] * offset: 1D input array of offsets (shifts), shape [iD] * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * + * * T input arguments: * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * + * * integer input arguments: * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW * 1: isTraining, may have two values: zero -> inference, unity -> training diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index cbc7e56da..3660ee229 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1691,15 +1691,38 @@ namespace nd4j { * 1 - scales - 1D-tensor with shape (num_boxes) by float type * 2 - output_size - 0D-tensor by int type (optional) * float args: - * 0 - threshold - threshold value for overlap checks (optional, by default 0.5) + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) * int args: * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. * + * output: + * - vector with size M, where M <= output_size by int type + * * */ #if NOT_EXCLUDED(OP_image_non_max_suppression) DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); #endif + /* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) + * int args: + * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size + * */ + #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) + DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); + #endif + /* * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). * input: diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index 4b2eddc57..bf6aaa6bc 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -231,6 +231,12 @@ namespace ops { DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); #endif + ////////////////////////////////////////////////////////////////////////// + #if NOT_EXCLUDED(OP_lstmLayer) + DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); + #endif + + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index d6c4da4a1..a0847f704 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -32,6 +32,8 @@ namespace helpers { template static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta + NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo T eps = epsilon; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp index f90974a9f..f4fb98b2a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp @@ -22,21 +22,26 @@ //#include #include #include +#include namespace nd4j { namespace ops { namespace helpers { template - static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - std::vector indices(scales->lengthOf()); + static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output) { + std::vector indices(scales->lengthOf()); std::iota(indices.begin(), indices.end(), 0); - + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < scoreThreshold) indices[e] = -1; + } std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e(i) > scales->e(j);}); // std::vector selected(output->lengthOf()); std::vector selectedIndices(output->lengthOf(), 0); auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool { + if (previousIndex < 0 || nextIndex < 0) return true; T minYPrev = nd4j::math::nd4j_min(boxes.e(previousIndex, 0), boxes.e(previousIndex, 2)); T minXPrev = nd4j::math::nd4j_min(boxes.e(previousIndex, 1), boxes.e(previousIndex, 3)); T maxYPrev = nd4j::math::nd4j_max(boxes.e(previousIndex, 0), boxes.e(previousIndex, 2)); @@ -70,7 +75,7 @@ namespace helpers { PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected)) for (int j = numSelected - 1; j >= 0; --j) { if (shouldSelect) - if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) { + if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(overlapThreshold))) { shouldSelect = false; } } @@ -80,11 +85,119 @@ namespace helpers { } } } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static Nd4jLong + nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, threshold, output), NUMERIC_TYPES); +// const int outputSize = maxSize->e(0); + auto numBoxes = boxes->sizeAt(0); + //std::vector scoresData(numBoxes); + T* scoresData = scores->dataBuffer()->primaryAsT(); + //std::copy_n(scores->getDataBuffer()->primaryAsT(), numBoxes, scoresData.begin()); + + // Data structure for a selection candidate in NMS. + struct Candidate { + int _boxIndex; + T _score; + int _suppressBeginIndex; + }; + + auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool{ + return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || + (bsI._score < bsJ._score); + }; + std::priority_queue, decltype(cmp)> candidatePriorityQueue(cmp); + for (auto i = 0; i < scores->lengthOf(); ++i) { + if (scoresData[i] > scoreThreshold) { + candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); + } + } + + std::vector selected; + T similarity, originalScore; + Candidate nextCandidate; + + while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { + nextCandidate = candidatePriorityQueue.top(); + originalScore = nextCandidate._score; + candidatePriorityQueue.pop(); + + // Overlapping boxes are likely to have similar scores, therefore we + // iterate through the previously selected boxes backwards in order to + // see if `nextCandidate` should be suppressed. We also enforce a property + // that a candidate can be suppressed by another candidate no more than + // once via `suppress_begin_index` which tracks which previously selected + // boxes have already been compared against next_candidate prior to a given + // iteration. These previous selected boxes are then skipped over in the + // following loop. + bool shouldHardSuppress = false; + for (int j = static_cast(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) { + similarity = boxes->t(nextCandidate._boxIndex, selected[j]); + nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity); + + // First decide whether to perform hard suppression + if (similarity >= static_cast(overlapThreshold)) { + shouldHardSuppress = true; + break; + } + + // If next_candidate survives hard suppression, apply soft suppression + if (nextCandidate._score <= scoreThreshold) break; + } + // If `nextCandidate._score` has not dropped below `scoreThreshold` + // by this point, then we know that we went through all of the previous + // selections and can safely update `suppress_begin_index` to + // `selected.size()`. If on the other hand `next_candidate.score` + // *has* dropped below the score threshold, then since `suppressWeight` + // always returns values in [0, 1], further suppression by items that were + // not covered in the above for loop would not have caused the algorithm + // to select this item. We thus do the same update to + // `suppressBeginIndex`, but really, this element will not be added back + // into the priority queue in the following. + nextCandidate._suppressBeginIndex = selected.size(); + + if (!shouldHardSuppress) { + if (nextCandidate._score == originalScore) { + // Suppression has not occurred, so select next_candidate + selected.push_back(nextCandidate._boxIndex); +// selected_scores.push_back(nextCandidate._score); + } + if (nextCandidate._score > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // score_threshold; add next_candidate back onto priority queue. + candidatePriorityQueue.push(nextCandidate); + } + } + } + + if (output) { + DataBuffer buf(selected.data(), selected.size() * sizeof(I), DataTypeUtils::fromT()); + output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); + } + + return (Nd4jLong)selected.size(); } - BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output), NUMERIC_TYPES); + + Nd4jLong + nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output), FLOAT_TYPES, INTEGER_TYPES); + return 0; + } + + BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); + + void + nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, + overlapThreshold, scoreThreshold, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 3393a61e3..a96db0195 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -22,6 +22,7 @@ #include #include #include +#include namespace nd4j { namespace ops { @@ -121,24 +122,40 @@ namespace helpers { for (auto i = tid; i < len; i += step) indexBuf[i] = (I)srcBuf[i]; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + template + static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, T scoreThreshold) { + auto start = blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start + threadIdx.x; e < (int)length; e += step) { + if (scores[e] < scoreThreshold) { + scores[e] = scoreThreshold; + indices[e] = -1; + } + else { + indices[e] = I(e); + } + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation // template - static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { + static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {boxes, scales}); std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); - indices->linspace(0); - indices->syncToDevice(); // linspace only on CPU, so sync to Device as well NDArray scores(*scales); Nd4jPointer extras[2] = {nullptr, stream}; - + auto indexBuf = indices->dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); + auto scoreBuf = scores.dataBuffer()->specialAsT(); + suppressScores<<<128, 128, 128, *stream>>>(scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); + indices->tickWriteDevice(); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); - - auto indexBuf = reinterpret_cast(indices->specialBuffer()); - + indices->tickWriteDevice(); NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); int numSelected = 0; int numBoxes = boxes->sizeAt(0); @@ -180,10 +197,156 @@ namespace helpers { } } + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold) { + bool shouldHardSuppress = false; + T& nextCandidateScore = scores[nextCandidateIndex]; + I selectedIndex = indices[nextCandidateIndex]; + I finish = startIndices[nextCandidateIndex]; + + for (int j = selectedSize; j > finish; --j) { + Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; + auto xShift = shape::getOffset(shape, xPos, 0); + nextCandidateScore *= (boxes[xShift] <= static_cast(overlapThreshold)?T(1.):T(0.));// + // First decide whether to perform hard suppression + if (boxes[xShift] >= overlapThreshold) { + shouldHardSuppress = true; + break; + } + + // If nextCandidate survives hard suppression, apply soft suppression + if (nextCandidateScore <= scoreThreshold) break; + } + + return shouldHardSuppress; + } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static __global__ void + suppressNonMaxOverlapKernel(T* boxes, Nd4jLong* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen, + T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength) { + + __shared__ I selectedSize; + __shared__ I* tempOutput; + + if (threadIdx.x == 0) { + selectedSize = outputLength?*outputLength:maxOutputLen; + extern __shared__ unsigned char shmem[]; + tempOutput = (I*)shmem; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (I nextCandidateIndex = start + threadIdx.x; selectedSize < maxOutputLen && nextCandidateIndex < (I)length; ) { + auto originalScore = scoresData[nextCandidateIndex];//nextCandidate._score; + I nextCandidateBoxIndex = indices[nextCandidateIndex]; + auto selectedSizeMark = selectedSize; + + // skip for cases when index is less than 0 (under score threshold) + if (nextCandidateBoxIndex < 0) { + nextCandidateIndex += step; + continue; + } + // check for overlaps + bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize, + nextCandidateIndex, overlapThreshold, scoreThreshold);//false; + T nextCandidateScore = scoresData[nextCandidateIndex]; + + startIndices[nextCandidateIndex] = selectedSize; + if (!shouldHardSuppress) { + if (nextCandidateScore == originalScore) { + // Suppression has not occurred, so select nextCandidate + if (output) + output[selectedSize] = nextCandidateBoxIndex; + tempOutput[selectedSize] = nextCandidateBoxIndex; + math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); + } + + if (nextCandidateScore > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // scoreThreshold; add nextCandidate back onto priority queue. + continue; // in some cases, this index not 0 + } + } + nextCandidateIndex += step; + } + + if (threadIdx.x == 0) { + if (outputLength) + *outputLength = selectedSize; + } + } + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static Nd4jLong + nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + auto stream = context->getCudaStream(); + if (output) + NDArray::prepareSpecialUse({output}, {boxes, scores}); + else { + if (!boxes->isActualOnDeviceSide()) + boxes->syncToDevice(); + if (!scores->isActualOnDeviceSide()) + scores->syncToDevice(); + } + + NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); + NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}); + NDArray selectedScores(*scores); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); + + suppressScores<<<128, 128, 128, *stream>>>(selectedScores.dataBuffer()->specialAsT(), indexBuf, selectedScores.lengthOf(), T(scoreThreshold)); + + sortByValue(extras, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), selectedScores.buffer(), selectedScores.shapeInfo(), selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), true); + indices.tickWriteDevice(); + selectedScores.tickWriteDevice(); + + auto scoresData = selectedScores.dataBuffer()->specialAsT();//, numBoxes, scoresData.begin()); + + auto startIndices = startPositions.dataBuffer()->specialAsT(); + I selectedSize = 0; + Nd4jLong res = 0; + if (output) { // this part used when output shape already calculated to fill up values on output + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), + boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize, + T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT(), output->specialShapeInfo(), + selectedSizeBuf.specialAsT()); + } + else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr. + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), + boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT()); + selectedSizeBuf.syncToPrimary(context, true); + res = *selectedSizeBuf.primaryAsT(); + } + + if (output) + NDArray::registerSpecialUse({output}, {boxes, scores}); + + return res; + } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); + } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INDEXING_TYPES); + Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } } diff --git a/libnd4j/include/ops/declarable/helpers/image_suppression.h b/libnd4j/include/ops/declarable/helpers/image_suppression.h index afce399a6..85224e0f5 100644 --- a/libnd4j/include/ops/declarable/helpers/image_suppression.h +++ b/libnd4j/include/ops/declarable/helpers/image_suppression.h @@ -26,7 +26,10 @@ namespace nd4j { namespace ops { namespace helpers { - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output); + void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output); + Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output); } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp new file mode 100644 index 000000000..71711832d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) { + auto input = reinterpret_cast(vinput); + auto low = reinterpret_cast(vlow); + auto high = reinterpret_cast(vhigh); + auto output = reinterpret_cast(vout); + + T res = 0.0f; + T po = 2.f; + T o = 1.f; + +#pragma omp simd reduction(sumT:res) + for (auto e = 0; e < length; e++) { + T p = input[e]; + T l = low[e]; + T h = high[e]; + if (!(l <= p || h <= p)) { + if (p < l) + res += nd4j::math::nd4j_pow((p - o), po); + else + res += nd4j::math::nd4j_pow((p - h), po); + } + } + + output[0] = nd4j::math::nd4j_pow(res, (T) 0.5f); + } + + void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) { + NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); + + BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.getBuffer(), lowest.getBuffer(), highest.getBuffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES); + + NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp new file mode 100644 index 000000000..528642bb6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -0,0 +1,460 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// implementation of operation for LSTM cell with peep hole connections: +// http://www.bioinf.jku.at/publications/older/2604.pdf +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. +// and +// https://research.google.com/pubs/archive/43905.pdf +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. + + +#include +#include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +namespace nd4j { +namespace ops { +namespace helpers { + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* h, NDArray* c) { + + + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr + // c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] + //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] + + // add biases if they are given + if(b != nullptr) + z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut] + auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut] + auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut] + auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut] + + // peephole connections for input and forget gates + if(Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + } + + applyActivation(zi, params[3], params[4], params[5], zi); // inplace + applyActivation(zf, params[3], params[4], params[5], zf); // inplace + applyActivation(zc, params[6], params[7], params[8], zc); // inplace + + c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut] + + // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation + if(params[2] != 0) + c->applyScalar(scalar::LstmClip, params[2]); + + // peephole connections for output gate + if(Wp != nullptr) + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut] + + applyActivation(zo, params[3], params[4], params[5], zo); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= zo; // [bS, nOut] ◦ [bS, nOut] +} + + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, + NDArray* h, NDArray* hL, NDArray* cL) { + + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr + // hL - output at last step [bS, nOut], optional, may be nullptr + // cL - cell state at last step [bS, nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + const std::vector shapeOut = {bS, nOut}; + + auto h0 = const_cast(hI); + if(!hI) { + h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + h0->nullify(); + } + + auto c0 = const_cast(cI); + if(!cI) { + c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + c0->nullify(); + } + + auto ct = cL; + if(!cL) + cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + + auto ht = hL; + if(!h && !hL) + ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr); + + if(!seqLen) { + + dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes + + xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn] + if(h) + hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut] + } + else { + + dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis + + xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn] + h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + if(h) + hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut] + if(ht) + htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + } + + // loops + if(forward) { + + if(!seqLen) { + + if(!h) { // seqLen and h are absent + + lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + for (int t = 1; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + } + else { // seqLen is absent and h is present + + lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step + for (int t = 1; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + + if(hL) + hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr + } + } + else { + + if(!h) { // seqLen is present and h is absent + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + } + else { // backward + + if(!seqLen) { + + if(!h) { // seqLen and h are absent + + lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + for (int t = sL - 2; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + } + else { // seqLen is absent and h is present + + lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step + for (int t = sL - 2; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + + if(hL) + hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr + } + } + else if(directionMode == 1) { // only backward, no bidirectional mode + + if(!h) { // h is absent and seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = sL - 2; t >= sL - limit; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = sL - 2; t >= sL - limit; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + else { // backward in bidirectional mode + + if(!h) { // h is absent and seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + } + + delete xSet; + delete hSet; + delete h0Set; + delete c0Set; + delete htSet; + delete ctSet; +} + + + +} +} +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java b/libnd4j/include/ops/declarable/helpers/knn.h similarity index 66% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java rename to libnd4j/include/ops/declarable/helpers/knn.h index e81016426..a2de9c71c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfImportState.java +++ b/libnd4j/include/ops/declarable/helpers/knn.h @@ -14,19 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.ops.impl.controlflow; +// +// @author raver119@gmail.com +// -import lombok.Builder; -import lombok.Data; -import org.tensorflow.framework.NodeDef; +#ifndef SAMEDIFF_KNN_H +#define SAMEDIFF_KNN_H -import java.util.List; +#include -@Builder -@Data -public class IfImportState { - private List condNodes; - private List trueNodes; - private List falseNodes; - private String falseBodyScopeName,trueBodyScopeName,conditionBodyScopeName; +namespace nd4j { + namespace ops { + namespace helpers { + void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output); + } + } } + +#endif //SAMEDIFF_KNN_H diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h new file mode 100644 index 000000000..7d94c32e0 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#ifndef LIBND4J_LSTMLAYER_H +#define LIBND4J_LSTMLAYER_H + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* h, NDArray* c); + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, + NDArray* h, NDArray* hL, NDArray* cL); + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, &z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, &z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, &z); + break; + case 3: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, &z, &args); + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, &z); + break; + case 5: + helpers::thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, &z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, &z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, &z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, &z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, &z); + break; + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { + + if(dataFormat == 0 || dataFormat == 3) + return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] + + if(dataFormat == 1) + return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + + return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { + + if(dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] + + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] +} + + +} +} +} + + +#endif //LIBND4J_LSTMLAYER_H diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 3dea41a18..fe1574ea1 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -217,7 +217,7 @@ namespace nd4j { auto var = ctx.variable(pair); auto shape = var->getNDArray()->shapeInfo(); - if (!shape::equalsSoft(out, shape)) { + if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(shape); @@ -237,7 +237,7 @@ namespace nd4j { ctx.setOutputArray(idx, outArr, true); } else { auto array = fout[idx]; - if (!shape::equalsSoft(out, array->shapeInfo())) { + if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 4947a39c0..1a2780d52 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -28,139 +29,679 @@ #include #include -using namespace mkldnn; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(batchnorm_new) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; +namespace nd4j { +namespace ops { +namespace platforms { - auto output = OUTPUT_VARIABLE(0); - const bool applyScale = (bool) INT_ARG(0); - const bool applyOffset = (bool) INT_ARG(1); - const double epsilon = T_ARG(0); +////////////////////////////////////////////////////////////////////////// +static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { - if (applyScale) - gamma = INPUT_VARIABLE(3); - if (applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // also it gives wrong results for formats nhwc and ndhwc - std::vector axes; - if (block.numI() > 2) - for (int i = 2; i < block.numI(); ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf() - 1); + // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // mean -> 1D [c] + // variance -> 1D [c] + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta + // z(output) - same shape as x - std::vector shape({2, mean->lengthOf()}); - NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); - weights({0, 1, 0, 0}).assign(1.0f); - weights({1, 2, 0, 0}).assign(0.0f); + const int xRank = x->rankOf(); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md( - empty), user_dst_md(empty); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto norm_flag = normalization_flags::use_global_stats; - if (applyScale || applyOffset) - norm_flag |= normalization_flags::use_scale_shift; + // input type + mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; - mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, - &batchnorm_src_md, nullptr, &batchnorm_dst_md, - &user_src_md, nullptr, &user_dst_md, axes[0]); + // indicate whether gamma or/and beta are given + auto flags = mkldnn::normalization_flags::use_global_stats; + if (weights != nullptr) + flags |= mkldnn::normalization_flags::use_scale_shift; - auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag); + mkldnn::memory::dims dims; + mkldnn::memory::format_tag format; - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); - auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); - auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, - mean->buffer()); - auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, - variance->buffer()); - auto batchnorm_src_memory = user_src_memory; - mkldnn::memory m(batchnorm_src_md, engine); - if (m.get_desc() != user_src_memory.get_desc()) { - batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); - reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, - batchnorm_src_memory); - } - auto batchnorm_dst_memory = user_dst_memory; - if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); - } - if (applyScale || applyOffset) { - if (gamma != nullptr) { - weights({0, 1, 0, 0}).assign(gamma); - } - if (beta != nullptr) { - weights({1, 2, 0, 0}).assign(beta); - } - - auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); - batch_normalization_forward(batchnorm_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, batchnorm_src_memory}, - {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, - {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, - {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, - {MKLDNN_ARG_DST, batchnorm_dst_memory}}); - } else { - batch_normalization_forward(batchnorm_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, batchnorm_src_memory}, - {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, - {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, - {MKLDNN_ARG_DST, batchnorm_dst_memory}}); - } - if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, - user_dst_memory); - } - stream.wait(); - - return Status::OK(); - } - - PLATFORM_CHECK(batchnorm_new) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool) INT_ARG(0); - const bool applyOffset = (bool) INT_ARG(1); - const double epsilon = T_ARG(0); - - if (applyScale) - gamma = INPUT_VARIABLE(3); - if (applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - - std::vector axes; - if (block.numI() > 2) - for (int i = 2; i < block.numI(); ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf() - 1); - - return block.isUseMKLDNN() && - nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && - axes.size() == 1; - } - } + if(xRank == 2) { + dims = {x->sizeAt(0), x->sizeAt(1)}; + format = mkldnn::memory::format_tag::nc; } + else if(xRank == 4) { + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; + format = mkldnn::memory::format_tag::nchw; + } + else { // xRank = 5 + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; + format = mkldnn::memory::format_tag::ncdhw; + } + + // memory descriptors for arrays + + // x + mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; + + // z, output + mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format); + z_user_md.data.format_kind = mkldnn_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; + z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; + if(xRank > 2) { + z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2]; + z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3]; + } + if(xRank > 4) + z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4]; + + + // batchnorm forward description + mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + mkldnn::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[MKLDNN_ARG_SRC] = x_mkl_mem; + + // z + auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer()); + const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; + if (zReorder) + mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); + args[MKLDNN_ARG_DST] = z_mkl_mem; + + // mean + auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if(weights != nullptr) { + + auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + } + + // run calculations + mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); +} + + +////////////////////////////////////////////////////////////////////////// +static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, + const float epsilon, NDArray* dLdI, NDArray* dLdW) { + + // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // also it gives wrong results for formats nhwc and ndhwc + + // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // mean -> 1D [c] + // variance -> 1D [c] + // dLdO - same shape as x + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta + // dLdI - same shape as x + // dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta + + const int xRank = x->rankOf(); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // input type + mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = mkldnn::normalization_flags::use_global_stats; + if (weights != nullptr) + flags |= mkldnn::normalization_flags::use_scale_shift; + + mkldnn::memory::dims dims; + mkldnn::memory::format_tag format; + + if(xRank == 2) { + dims = {x->sizeAt(0), x->sizeAt(1)}; + format = mkldnn::memory::format_tag::nc; + } + else if(xRank == 4) { + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; + format = mkldnn::memory::format_tag::nchw; + } + else { // xRank = 5 + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; + format = mkldnn::memory::format_tag::ncdhw; + } + + // memory descriptors for arrays + + // x + mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; + + // dLdO + mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format); + dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; + dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; + if(xRank > 2) { + dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2]; + dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3]; + } + if(xRank > 4) + dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; + + // dLdI + mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format); + dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; + dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; + if(xRank > 2) { + dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2]; + dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3]; + } + if(xRank > 4) + dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; + + // batchnorm forward description + mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // batchnorm backprop description + mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + mkldnn::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[MKLDNN_ARG_SRC] = x_mkl_mem; + + // dLdO + auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer()); + const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc(); + auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem; + if (dLdOReorder) + mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); + args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem; + + // mean + auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + + // dLdI + auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer()); + const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc(); + auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem; + args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if(weights != nullptr) { + + auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + + auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); + args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; + } + + // run calculations + mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (dLdIReorder) + mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); + + stream.wait(); + + // shape::printArray(dLdI_mkl_mem.map_data(),8); +} + +PLATFORM_IMPL(batchnorm) { + + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if(applyScale) + gamma = INPUT_VARIABLE(3); + if(applyOffset) + beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma != nullptr) + REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta != nullptr) + REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); + + + NDArray *weights = nullptr; + + if(applyScale || applyOffset) { + + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + + if(applyScale) + (*weights)({0,1, 0,0}).assign(gamma); + else + (*weights)({0,1, 0,0}).assign(1); + if(applyOffset) + (*weights)({1,2, 0,0}).assign(beta); + else + (*weights)({1,2, 0,0}).assign(0); + } + + batchnormMKLDNN(input, mean, variance, weights, epsilon, output); + + delete weights; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(batchnorm) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) + // return false; + + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if(applyScale) + gamma = INPUT_VARIABLE(3); + if(applyOffset) + beta = INPUT_VARIABLE(3 + (int)applyScale); + + + const int numOfIntArgs = block.getIArguments()->size(); + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + DataType outType = output->dataType(); + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && + gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); +} + +////////////////////////////////////////////////////////////////////////// +// PLATFORM_IMPL(batchnorm) { + +// auto input = INPUT_VARIABLE(0); +// auto mean = INPUT_VARIABLE(1); +// auto variance = INPUT_VARIABLE(2); +// NDArray *gamma = nullptr; +// NDArray *beta = nullptr; + +// auto output = OUTPUT_VARIABLE(0); + +// const bool applyScale = (bool) INT_ARG(0); +// const bool applyOffset = (bool) INT_ARG(1); +// const double epsilon = T_ARG(0); + +// if (applyScale) +// gamma = INPUT_VARIABLE(3); +// if (applyOffset) +// beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + +// std::vector axes; +// if (block.numI() > 2) +// for (int i = 2; i < block.numI(); ++i) +// axes.push_back(INT_ARG(i)); +// else +// axes.push_back(input->rankOf() - 1); + +// std::vector shape({2, mean->lengthOf()}); +// NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); +// weights({0, 1, 0, 0}).assign(1.0f); +// weights({1, 2, 0, 0}).assign(0.0f); + +// mkldnn_memory_desc_t empty; +// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); + +// auto flag = mkldnn::normalization_flags::use_global_stats; +// if (applyScale || applyOffset) +// flag |= mkldnn::normalization_flags::use_scale_shift; + +// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, +// &batchnorm_src_md, nullptr, &batchnorm_dst_md, +// &user_src_md, nullptr, &user_dst_md, axes[0]); + +// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); + +// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); +// mkldnn::stream stream(engine); +// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); +// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); +// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); +// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, +// mean->buffer()); +// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, +// variance->buffer()); +// auto batchnorm_src_memory = user_src_memory; +// mkldnn::memory m(batchnorm_src_md, engine); +// if (m.get_desc() != user_src_memory.get_desc()) { +// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); +// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, +// batchnorm_src_memory); +// } +// auto batchnorm_dst_memory = user_dst_memory; +// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { +// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); +// } +// if (applyScale || applyOffset) { +// if (gamma != nullptr) { +// weights({0, 1, 0, 0}).assign(gamma); +// } +// if (beta != nullptr) { +// weights({1, 2, 0, 0}).assign(beta); +// } + +// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); +// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, +// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, +// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// } else { +// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, +// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// } +// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { +// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, +// user_dst_memory); +// } +// stream.wait(); + +// return Status::OK(); +// } + +////////////////////////////////////////////////////////////////////////// +// PLATFORM_CHECK(batchnorm) { +// // we don't want to use mkldnn if cpu doesn't support avx/avx2 +// if (::optimalLevel() < 2) +// return false; + +// auto input = INPUT_VARIABLE(0); +// auto mean = INPUT_VARIABLE(1); +// auto variance = INPUT_VARIABLE(2); +// NDArray *gamma = nullptr; +// NDArray *beta = nullptr; + +// auto output = OUTPUT_VARIABLE(0); + +// const bool applyScale = (bool) INT_ARG(0); +// const bool applyOffset = (bool) INT_ARG(1); +// const double epsilon = T_ARG(0); + +// if (applyScale) +// gamma = INPUT_VARIABLE(3); +// if (applyOffset) +// beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + +// std::vector axes; +// if (block.numI() > 2) +// for (int i = 2; i < block.numI(); ++i) +// axes.push_back(INT_ARG(i)); +// else +// axes.push_back(input->rankOf() - 1); + +// return block.isUseMKLDNN() && +// nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && +// axes.size() == 1; +// } + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(batchnorm_bp) { + + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if(applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); + REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma != nullptr) + REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta != nullptr) + REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); + + + NDArray *weights = nullptr, *dLdW = nullptr; + + if(applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + if(applyScale) + (*weights)({0,1, 0,0}).assign(gamma); + else + (*weights)({0,1, 0,0}).assign(1); + if(applyOffset) + (*weights)({1,2, 0,0}).assign(beta); + else + (*weights)({1,2, 0,0}).assign(0); + } + + *dLdM = 0; + *dLdV = 0; + + batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW); + + if(applyScale || applyOffset) { + if(applyScale) + dLdG->assign((*dLdW)({0,1, 0,0})); + if(applyOffset) + dLdB->assign((*dLdW)({1,2, 0,0})); + + delete weights; + delete dLdW; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(batchnorm_bp) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) + // return false; + + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if(applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType dLdOType = dLdO->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + + DataType dLdIType = dLdI->dataType(); + DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; + DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && + dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && + dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp new file mode 100644 index 000000000..10b392465 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -0,0 +1,544 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include "mkldnnUtils.h" + +using namespace mkldnn; + +namespace nd4j { +namespace ops { +namespace platforms { + +static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const std::vector& params, + NDArray* h, NDArray* hL, NDArray* cL) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + + // ******* + // input weights Wx: + // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 + // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 + // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [1, 1, 4*nOut] when directionMode < 2 + // 2) [1, 2, 4*nOut] when directionMode >= 2 + + // ******* + // initial output hI: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + + // ******* + // output at last step hL: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + // dataFormat: 0 = [sL, bS, nIn] + // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); + const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const int nIn = x->sizeAt(-1); + const int nOut = Wx->sizeAt(-1); + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional + const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) + + // evaluate direction + rnn_direction direction; + switch (directionMode) { + case 0: + direction = rnn_direction::unidirectional_left2right; + break; + case 1: + direction = rnn_direction::unidirectional_right2left; + break; + case 2: + direction = rnn_direction::bidirectional_sum; + break; + default: + direction = rnn_direction::bidirectional_concat; + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, + x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; + + // input type + mkldnn::memory::data_type xType; + if(x->dataType() == DataType::FLOAT32) + xType = mkldnn::memory::data_type::f32; + else if(x->dataType() == DataType::HALF) + xType = mkldnn::memory::data_type::f16; + else + xType = mkldnn::memory::data_type::u8; + + // weights type + mkldnn::memory::data_type wType = xType; + if(xType == mkldnn::memory::data_type::u8) + wType = mkldnn::memory::data_type::s8; + + // bias type + mkldnn::memory::data_type bType = xType; + if(xType == mkldnn::memory::data_type::u8) + bType = mkldnn::memory::data_type::f32; + + // output type + mkldnn::memory::data_type hType; + if(h->dataType() == DataType::FLOAT32) + hType = mkldnn::memory::data_type::f32; + else if(h->dataType() == DataType::HALF) + hType = mkldnn::memory::data_type::f16; + else + hType = mkldnn::memory::data_type::u8; + + + // memory descriptors for arrays + // x + x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc); + x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + + // wx + wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any); + wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); + wx_user_md.data.format_kind = mkldnn_blocked; // overrides format + wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; + wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; + wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; + wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3]; + wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; + + // wr + wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any); + wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); + wr_user_md.data.format_kind = mkldnn_blocked; // overrides format + wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; + wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; + wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; + wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3]; + wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; + + // h + h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc); + h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc); + h_user_md.data.format_kind = mkldnn_blocked; // overrides format + h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; + h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; + h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; + + // b + if(b) { + b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any); + b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo); + b_user_md.data.format_kind = mkldnn_blocked; // overrides format + b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; + b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; + b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; + b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3]; + } + + // hI + if(hI) { + hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); + hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); + hI_user_md.data.format_kind = mkldnn_blocked; // overrides format + hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; + hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; + hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; + hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3]; + } + + // cI + if(cI) { + cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); + cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); + cI_user_md.data.format_kind = mkldnn_blocked; // overrides format + cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; + cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3]; + } + + // hL + if(hL) { + hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any); + hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + hL_user_md.data.format_kind = mkldnn_blocked; // overrides format + hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; + hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; + hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; + hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3]; + } + + if(cL) { + cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + cL_user_md.data.format_kind = mkldnn_blocked; // overrides format + cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; + cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; + cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; + cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3]; + } + + // lstm memory description + lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, + x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, + h_lstm_md, hL_lstm_md, cL_lstm_md); + + mkldnn::stream stream(engine); + + // lstm primitive description + lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // provide memory and check whether reorder is required + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc(); + auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; + if (xReorder) + reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem); + args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem; + + // wx + auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer()); + const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc(); + auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; + if (wxReorder) + reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem); + args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem; + + // wr + auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer()); + const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc(); + auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; + if (wrReorder) + reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem); + args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem; + + // h + auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer()); + const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); + auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; + args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem; + + // b + if(b) { + auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer()); + const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc(); + auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; + if (bReorder) + reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem); + args[MKLDNN_ARG_BIAS] = b_lstm_mem; + } + + // hI + if(hI) { + auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer()); + const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc(); + auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; + if (hIReorder) + reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem); + args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem; + } + + // cI + if(cI) { + auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer()); + const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc(); + auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; + if (cIReorder) + reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem); + args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem; + } + + bool hLReorder(false), cLReorder(false); + mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; + + // hL + if(hL) { + hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer()); + hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); + hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; + args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem; + } + + // cL + if(cL) { + cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer()); + cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); + cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; + args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem; + } + + // run calculations + lstm_forward(lstm_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (hReorder) + reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem); + if(hLReorder) + reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem); + if(cLReorder) + reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(lstmLayer) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !"); + REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); + REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); + REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); + REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); + REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); + REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional + + // permut x and h to tnc format if they have ntc format + NDArray* xP(const_cast(x)), *hP(h); + if(dataFormat == 1) { + xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn] + hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] + } + + // reshape arrays in accordance to mkl allowed formats + NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); + + WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut})); + WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut})); + if(b) + bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); + if(hI) + hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); + if(cI) + cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); + if(hL) + hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut})); + if(cL) + cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut})); + + lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); + + delete WxR; + delete WrR; + delete bR; + delete hIR; + delete cIR; + delete hLR; + delete cLR; + + if(dataFormat == 1) { + delete xP; + delete hP; + } + + return Status::OK(); +} + +PLATFORM_CHECK(lstmLayer) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) { + // return false; + // } + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + DataType xType = x->dataType(); + DataType WxType = Wx->dataType(); + DataType WrType = Wr->dataType(); + DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); + DataType hIType = hI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? hI->dataType() : xType; + DataType hType = h != nullptr ? h->dataType() : xType; + DataType hLType = hL != nullptr ? hL->dataType() : xType; + DataType cLType = cL != nullptr ? cL->dataType() : xType; + + return block.isUseMKLDNN() && ( + (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || + (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || + (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) + ); +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 4fac4a1b7..b84506c3b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -305,50 +305,50 @@ namespace nd4j { }; - void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->getShapeInfo(); - Nd4jLong rank = shape[0]; - Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - Nd4jLong dim2 = axis >= 2 ? 1 : 2; - Nd4jLong dim3 = axis >= 3 ? 2 : 3; - mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + // mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, + // mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { + // const Nd4jLong* shape = src->getShapeInfo(); + // Nd4jLong rank = shape[0]; + // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + // Nd4jLong dim2 = axis >= 2 ? 1 : 2; + // Nd4jLong dim3 = axis >= 3 ? 2 : 3; + // mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - auto type = mkldnn::memory::data_type::f32; - auto format = mkldnn::memory::format_tag::nchw; - auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" + // auto type = mkldnn::memory::data_type::f32; + // auto format = mkldnn::memory::format_tag::nchw; + // auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" - if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { - *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides format - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } + // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { + // *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_src_md->data.format_kind = mkldnn_blocked; // overrides format + // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; + // user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; + // } - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { - *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } + // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { + // *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format + // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; + // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; + // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; + // user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; + // } - if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { - *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides format - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } - }; + // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { + // *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_dst_md->data.format_kind = mkldnn_blocked; // overrides format + // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; + // user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + // } + // }; void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 4e79974a5..14cc41a96 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -62,7 +62,11 @@ namespace nd4j{ DECLARE_PLATFORM(lrn); - DECLARE_PLATFORM(batchnorm_new); + DECLARE_PLATFORM(batchnorm); + + DECLARE_PLATFORM(batchnorm_bp); + + DECLARE_PLATFORM(lstmLayer); } } diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index ca408e8dc..0e9c99636 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -48,4 +48,11 @@ namespace nd4j { BroadcastOpsTuple BroadcastOpsTuple::Subtract() { return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); } + BroadcastOpsTuple BroadcastOpsTuple::IGamma() { + return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma); + } + BroadcastOpsTuple BroadcastOpsTuple::IGammac() { + return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); + } + } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index a738f0bdc..601481b21 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1482,6 +1482,52 @@ namespace simdOps { }; + template + class IGamma { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igamma(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + + template + class IGammac { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igammac(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + template class Round { public: @@ -1811,6 +1857,17 @@ namespace simdOps { } }; + template + class Affine { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return params[0] * d1 + params[1]; + } + }; + template class SigmoidDerivative { public: @@ -2005,6 +2062,17 @@ namespace simdOps { } }; + template + class ScaledTanh { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return params[0] * nd4j::math::nd4j_tanh(params[1] * d1); + } + }; + template class RectifiedTanh { public: diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 2c6de814a..d35346e2b 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -413,7 +413,7 @@ namespace nd4j { return ctx; }; - nd4j::ops::batchnorm_new batchnorm; + nd4j::ops::batchnorm batchnorm; DeclarableBenchmark benchmark(batchnorm, "batchnorm"); output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization"); @@ -1822,7 +1822,7 @@ namespace nd4j { std::string result; long start = nowMs(); - + // set 1 nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); result += fastScalarBenchmark(); diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 908323369..d0af6c8ed 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -223,6 +223,12 @@ namespace nd4j { return nd4j_sgn(val); } + template + math_def inline Z nd4j_gamma(X a); + + template + math_def inline Z nd4j_lgamma(X x); + //#ifndef __CUDACC__ /* template<> @@ -656,9 +662,56 @@ namespace nd4j { return p_pow(static_cast(val), static_cast(val2)); } + /** + * LogGamma(a) - float point extension of ln(n!) + **/ + template + math_def inline Z nd4j_lgamma(X x) { +// if (x <= X(0.0)) +// { +// std::stringstream os; +// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given."; +// throw std::invalid_argument( os.str() ); +// } + + if (x < X(12.0)) { + return nd4j_log(nd4j_gamma(x)); + } + + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + static const double c[8] = { + 1.0/12.0, + -1.0/360.0, + 1.0/1260.0, + -1.0/1680.0, + 1.0/1188.0, + -691.0/360360.0, + 1.0/156.0, + -3617.0/122400.0 + }; + + double z = Z(1.0 / Z(x * x)); + double sum = c[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += c[i]; + } + + double series = sum / Z(x); + + static const double halfLogTwoPi = 0.91893853320467274178032973640562; + + return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + halfLogTwoPi + series); + } - template + + template math_def inline T nd4j_re(T val1, T val2) { if (val1 == (T) 0.0f && val2 == (T) 0.0f) return (T) 0.0f; @@ -731,7 +784,127 @@ namespace nd4j { template math_def inline void nd4j_swap(T &val1, T &val2) { T temp = val1; val1=val2; val2=temp; - }; + }; + + template + math_def inline Z nd4j_gamma(X a) { +// nd4j_lgamma(a); +// return (Z)std::tgamma(a); + // Split the function domain into three intervals: + // (0, 0.001), [0.001, 12), and (12, infinity) + + /////////////////////////////////////////////////////////////////////////// + // First interval: (0, 0.001) + // + // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... + // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3. + // The relative error over this interval is less than 6e-7. + + const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant + + if (a < X(0.001)) + return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); + + /////////////////////////////////////////////////////////////////////////// + // Second interval: [0.001, 12) + + if (a < X(12.0)) { + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + + double y = (double)a; + int n = 0; + bool argWasLessThanOne = y < 1.0; + + // Add or subtract integers as necessary to bring y into (1,2) + // Will correct for this below + if (argWasLessThanOne) { + y += 1.0; + } + else { + n = static_cast(floor(y)) - 1; // will use n later + y -= n; + } + + // numerator coefficients for approximation over the interval (1,2) + static const double p[] = { + -1.71618513886549492533811E+0, + 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, + 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, + -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, + 6.64561438202405440627855E+4 + }; + + // denominator coefficients for approximation over the interval (1,2) + static const double q[] = { + -3.08402300119738975254353E+1, + 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, + -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, + 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, + -1.15132259675553483497211E+5 + }; + + double num = 0.0; + double den = 1.0; + + + double z = y - 1; + for (auto i = 0; i < 8; i++) { + num = (num + p[i]) * z; + den = den * z + q[i]; + } + double result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (argWasLessThanOne) { + // Use identity gamma(z) = gamma(z+1)/z + // The variable "result" now holds gamma of the original y + 1 + // Thus we use y-1 to get back the orginal y. + result /= (y - 1.0); + } + else { + // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + for (auto i = 0; i < n; i++) + result *= y++; + } + + return Z(result); + } + + /////////////////////////////////////////////////////////////////////////// + // Third interval: [12, infinity) + + if (a > 171.624) { + // Correct answer too large to display. Force +infinity. + return Z(DOUBLE_MAX_VALUE); + //DataTypeUtils::infOrMax(); + } + + return nd4j::math::nd4j_exp(nd4j::math::nd4j_lgamma(a)); + } + + template + math_def inline Z nd4j_igamma(X a, Y x) { + Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); + auto sum = Z(0.); + auto denom = Z(1.); + for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { + denom *= (a + i); + sum += nd4j_pow(x, i) / denom; + } + return aim * sum; + } + + template + math_def inline Z nd4j_igammac(X a, Y x) { + return Z(1.) - nd4j_igamma(a, x); + } #ifdef __CUDACC__ namespace atomics { @@ -1473,4 +1646,4 @@ inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 } -#endif /* TEMPLATEMATH_H_ */ \ No newline at end of file +#endif /* TEMPLATEMATH_H_ */ diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index b6f5f125d..458858c57 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2385,129 +2385,6 @@ TEST_F(DeclarableOpsTests1, CompactLaunchTests2) { ASSERT_TRUE(exp.equalsTo(&z)); } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test1) { - - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,2,3,2}); - auto gamma = NDArrayFactory::create('c', {2,3,2,3,2}); - auto beta = NDArrayFactory::create('c', {2,3,2,3,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011,10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - - -TEST_F(DeclarableOpsTests1, batchnorm_test2) { - - auto input = NDArrayFactory::create('c', {2,3,1,3,1}); - auto mean = NDArrayFactory::create('c', {1,3,2,1,2}); - auto variance = NDArrayFactory::create('c', {2,1,2,3,2}); - auto gamma = NDArrayFactory::create('c', {2,3,2,3,1}); - auto beta = NDArrayFactory::create('c', {1,3,2,1,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test3) { - - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,1}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011, 10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test4) { - - auto input = NDArrayFactory::create('c', {3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance= NDArrayFactory::create('c', {2,3,1,3,2}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - - auto expected= NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} //////////////////////////////////////////////////////////////////// // TEST_F(DeclarableOpsTests1, sru_old_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 0652a398e..84e3b4e8f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) { delete result; } +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test1) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create('c', {1,3,4}, { + 0.659917, 0.61757898, 0.59726304, 0.58478117, + 0.0066205109, 0.022211598, 0.040677428, 0.059117373, + 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); + + nd4j::ops::igamma op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test2) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , + 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, + 0.993379, 0.977788, 0.959323, 0.940883, + 0.999996, 0.999914, 0.999564, 0.998773}); + + nd4j::ops::igammac op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) { @@ -1916,7 +1960,82 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); - result->printBuffer("NonMaxSuppression OUtput2"); +// result->printBuffer("NonMaxSuppression OUtput2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {1,}, {3}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap1 Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); //3 + NDArray max_num = NDArrayFactory::create(5); + NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap Output"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1940,7 +2059,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - result->printIndexedBuffer("Cropped and Resized"); +// result->printIndexedBuffer("Cropped and Resized"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2269,7 +2388,35 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { +TEST_F(DeclarableOpsTests10, batchnorm_test1) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {4}); @@ -2286,7 +2433,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { gamma.assign(1.2); beta.assign(1.); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); @@ -2302,7 +2449,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {3}, {1.05, 1.1, 1.15}); @@ -2315,7 +2462,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { input.linspace(0.1, 0.1); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); @@ -2330,7 +2477,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); @@ -2343,7 +2490,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { input.linspace(0.1, 0.1); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); @@ -2357,6 +2504,63 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, batchnorm_test5) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20., + -19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438, + -12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, batchnorm_test6) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 , + 20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 , + -12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 2ef9e2309..9d460f152 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -983,5 +983,952 @@ TEST_F(DeclarableOpsTests13, mergemax_2) { ASSERT_EQ(20, status); } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_1) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, + 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, + 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, + 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, + 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *h = results->at(0); + auto *cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_2) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999, + 0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764, + 0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922}); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *h = results->at(0); + auto *cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_3) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL,bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139, + 0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106, + 0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_4) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289, + -0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647, + 0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395, + -0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862, + 0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308, + -0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128, + 0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, + -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, + -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_5) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406, + 0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 , + 0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739, + 0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681, + 0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492, + -0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609, + -0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + // h->printBuffer(); + // hL->printBuffer(); + // cL->printBuffer(); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_6) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060, + 0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491, + 0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, + -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, + -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_7) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556, + 0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805, + 0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_8) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 1.; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674, + 0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028, + 0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_9) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05; + Wp({1,2, 0,0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843, + 0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139, + 0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 , + 0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 , + 0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502, + -0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425, + -0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_10) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023, + 0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_11) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209, + 0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627, + 0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, + 0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_12) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05; + Wp({1,2, 0,0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103, + -0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025, + -0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, + 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, + 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315, + 0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702, + 0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index fc7f29e3c..6eabc964a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -505,9 +505,9 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { } TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { - int seqLen = 32; - int bS = 64; - int nIn = 32; + int seqLen = 8; + int bS = 16; + int nIn = 8; auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index adbff7f83..d95e86b1c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -133,4 +133,21 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { ASSERT_EQ(e, *z); delete result; +} + +TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { + auto input = NDArrayFactory::create('c', {512}); + auto low = NDArrayFactory::create('c', {512}); + auto high = NDArrayFactory::create('c', {512}); + + auto output = NDArrayFactory::create(0.0f); + + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); + + nd4j::ops::knn_mindistance op; + auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 84a1f2dc9..e36b78a98 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2883,78 +2883,336 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { - auto input = NDArrayFactory::create('c', {3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,2}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3,2}); + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779, + 1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); mean.assign(1.); variance.assign(0.5); gamma.assign(1.2); - beta.assign(1.); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &beta, &dLdO}, {1e-5}, {1,1}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,1}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3,2}); + NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}); + NDArray beta ('c', {3}, nd4j::DataType::DOUBLE); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981, + 1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936}); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}); input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma}, {1e-5}, {1,0}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &dLdO}, {1e-5}, {1,0}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { - auto input = NDArrayFactory::create('c', {2,3,1,3}); - auto mean = NDArrayFactory::create('c', {1,3,2,1}); - auto variance = NDArrayFactory::create('c', {2,1,2,3}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3}); + NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); + NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 , + 1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841}); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}); input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance}, {1e-5}, {0,0}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &dLdO}, {1e-5}, {0,0}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753, + -2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753, + -2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { + + NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866, + 1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164, + 4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463, + 7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761, + 9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { + + NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611, + -2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203, + -6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784, + -9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 9db8a5f06..2ed43d08a 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -28,6 +28,7 @@ #include #include #include +#include using namespace nd4j; @@ -2342,5 +2343,155 @@ TEST_F(HelpersTests1, softmaxDerivative_3) { ASSERT_TRUE(expOutput.equalsTo(output)); } +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_1) { + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_2) { + + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, nd4j::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_3) { + + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index c95fc11e3..829117bed 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -64,7 +64,7 @@ TEST_F(MklDnnTests, helpers_includer) { nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp; nd4j::ops::platforms::PLATFORM_lrn lrn; - nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm; + nd4j::ops::platforms::PLATFORM_batchnorm batchnorm; printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm}); #endif diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index eb3424007..32df3e69d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -442,22 +442,12 @@ public abstract class DifferentialFunction { setInstanceId(); if(sameDiff != null) { sameDiff.addArgsFor(args, this); - for (int i = 0; i < args.length; i++) { - if (args[i].isPlaceHolder()) { - sameDiff.addPropertyToResolve(this, args[i].getVarName()); - } - } } } public void replaceArg(int i, SDVariable newArg){ if(sameDiff != null){ sameDiff.replaceArgFor(i, newArg, this); - if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){ - sameDiff.removePropertyToResolve(this, args()[i].getVarName()); - } else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){ - sameDiff.addPropertyToResolve(this, newArg.getVarName()); - } } } @@ -483,7 +473,7 @@ public abstract class DifferentialFunction { SDVariable[] outputVars = outputVariables(); String[] out = new String[outputVars.length]; for( int i=0; i 1) { - INDArray ret = sameDiff.getArrForVarName(args()[1].getVarName()); - return ret; - } - return null; - } - - @JsonIgnore - private INDArray getZ() { - if(isInPlace()) - return getX(); - SDVariable opId = outputVariables()[0]; - INDArray ret = opId.getArr(); - return ret; - } - - /** @@ -860,4 +773,8 @@ public abstract class DifferentialFunction { public int getNumOutputs(){return -1;} + /** + * Clear the input and output INDArrays, if any are set + */ + public abstract void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 1a40fbd11..0bc395803 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.NoOp; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; @@ -183,7 +184,6 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.ops.impl.transforms.Constant; import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; @@ -351,12 +351,6 @@ public class DifferentialFunctionFactory { } } - - public Constant val(SDVariable iX) { - return new Constant(sameDiff(), iX, - iX.getShape()); - } - public ExternalErrorsFunction externalErrors(SDVariable... inputs) { return externalErrors(null, inputs); } @@ -383,10 +377,6 @@ public class DifferentialFunctionFactory { return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); } - public SDVariable constant(SDVariable input, long... shape) { - return new Constant(sameDiff(), input, (shape != null && shape.length > 0 ? shape : null)).outputVariable(); - } - public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); } @@ -981,8 +971,8 @@ public class DifferentialFunctionFactory { return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); } - public SDVariable biasAdd(SDVariable input, SDVariable bias) { - return new BiasAdd(sameDiff(), input, bias).outputVariable(); + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); } public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) { @@ -1055,7 +1045,7 @@ public class DifferentialFunctionFactory { public SDVariable gradientBackwardsMarker(SDVariable iX) { - return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.getVarName() + "-pairgrad", 1.0)).outputVariable(); + return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable(); } public SDVariable abs(SDVariable iX) { @@ -2629,7 +2619,6 @@ public class DifferentialFunctionFactory { validateDifferentialFunctionsameDiff(func); validateDifferentialFunctionsameDiff(input); - // FIXME: int cast! return tile(func, ArrayUtil.toInts(input.getShape())); } @@ -2649,6 +2638,33 @@ public class DifferentialFunctionFactory { return new NextIteration(sameDiff, x).outputVariable(); } + public SDVariable adjustContrast(SDVariable in, SDVariable factor) { + return new AdjustContrast(sameDiff, in, factor).outputVariable(); + } + + public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { + return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); + } + + public SDVariable bitCast(SDVariable in, SDVariable dataType) { + return new BitCast(sameDiff, in, dataType).outputVariable(); + } + + public SDVariable compareAndBitpack(SDVariable threshold) { + return new CompareAndBitpack(sameDiff, threshold).outputVariable(); + } + + public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { + return new DivideNoNan(sameDiff, in1, in2).outputVariable(); + } + + public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { + return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); + } + + public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) { + return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); + } public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java index 8d6a051df..18e3b934b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java @@ -122,7 +122,7 @@ public interface Listener { /** * Called when any activation becomes available. *

- * The activation will most likely be freed later, use detach() if you need to save it.
+ * The activation will most likely be freed later, use dup() if you need to save it.
*
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java index 08722b06e..9bdc54dd7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java @@ -178,7 +178,7 @@ public class ListenerEvaluations { * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { - return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + return trainEvaluation(variable.name(), labelIndex, evaluations); } /** @@ -202,7 +202,7 @@ public class ListenerEvaluations { * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { - return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + return validationEvaluation(variable.name(), labelIndex, evaluations); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java index 34b305001..33baf7099 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java @@ -167,7 +167,7 @@ public class ListenerVariables { String[] names = new String[variables.length]; for (int i = 0; i < variables.length; i++) - names[i] = variables[i].getVarName(); + names[i] = variables[i].name(); return requireVariables(op, names); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 452636d57..6c38c6c9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -226,7 +226,7 @@ public class UIListener extends BaseListener { List sdVars = sd.variables(); List varNames = new ArrayList<>(sdVars.size()); for(SDVariable v : sdVars){ - varNames.add(v.getVarName()); + varNames.add(v.name()); } if(varNames.size() != vars.size() || !varNames.containsAll(vars)){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index b063e18a2..149ea5f2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -91,7 +91,7 @@ public class EvaluationRecord { * @param param The target param/variable */ public List evaluations(SDVariable param) { - return evaluations(param.getVarName()); + return evaluations(param.name()); } /** @@ -105,7 +105,7 @@ public class EvaluationRecord { * Get the evaluation for param at the specified index */ public IEvaluation evaluation(SDVariable param, int index) { - return evaluation(param.getVarName(), index); + return evaluation(param.name(), index); } /** @@ -132,7 +132,7 @@ public class EvaluationRecord { * @param param The target param/variable */ public T evaluation(SDVariable param) { - return evaluation(param.getVarName()); + return evaluation(param.name()); } /** @@ -174,7 +174,7 @@ public class EvaluationRecord { * @param evalClass The type of evaluation to look for */ public > T evaluation(SDVariable param, Class evalClass) { - return evaluation(param.getVarName(), evalClass); + return evaluation(param.name(), evalClass); } /** @@ -209,7 +209,7 @@ public class EvaluationRecord { * @param metric The metric to calculate */ public double getValue(SDVariable param, IMetric metric) { - return getValue(param.getVarName(), metric); + return getValue(param.name(), metric); } /** @@ -235,7 +235,7 @@ public class EvaluationRecord { * @param metric The metric to calculate */ public double getValue(SDVariable param, int index, IMetric metric) { - return getValue(param.getVarName(), index, metric); + return getValue(param.name(), index, metric); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java index f0dcecb49..809400c0b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -24,6 +24,7 @@ import lombok.Getter; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IMetric; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -124,7 +125,7 @@ public class History { * Only works if there is only one evaluation with the given metric for param */ public List trainingEval(SDVariable param, IMetric metric){ - return trainingEval(param.getVarName(), metric); + return trainingEval(param.name(), metric); } /** @@ -148,7 +149,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List trainingEval(SDVariable param, int index, IMetric metric){ - return trainingEval(param.getVarName(), index, metric); + return trainingEval(param.name(), index, metric); } /** @@ -183,7 +184,7 @@ public class History { * Only works if there is only one evaluation for param. */ public List trainingEval(SDVariable param){ - return trainingEval(param.getVarName()); + return trainingEval(param.name()); } /** @@ -207,7 +208,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List trainingEval(SDVariable param, int index){ - return trainingEval(param.getVarName(), index); + return trainingEval(param.name(), index); } /** @@ -229,7 +230,7 @@ public class History { * Only works if there is only one evaluation with the given metric for param */ public List validationEval(SDVariable param, IMetric metric){ - return validationEval(param.getVarName(), metric); + return validationEval(param.name(), metric); } /** @@ -253,7 +254,7 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List validationEval(SDVariable param, int index, IMetric metric){ - return validationEval(param.getVarName(), index, metric); + return validationEval(param.name(), index, metric); } /** @@ -288,7 +289,7 @@ public class History { * Only works if there is only one evaluation for param. */ public List validationEval(SDVariable param){ - return validationEval(param.getVarName()); + return validationEval(param.name()); } /** @@ -312,13 +313,14 @@ public class History { * Index determines the evaluation used not the epoch's results to return. */ public List validationEval(SDVariable param, int index){ - return validationEval(param.getVarName(), index); + return validationEval(param.name(), index); } /** * Gets the training evaluations ran during the last epoch */ public EvaluationRecord finalTrainingEvaluations(){ + Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty"); return trainingHistory.get(trainingHistory.size() - 1); } @@ -326,6 +328,7 @@ public class History { * Gets the validation evaluations ran during the last epoch */ public EvaluationRecord finalValidationEvaluations(){ + Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty"); return validationHistory.get(validationHistory.size() - 1); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java index 493950bbf..f3f24b98a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java @@ -116,7 +116,7 @@ public class LossCurve { * Return all mean loss values for a given variable */ public float[] meanLoss(@NonNull SDVariable loss){ - return meanLoss(loss.getVarName()); + return meanLoss(loss.name()); } /** @@ -143,7 +143,7 @@ public class LossCurve { * See {@link #meanLoss(int)} */ public float meanLoss(@NonNull SDVariable loss, int epoch){ - return meanLoss(loss.getVarName(), epoch); + return meanLoss(loss.name(), epoch); } /** @@ -162,7 +162,7 @@ public class LossCurve { * Return the mean loss value for a given variable on the last epoch. */ public float lastMeanLoss(@NonNull SDVariable loss){ - return lastMeanLoss(loss.getVarName()); + return lastMeanLoss(loss.name()); } /** @@ -189,7 +189,7 @@ public class LossCurve { * A positive delta means the loss is increasing, and a negative delta means it is decreasing. */ public double lastMeanDelta(SDVariable loss){ - return lastMeanDelta(loss.getVarName()); + return lastMeanDelta(loss.name()); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index a97668e8e..6d9e34ed0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -16,34 +16,23 @@ package org.nd4j.autodiff.samediff; -import java.util.Objects; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.weightinit.WeightInitScheme; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Map; +import java.util.Objects; /** * @@ -70,10 +59,6 @@ public class SDVariable implements Serializable { @Setter protected VariableType variableType; - @Getter - @Setter - protected WeightInitScheme weightInitScheme; - @Setter(AccessLevel.NONE) protected long[] shape; @@ -86,9 +71,7 @@ public class SDVariable implements Serializable { // autogen_tag::sdvars::start - public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType, WeightInitScheme weightInitScheme){ - Preconditions.checkState(weightInitScheme == null || varType == VariableType.VARIABLE, "Weight initalization schemes can only be applied to VARIABLE type" + - " SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme); + public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType){ Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); varName = sameDiff.generateNewVarName(varName, 0, true); @@ -97,10 +80,25 @@ public class SDVariable implements Serializable { this.varName = varName; this.variableType = varType; this.dataType = dataType; - this.weightInitScheme = weightInitScheme; this.shape = shape; } + /** + * Get the name of the SDVariable + * @return Name of the variable + */ + public String name(){ + return varName; + } + + /** + * @deprecated Use {@link #name()} + */ + @Deprecated + public String getVarName(){ + return name(); + } + /** * Returns true if this variable is a place holder * @return @@ -113,30 +111,6 @@ public class SDVariable implements Serializable { return variableType == VariableType.CONSTANT; } - /** - * Allocate and return a new array - * based on the vertex id and weight initialization. - * @return the allocated array - */ - public INDArray storeAndAllocateNewArray() { - Preconditions.checkState(variableType == VariableType.VARIABLE, "Unable to allocate and store array for variable of type %s: only" + - " VARIABLE type variables can be initialized using this method", variableType); - - if(!sameDiff.arrayAlreadyExistsForVarName(varName)){ - long[] shape = getShape(); - INDArray arr = getWeightInitScheme().create(dataType(), shape); - sameDiff.associateArrayWithVariable(arr, this); - if(log.isTraceEnabled()){ - log.trace("Generated and stored new array for variable \"{}\": shape {}", getVarName(), Arrays.toString(arr.shape())); - } - return arr; - } - - //Variable type SDVariables: shape should never change (i.e., these are params in the net!) - INDArray ret = getArr(); - return ret; - } - /** * A getter for the allocated ndarray with this {@link SDVariable}. * @@ -166,26 +140,14 @@ public class SDVariable implements Serializable { public INDArray getArr(boolean enforceExistence){ if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) return sameDiff.getArrForVarName(getVarName()); - - //initialize value if it's actually a scalar constant (zero or 1 typically...) - if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){ - INDArray arr = weightInitScheme.create(dataType, shape); - sameDiff.associateArrayWithVariable(arr, this); - if(log.isTraceEnabled()){ - log.trace("getArr() for variable \"{}\" allocated new array: shape {}", getVarName(), Arrays.toString(getShape())); - } - return arr; - } else if(sameDiff.getShapeForVarName(getVarName()) == null) { - if (enforceExistence) { - throw new IllegalStateException("Cannot get array for SDVariable \"" + getVarName() + "\": no array has" + - " been defined, and array shape cannot be calculated"); - } - if(log.isTraceEnabled()){ - log.trace("SDVariable.getArr(): could not get array for variable {}: shape is null", getVarName()); - } - return null; + if(variableType == VariableType.ARRAY){ + throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead"); } - return sameDiff.getArrForVarName(getVarName()); + INDArray ret = sameDiff.getArrForVarName(getVarName()); + if(enforceExistence && ret == null){ + throw new IllegalStateException("No array exists for variable \"" + name() + "\""); + } + return ret; } @@ -211,8 +173,8 @@ public class SDVariable implements Serializable { * created automatically when training is performed. */ public SDVariable getGradient() { - Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" + - " point variables have gradients", getVarName(), dataType()); + Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" + + " point variables have gradients", dataType(), getVarName()); return sameDiff.getGradForVariable(getVarName()); } @@ -222,21 +184,13 @@ public class SDVariable implements Serializable { * @return Shape of the variable */ public long[] getShape() { - if (variableType == VariableType.PLACEHOLDER && getArr() == null) { - if (shape != null) + if (variableType == VariableType.PLACEHOLDER ) { return shape; - else - return new long[0]; + } else if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){ + return getArr().shape(); } - long[] initialShape = sameDiff.getShapeForVarName(getVarName()); - if(initialShape == null) { - val arr = getArr(); - if(arr != null) - return arr.shape(); - } - - return initialShape; + return null; } public void setShape(long... shape){ @@ -254,7 +208,7 @@ public class SDVariable implements Serializable { public DataType dataType() { if(this.dataType == null){ //Try to infer datatype instead of returning null - if(getArr() != null){ + if(variableType != VariableType.ARRAY && getArr() != null){ this.dataType = getArr().dataType(); } } @@ -1495,8 +1449,8 @@ public class SDVariable implements Serializable { * @return */ public INDArray eval() { - sameDiff.exec(null, getVarName()); - return getArr(); + Map m = sameDiff.output((Map)null, name()); + return m.get(name()); } @@ -1505,8 +1459,8 @@ public class SDVariable implements Serializable { * @return */ public INDArray eval(Map placeholders) { - sameDiff.exec(placeholders, getVarName()); - return getArr(); + Map m = sameDiff.output(placeholders, name()); + return m.get(name()); } @@ -1518,26 +1472,59 @@ public class SDVariable implements Serializable { /** * Add a control dependency for this variable on the specified variable.
- * Control depnedencies can be used to enforce the execution order. + * Control dependencies can be used to enforce the execution order. * For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even * if Y wouldn't normally depend on the result/values of X. * * @param controlDependency Control dependency to add for this variable */ public void addControlDependency(SDVariable controlDependency){ - String cdN = controlDependency.getVarName(); - String n = this.getVarName(); - Variable v = sameDiff.getVariables().get(n); - if(v.getControlDeps() == null) - v.setControlDeps(new ArrayList()); - if(!v.getControlDeps().contains(cdN)) - v.getControlDeps().add(cdN); + Variable vThis = sameDiff.getVariables().get(getVarName()); + Variable vCD = sameDiff.getVariables().get(controlDependency.name()); - Variable v2 = sameDiff.getVariables().get(cdN); - if(v2.getControlDepsForVar() == null) - v2.setControlDepsForVar(new ArrayList()); - if(!v2.getControlDepsForVar().contains(n)) - v2.getControlDepsForVar().add(n); + //If possible: add control dependency on ops + if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){ + //Op -> Op case + SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); + SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp()); + + if(oThis.getControlDeps() == null) + oThis.setControlDeps(new ArrayList()); + if(!oThis.getControlDeps().contains(oCD.getName())) + oThis.getControlDeps().add(oCD.getName()); + + if(oCD.getControlDepFor() == null) + oCD.setControlDepFor(new ArrayList()); + if(!oCD.getControlDepFor().contains(oThis.getName())) + oCD.getControlDepFor().add(oThis.getName()); + } else { + if(vThis.getOutputOfOp() != null){ + //const/ph -> op case + SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp()); + + if(oThis.getVarControlDeps() == null) + oThis.setVarControlDeps(new ArrayList()); + + if(!oThis.getVarControlDeps().contains(vCD.getName())) + oThis.getVarControlDeps().add(vCD.getName()); + + if(vCD.getControlDepsForOp() == null) + vCD.setControlDepsForOp(new ArrayList()); + if(!vCD.getControlDepsForOp().contains(oThis.getName())) + vCD.getControlDepsForOp().add(oThis.getName()); + } else { + //const/ph -> const/ph case + if(vThis.getControlDeps() == null) + vThis.setControlDeps(new ArrayList()); + if(!vThis.getControlDeps().contains(vCD.getName())) + vThis.getControlDeps().add(vCD.getName()); + + if(vCD.getControlDepsForVar() == null) + vCD.setControlDepsForVar(new ArrayList()); + if(!vCD.getControlDepsForVar().contains(vThis.getName())) + vCD.getControlDepsForVar().add(vThis.getName()); + } + } } /** @@ -1703,7 +1690,6 @@ public class SDVariable implements Serializable { SDVariable v = new SDVariable(); v.varName = varName; v.variableType = variableType; - v.weightInitScheme = weightInitScheme; v.shape = shape == null ? null : shape.clone(); v.dataType = dataType; v.sameDiff = sd; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index ddd9ecbb2..c79677c1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,93 +16,32 @@ package org.nd4j.autodiff.samediff; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; - import com.google.flatbuffers.FlatBufferBuilder; -import java.io.BufferedInputStream; -import java.io.BufferedOutputStream; -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.lang.reflect.Method; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.IdentityHashMap; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.Set; -import java.util.Stack; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; +import lombok.*; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.listeners.At; -import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.listeners.ListenerResponse; -import org.nd4j.autodiff.listeners.Loss; -import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.OutputConfig; -import org.nd4j.autodiff.samediff.internal.AbstractSession; -import org.nd4j.autodiff.samediff.internal.DataTypesSession; -import org.nd4j.autodiff.samediff.internal.InferenceSession; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.internal.Variable; -import org.nd4j.autodiff.samediff.ops.SDBaseOps; -import org.nd4j.autodiff.samediff.ops.SDBitwise; -import org.nd4j.autodiff.samediff.ops.SDCNN; -import org.nd4j.autodiff.samediff.ops.SDImage; -import org.nd4j.autodiff.samediff.ops.SDLoss; -import org.nd4j.autodiff.samediff.ops.SDMath; -import org.nd4j.autodiff.samediff.ops.SDNN; -import org.nd4j.autodiff.samediff.ops.SDRNN; -import org.nd4j.autodiff.samediff.ops.SDRandom; +import org.nd4j.autodiff.samediff.internal.*; +import org.nd4j.autodiff.samediff.ops.*; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; -import org.nd4j.graph.ExecutionMode; -import org.nd4j.graph.FlatArray; -import org.nd4j.graph.FlatConfiguration; -import org.nd4j.graph.FlatGraph; -import org.nd4j.graph.FlatNode; -import org.nd4j.graph.FlatVariable; -import org.nd4j.graph.IntPair; -import org.nd4j.graph.OpType; -import org.nd4j.graph.UpdaterState; +import org.nd4j.graph.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -112,8 +51,6 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; -import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -136,7 +73,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.AtomicBoolean; -import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.DeviceLocalNDArray; @@ -152,6 +88,17 @@ import org.nd4j.weightinit.impl.NDArraySupplierInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme; import org.tensorflow.framework.GraphDef; +import java.io.*; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; + /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. *

@@ -168,7 +115,7 @@ public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; //Fields for graph structure and execution - @Getter //TODO use package private instead of public getters? + @Getter private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde @Getter private final Map ops = new LinkedHashMap<>(); @@ -185,6 +132,8 @@ public class SameDiff extends SDBaseOps { private final List nameScopes = new ArrayList<>(); //Used as a stack + private List outputs; //Names of the output variables, set by the user. + /////////////////////////////////////// //Fields related to training @Getter @@ -195,15 +144,8 @@ public class SameDiff extends SDBaseOps { private Map updaterMap; //GradientUpdater instance for each trainable parameter //////////////////////////////////////// - //map a function's instance id to a base name, used for propagating variable names - //for output during import - private Map baseNameForFunctionInstanceId; private DifferentialFunctionFactory functionFactory; - @Deprecated //TO BE REMOVED - to ShapeSession - private Map variableNameToShape; //Key: SDVariable name. Value: shape for that variable - @Deprecated //TO BE REMOVED - to Variable - private Map forwardVarForGrad; // counter for auto-naming variables private int variableId = 0; @@ -300,38 +242,7 @@ public class SameDiff extends SDBaseOps { return bitwise; } - - /** - * For import, many times we have variables - * that map to properties. Most common - * we will have an input to a function that is mapped to an ndarray. - * That ndarray is usually a scalar shape. - *

- * That array with a scalar shape can be something like an axis. - *

- * We often don't know that array's value till run time. - * This map stores variable names that we should resolve - * from samediff. We use the value of that array - * to update the properties. - */ - private Map> propertiesToResolve; - - /** - * A map of own name to - * the properties of the function (things like execution axes etc) - * The valid values can be: - * int - * long - * INDArray - */ - private Map> propertiesForFunction; - - @Deprecated //TO BE REMOVED - to Variable - private Map placeHolderOriginalShapes; - private Map sameDiffFunctionDefinitionMap; private Map sameDiffFunctionInstances; - private Set placeHolderFunctions; - private static Map opMethods; private Table fieldVariableResolutionMapping; @@ -342,9 +253,6 @@ public class SameDiff extends SDBaseOps { //debug mode variables @Getter private boolean debugMode; - private Map opsForResult; - private boolean resolvedVariables = false; - @Getter private Stack argumentInterceptors = new Stack<>(); @@ -363,110 +271,6 @@ public class SameDiff extends SDBaseOps { @Getter private SameDiff child; - static { - opMethods = new HashMap<>(); - Method[] methods = SameDiff.class.getDeclaredMethods(); - for (Method method : methods) { - if (method.getReturnType().equals(SDVariable.class)) { - opMethods.put(method.getName(), method); - } - } - } - - - /** - * Update the opName for the variable with the given vertex id - * - * @param varName the vertex id to update - * @param withName thew new opName - */ - public void updateVariableName(String varName, String withName) { - SDVariable oldVarNameRef = getVariable(varName); - Variable v = variables.remove(varName); - String oldVarName = varName; - oldVarNameRef.setVarName(withName); - v.setName(withName); - variables.put(withName, v); - - for (SameDiffOp op : ops.values()) { - List outputsOfOp = op.getOutputsOfOp(); - if (outputsOfOp != null && !outputsOfOp.isEmpty()) { - for (int i = 0; i < outputsOfOp.size(); i++) { - if (outputsOfOp.get(i).equals(oldVarName)) { - outputsOfOp.set(i, withName); - } - } - } - - List inputsToOp = op.getInputsToOp(); - if (inputsToOp != null && !inputsToOp.isEmpty()) { - for (int i = 0; i < inputsToOp.size(); i++) { - if (inputsToOp.get(i).equals(oldVarName)) { - inputsToOp.set(i, withName); - } - } - } - } - -// if (variableNameToArr.containsKey(oldVarName)) { -// val arr = variableNameToArr.remove(oldVarName); -// variableNameToArr.put(withName, arr); -// } - - - if (variableNameToShape.containsKey(oldVarName)) { - val shape = variableNameToShape.remove(oldVarName); - variableNameToShape.put(withName, shape); - } - - if (forwardVarForGrad.containsKey(oldVarName)) { - val forwardGrad = forwardVarForGrad.remove(oldVarName); - forwardVarForGrad.put(withName, forwardGrad); - } - - - if (v.getInputsForOp() != null) { - List funcNames = v.getInputsForOp(); - for (String s : funcNames) { - DifferentialFunction func = ops.get(s).getOp(); - if (func instanceof BaseOp) { - BaseOp baseOp = (BaseOp) func; - if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { - baseOp.setXVertexId(withName); - } - - if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { - baseOp.setYVertexId(withName); - } - - if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { - baseOp.setZVertexId(withName); - } - - } - } - } - - - if (v.getOutputOfOp() != null) { - DifferentialFunction func = ops.get(v.getOutputOfOp()).getOp(); - if (func instanceof BaseOp) { - BaseOp baseOp = (BaseOp) func; - if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { - baseOp.setXVertexId(withName); - } - - if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { - baseOp.setYVertexId(withName); - } - - if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { - baseOp.setZVertexId(withName); - } - } - } - } - /** * Clears debugging state and disables debug mode. @@ -604,9 +408,9 @@ public class SameDiff extends SDBaseOps { * } * SDVariable z = sd.var("z", DataType.FLOAT, 5); * - * String xName = x.getVarName(); //RESULT: "x" - * String yName = y.getVarName(); //RESULT: "myScope/y" - * String zName = z.getVarName(); //RESULT: "z" + * String xName = x.name(); //RESULT: "x" + * String yName = y.name(); //RESULT: "myScope/y" + * String zName = z.name(); //RESULT: "z" * } * *

@@ -620,7 +424,7 @@ public class SameDiff extends SDBaseOps { * x = sd.var("x", DataType.FLOAT, 5); * } * } - * String xName = x.getVarName(); //RESULT: "first/second/x" + * String xName = x.name(); //RESULT: "first/second/x" * } * * @@ -659,7 +463,7 @@ public class SameDiff extends SDBaseOps { public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); for (SDVariable v : variables()) { - if (v.getVarName().startsWith(scope.getName())) + if (v.name().startsWith(scope.getName())) vars.add(v); } return vars; @@ -683,7 +487,7 @@ public class SameDiff extends SDBaseOps { for (val var : variables()) { SDVariable clone = var.clone(this); SDVariable newVar = sameDiff.var(clone); - if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway + if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway sameDiff.associateArrayWithVariable(var.getArr(), newVar); } @@ -795,9 +599,9 @@ public class SameDiff extends SDBaseOps { * @param function the function to get the inputs for * @return the input ids for a given function */ - public String[] getInputsForOp(DifferentialFunction function) { + public String[] getInputsForOp(@NonNull DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) - throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); + throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\""); List inputs = ops.get(function.getOwnName()).getInputsToOp(); return inputs == null ? null : inputs.toArray(new String[inputs.size()]); } @@ -885,99 +689,6 @@ public class SameDiff extends SDBaseOps { } - /** - * Get the shape for the given vertex id. - * Note that if an array is defined, it will use the shape of the array instead. - *

- * A shape *and* an array should not be defined at the same time. - * This wastes memory. The internal map used for tracking shapes for particular - * vertex ids should also delete redundant shapes stored to avoid redundant sources of information. - * - * @param varName the vertex id to get the shape for - * @return the shape for the given vertex if any. - */ - public long[] getShapeForVarName(String varName) { - if (arrayAlreadyExistsForVarName(varName)) { - return getVariable(varName).getArr().shape(); - } - return variableNameToShape.get(varName); - } - - /** - * See {@link #getShapeForVarName(String)}, but returns the shape descriptor. - */ - public LongShapeDescriptor getShapeDescriptorForVarName(String varName) { - if (getVariable(varName).getArr() != null) { - return getVariable(varName).getArr().shapeDescriptor(); - } - // FIXME: do we really want this Nd4j.dataType() here? - return LongShapeDescriptor.fromShape(variableNameToShape.get(varName), Nd4j.dataType()); - } - - - /** - * Associate a vertex id with the given shape. - * - * @param varName the vertex id to associate - * @param shape the shape to associate with - * @see #putShapeForVarName(String, long[]) - * @see #putOrUpdateShapeForVarName(String, long[], boolean) - */ - @Deprecated - public void putShapeForVarName(String varName, long[] shape) { - if (shape == null) { - throw new ND4JIllegalStateException("Shape must not be null!"); - } - - if (variableNameToShape.containsKey(varName)) { - throw new ND4JIllegalStateException("Shape for " + varName + " already exists!"); - } - - variableNameToShape.put(varName, shape); - } - - - /** - * Sets the shape descriptor for a variable. - */ - public void putShapeForVarName(String varName, LongShapeDescriptor shape) { - val v = getVariable(varName); - putShapeForVarName(varName, shape.getShape()); - v.setDataType(shape.dataType()); - } - - /** - * Put or update the shape for the given variable name. Optionally supports clearing the specified variable's - * INDArray if it's shape does not match the new shape - * - * @param varName Variable name - * @param shape Shape to put - * @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified - * variable name, it will be removed from the graph (to be later re-generated) if - * its shape does not match the specified shape - */ - @Deprecated - public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) { - Preconditions.checkNotNull(shape, "Cannot put null shape for variable: %s", varName); - if (variableNameToShape.containsKey(varName)) { -// updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch); - //TODO - } else { - putShapeForVarName(varName, shape); - } - } - - /** - * Returns true if the given vertex id and shape already exist. - * - * @param varName the vertex id - * @return true if the ndarray and vertex id already exist - */ - public boolean shapeAlreadyExistsForVarName(String varName) { - return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName); - } - - /** * Returns true if the given vertex id and {@link INDArray} already exist. * @@ -1013,11 +724,6 @@ public class SameDiff extends SDBaseOps { SDVariable v = variables.get(varName).getVariable(); switch (v.getVariableType()) { case VARIABLE: - if (!variablesArrays.containsKey(varName)) { - //VARIBALE type arrays should have a parameter initializer... - // we should use this to azy init the array if none is present - v.storeAndAllocateNewArray(); - } return variablesArrays.get(varName).get(); case CONSTANT: if (!constantArrays.containsKey(varName)) @@ -1069,7 +775,7 @@ public class SameDiff extends SDBaseOps { arr = arr.castTo(variable.dataType()); Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", - variable.getVarName(), variable.dataType(), arr.dataType()); + variable.name(), variable.dataType(), arr.dataType()); if (sessions.get(Thread.currentThread().getId()) == null) { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); @@ -1096,18 +802,14 @@ public class SameDiff extends SDBaseOps { switch (variable.getVariableType()) { case VARIABLE: - variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: - constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); + constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); break; case ARRAY: - // FIXME: remove this before release - val session = sessions.get(Thread.currentThread().getId()); - val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null); - session.getNodeOutputs().put(varId, arr); - //throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY"); - break; + throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + + " this type of variable is calculated "); case PLACEHOLDER: //Validate placeholder shapes: long[] phShape = variable.placeholderShape(); @@ -1120,19 +822,19 @@ public class SameDiff extends SDBaseOps { if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } - placeholdersPerThread.get(tid).put(variable.getVarName(), arr); + placeholdersPerThread.get(tid).put(variable.name(), arr); break; default: throw new IllegalStateException("Unknown variable type: " + variable.getVariableType()); } - //putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true); + //putOrUpdateShapeForVarName(variable.name(), arr.shape(), true); //Also update nested SameDiff instances (such as gradient function) if (sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0) { for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); - SDVariable v = sd.getVariable(variable.getVarName()); + SDVariable v = sd.getVariable(variable.name()); if (v != null) { sd.associateArrayWithVariable(arr, v); } @@ -1150,16 +852,16 @@ public class SameDiff extends SDBaseOps { */ public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){ Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT, - "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType()); + "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.name(), variable.getVariableType()); //DeviceLocal doesn't work with views if(arr.isView()) arr = arr.dup(); if(variable.getVariableType() == VariableType.VARIABLE ){ - variablesArrays.get(variable.getVarName()).update(arr); + variablesArrays.get(variable.name()).update(arr); } else { - constantArrays.get(variable.getVarName()).update(arr); + constantArrays.get(variable.name()).update(arr); } } @@ -1192,38 +894,6 @@ public class SameDiff extends SDBaseOps { return ret; } - - /** - * Invoke an op by opName - * - * @param op the op - * @param x the first input - * @param y the second input - * @return the result variable - */ - @Deprecated //TO BE REMOVED - should not be part of public API - public SDVariable invoke(Op op, SDVariable x, SDVariable y) { - if (!opMethods.containsKey(op.opName())) { - throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); - } - - if (x != null && y != null) { - try { - return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y); - } catch (Exception e) { - - } - } else { - try { - return (SDVariable) opMethods.get(op.opName()).invoke(this, x); - } catch (Exception e) { - - } - } - - throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); - } - /** * The set of defined SameDiff function names. SameDiff function instances should not be confused * with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function @@ -1234,155 +904,10 @@ public class SameDiff extends SDBaseOps { return this.sameDiffFunctionInstances.keySet(); } - /** - * Invoke an op by opName - * - * @param op the op - * @param x the first input - * @return the result variable - */ - public SDVariable invoke(Op op, SDVariable x) { - return invoke(op, x, null); - } - private SameDiff() { functionFactory = new DifferentialFunctionFactory(this); - sameDiffFunctionDefinitionMap = new LinkedHashMap<>(); sameDiffFunctionInstances = new LinkedHashMap<>(); - forwardVarForGrad = new LinkedHashMap<>(); - opsForResult = new IntArrayKeyMap<>(); - variableNameToShape = new LinkedHashMap<>(); - placeHolderOriginalShapes = new LinkedHashMap<>(); - placeHolderFunctions = new LinkedHashSet<>(); - baseNameForFunctionInstanceId = new LinkedHashMap<>(); - propertiesToResolve = new LinkedHashMap<>(); - propertiesForFunction = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); - - } - - /** - * Adds a property that needs to be resolve for later. - * These variables are typically values that are arrays - * that are named but have an unknown value till execution time. - *

- * This is very common for model import. - * - * @param forFunction the function to add the property to resolve for - * @param arrayName the array name - */ - public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) { - if (!propertiesToResolve.containsKey(forFunction.getOwnName())) { - List newVal = new ArrayList<>(); - newVal.add(arrayName); - propertiesToResolve.put(forFunction.getOwnName(), newVal); - } else { - List newVal = propertiesToResolve.get(forFunction.getOwnName()); - newVal.add(arrayName); - } - } - - /** - * Remove a property to resolve added with {@link #addPropertyToResolve(DifferentialFunction, String)} - * - * @param forFunction the function to add the property to resolve for - * @param arrayName the array name - */ - public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) { - if (propertiesToResolve.containsKey(forFunction.getOwnName())) { - List newVal = propertiesToResolve.get(forFunction.getOwnName()); - newVal.remove(arrayName); - } - } - - /** - * Return the properties to resolve for the given function. - * This is typically used right before execution in model import in - * {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - * - * @param function the function get the properties to resolve for - * @return the properties to resolve for the given function - */ - public List propertiesToResolveForFunction(DifferentialFunction function) { - if (!propertiesToResolve.containsKey(function.getOwnName())) - return Collections.emptyList(); - - return propertiesToResolve.get(function.getOwnName()); - } - - - private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) { - if (!propertiesForFunction.containsKey(functionFor.getOwnName())) { - Map fields = new LinkedHashMap<>(); - fields.put(propertyName, propertyValue); - propertiesForFunction.put(functionFor.getOwnName(), fields); - } else { - val fieldMap = propertiesForFunction.get(functionFor.getOwnName()); - if (fieldMap.containsKey(propertyName)) { - throw new ND4JIllegalStateException("Attempting to override property " + propertyName); - } - - fieldMap.put(propertyName, propertyValue); - } - } - - - /** - * Adds a field name -> variable name mapping for a given function.
- * This is used for model import where there is an unresolved variable at the time of calling any - * {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)} - * . - *

- * This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - *

- * When a function attempts to resolve variables right before execution, there needs to be a way of knowing - * which variable in a samediff graph should map to a function's particular field name - * - * @param function the function to map - * @param fieldName the field name for the function to map - * @param varName the variable name of the array to get from samediff - */ - public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) { - fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName); - } - - /** - * Get the variable name to use - * for resolving a given field - * for a given function during import time. - * This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} - * - * @param function the function to get the variable name for - * @param fieldName the field name to resolve for - * @return the resolve variable name if any - */ - public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) { - return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName); - } - - /** - * Sets a base name for the function id. - * This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)} - * for ensuring original names for model import map to current samediff names - * when names are generated. - * - * @param baseName the base name to add - * @param function the function to declare a base name for. - */ - public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) { - baseNameForFunctionInstanceId.put(function.getOwnName(), baseName); - } - - /** - * Returns the base name for the given function - * if any (may return null) - * - * @param function the function to get the base name for - * @return the base name for the given function (if any) based - * on the function's instance id. - */ - public String getBaseNameForFunction(DifferentialFunction function) { - return baseNameForFunctionInstanceId.get(function.getOwnName()); } @@ -1418,7 +943,7 @@ public class SameDiff extends SDBaseOps { public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { - varNames[i] = variables[i].getVarName(); + varNames[i] = variables[i].name(); } addOutgoingFor(varNames, function); @@ -1557,7 +1082,7 @@ public class SameDiff extends SDBaseOps { if (interceptor != null) { pauseArgumentInterceptor(interceptor); for (int i = 0; i < variables.length; i++) { - variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName(); + variables[i] = interceptor.intercept(getVariable(variables[i])).name(); } unpauseArgumentInterceptor(interceptor); } @@ -1565,13 +1090,6 @@ public class SameDiff extends SDBaseOps { if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); - //double check if function contains placeholder args - for (val varName : variables) { - if (isPlaceHolder(varName)) { - placeHolderFunctions.add(function.getOwnName()); - } - } - //Add function if it doesn't exist //TODO could "not existing" be a bug sometimes? if (!ops.containsKey(function.getOwnName())) { @@ -1604,7 +1122,7 @@ public class SameDiff extends SDBaseOps { for (int i = 0; i < varNames.length; i++) { if (variables[i] == null) throw new ND4JIllegalStateException("Found null variable at index " + i); - varNames[i] = variables[i].getVarName(); + varNames[i] = variables[i].name(); } addArgsFor(varNames, function); } @@ -1619,25 +1137,8 @@ public class SameDiff extends SDBaseOps { function.getOwnName() + " only has " + function.args().length + " args but you are trying" + "to replace the argument at " + i); - String oldName = function.arg(i).getVarName(); - String newName = newArg.getVarName(); - - if (function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()) { - boolean otherPlaceholders = false; - for (int j = 0; j < function.argNames().length; j++) { - if (j == i) - continue; - - if (function.arg(j).isPlaceHolder()) - otherPlaceholders = true; - } - - if (!otherPlaceholders) - placeHolderFunctions.remove(function.getOwnName()); - } else if (!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()) { - if (!placeHolderFunctions.contains(function.getOwnName())) - placeHolderFunctions.add(function.getOwnName()); - } + String oldName = function.arg(i).name(); + String newName = newArg.name(); List oldArgs = ops.get(function.getOwnName()).getInputsToOp(); oldArgs = new ArrayList<>(oldArgs); @@ -1761,8 +1262,6 @@ public class SameDiff extends SDBaseOps { if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null) return false; - if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null) - return false; return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null; } @@ -1822,42 +1321,37 @@ public class SameDiff extends SDBaseOps { } /** - * Outputs are those variables (not placeholders, constants, etc) that are the output of a function that aren't the - * input to any other ops. - * Usually these are the output of the last function(s) in the SameDiff instance. + * Outputs are the names of the predictions of the network. + * Note that the outputs must be set using {@link #setOutputs(List)} first * - * @return The (inferred) outputs of the SameDiff instance, in no particular order + * @return The outputs of the SameDiff instance, or null if no outputs have been set */ public List outputs() { - List out = new ArrayList<>(); - for (Variable v : variables.values()) { - if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders - (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops - (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops - (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) - continue; + return this.outputs; + } + + /** + * See {@link #setOutputs(List)} + */ + public void setOutputs(String... outputs){ + setOutputs(outputs == null ? null : Arrays.asList(outputs)); + } + + + /** + * Set the outputs of the SameDiff instance. + * Outputs are the names of the variables that are the predictions of the neural network. + * Note that this is merely a convenience, and does not impact execution at all. Outputs can be retrieved (after + * setting here) using {@link #outputs()} + * @param outputs Outputs to set. Must be valid variable names in this SameDiff instance + */ + public void setOutputs(List outputs){ + if(outputs != null){ + for(String s : outputs){ + Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name"); } - - //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user - if (v.getOutputOfOp() != null) { - String opName = v.getOutputOfOp(); - SameDiffOp o = ops.get(opName); - if (o.getOp() instanceof Assert) { - continue; - } - - //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed - // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here - // This applies to SameDiff while loops as well - if (o.getOp() instanceof Switch) { - continue; - } - } - - - out.add(v.getName()); } - return out; + this.outputs = outputs; } /** @@ -1903,7 +1397,7 @@ public class SameDiff extends SDBaseOps { public void setLossVariables(@NonNull SDVariable... lossVariables) { String[] varNames = new String[lossVariables.length]; for (int i = 0; i < lossVariables.length; i++) - varNames[i] = lossVariables[i].getVarName(); + varNames[i] = lossVariables[i].name(); setLossVariables(varNames); } @@ -1932,7 +1426,7 @@ public class SameDiff extends SDBaseOps { * See {@link #addLossVariable(String)} */ public void addLossVariable(@NonNull SDVariable variable) { - addLossVariable(variable.getVarName()); + addLossVariable(variable.name()); } /** @@ -2149,14 +1643,45 @@ public class SameDiff extends SDBaseOps { Set requiredVars = new HashSet<>(); for (Listener l : activeListeners) { - requiredVars.addAll(l.requiredVariables(this).trainingVariables()); + ListenerVariables lv = l.requiredVariables(this); + if(lv != null) { + Set s = lv.trainingVariables(); + if(s != null) { + requiredVars.addAll(s); + } + } } - ArrayList listenersWitHistory = new ArrayList<>(listeners); + List listenersWitHistory = new ArrayList<>(listeners); + for(Listener l : this.listeners){ + if(!listenersWitHistory.contains(l)) + listenersWitHistory.add(l); + } listenersWitHistory.add(history); - for (int i = 0; i < numEpochs; i++) { + SameDiff gradInstance = getFunction("grad"); + if(gradInstance == null){ + createGradFunction(); + gradInstance = getFunction("grad"); + } + TrainingSession ts = new TrainingSession(gradInstance); + gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it + + for(Listener l : activeListeners){ + l.operationStart(gradInstance, Operation.TRAINING); + } + + Set paramsToTrain = new LinkedHashSet<>(); + for(Variable v : variables.values()){ + if(v.getVariable().getVariableType() == VariableType.VARIABLE){ + //TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped + paramsToTrain.add(v.getName()); + } + } + + Loss lastLoss = null; + for (int i = 0; i < numEpochs; i++) { if (incrementEpochCount && hasListeners) { at.setEpoch(trainingConfig.getEpochCount()); for (Listener l : activeListeners) { @@ -2198,155 +1723,39 @@ public class SameDiff extends SDBaseOps { Map placeholders = toPlaceholderMap(ds); Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training"); - resolveVariablesWith(placeholders); - //Calculate gradients: - execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners); - - - //Apply updater: + //Call TrainingSession to perform training if (!initializedTraining) initializeTraining(); - Map, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners - if (hasListeners) { - regScore = new HashMap<>(); - } + lastLoss = ts.trainingIteration( + trainingConfig, + placeholders, + paramsToTrain, + updaterMap, + ds, + getLossVariables(), + listenersWitHistory, + at); - int iteration = trainingConfig.getIterationCount(); - int e = trainingConfig.getEpochCount(); - for (Variable v : variables.values()) { - //Only update trainable params - float type parameters (variable type vars) - SDVariable sdv = v.getVariable(); - if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) - continue; - - - INDArray param = sdv.getArr(); - SDVariable gradVar = sdv.getGradient(); - if (gradVar == null) { - //Not all trainable parameters have gradients defined. - //Consider graph: in1->loss1; in2->loss2, where we optimize only loss1. - //No gradient will be present for in2, because in2 doesn't impact loss1 at all - continue; - } - INDArray grad = gradVar.getArr(); - //Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients, - // which should flow through to here - - //Pre-apply regularization (L1, L2) - List r = trainingConfig.getRegularization(); - int iterCount = trainingConfig.getIterationCount(); - int epochCount = trainingConfig.getEpochCount(); - double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0; - if (r != null && r.size() > 0) { - for (Regularization reg : r) { - if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { - reg.apply(param, grad, lr, iterCount, epochCount); - } - } - } - - //Apply updater. Note that we need to reshape to [1,length] for updater - INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order! - Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv); - GradientUpdater u = updaterMap.get(sdv.getVarName()); - try { - u.applyUpdater(reshapedView, iteration, e); - } catch (Throwable t) { - throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName() - + "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t); - } - - //Post-apply regularization (weight decay) - if (r != null && r.size() > 0) { - for (Regularization reg : r) { - if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { - reg.apply(param, grad, lr, iterCount, epochCount); - if (hasListeners) { - double score = reg.score(param, iterCount, epochCount); - if (!regScore.containsKey(reg.getClass())) { - regScore.put(reg.getClass(), new AtomicDouble()); - } - regScore.get(reg.getClass()).addAndGet(score); - } - } - } - } - - if (hasListeners) { - for (Listener l : activeListeners) { - if (l.isActive(at.operation())) - l.preUpdate(this, at, v, reshapedView); - } - } - - - if (trainingConfig.isMinimize()) { - param.subi(grad); - } else { - param.addi(grad); - } - } - - double[] d = new double[lossVariables.size() + regScore.size()]; - List lossVars; - if (regScore.size() > 0) { - lossVars = new ArrayList<>(lossVariables.size() + regScore.size()); - lossVars.addAll(lossVariables); - int s = regScore.size(); - //Collect regularization losses - for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) { - lossVars.add(entry.getKey().getSimpleName()); - d[s] = entry.getValue().get(); - } - } else { - lossVars = lossVariables; - } - - //Collect the losses... - SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY); - int count = 0; - for (String s : lossVariables) { - INDArray arr = gradFn.getArrForVarName(s); - double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); - d[count++] = l; - } - - Loss loss = new Loss(lossVars, d); - - if (lossNames == null) { - lossNames = lossVars; - } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars); - } if (lossSums == null) { - lossSums = d; + lossSums = lastLoss.getLosses().clone(); } else { - Preconditions.checkState(lossNames.equals(lossVars), - "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length); - for (int j = 0; j < lossSums.length; j++) { - lossSums[j] += d[j]; + lossSums[j] += lastLoss.getLosses()[j]; } } lossCount++; - if (hasListeners) { - for (Listener l : activeListeners) { - l.iterationDone(this, at, ds, loss); - } - - } - trainingConfig.incrementIterationCount(); } long epochTime = System.currentTimeMillis() - epochStartTime; if (incrementEpochCount) { + lossNames = lastLoss.getLossNames(); + for (int j = 0; j < lossSums.length; j++) lossSums[j] /= lossCount; @@ -2356,14 +1765,13 @@ public class SameDiff extends SDBaseOps { lossCurve = new LossCurve(lossSums, lossNames); } + if (incrementEpochCount) { if (hasListeners) { - boolean doStop = false; Listener stopped = null; for (Listener l : activeListeners) { - ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime); if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { @@ -2431,7 +1839,6 @@ public class SameDiff extends SDBaseOps { trainingConfig.incrementEpochCount(); } - if (i < numEpochs - 1) { iter.reset(); } @@ -2448,9 +1855,12 @@ public class SameDiff extends SDBaseOps { */ private void validateListenerActivations(List listeners, Operation op) { for (Listener l : listeners) { - for (String s : l.requiredVariables(this).requiredVariables(op)) { - if (!variables.containsKey(s)) { - Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); + ListenerVariables lv = l.requiredVariables(this); + if(lv != null) { + for (String s : lv.requiredVariables(op)) { + if (!variables.containsKey(s)) { + Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); + } } } } @@ -2507,7 +1917,9 @@ public class SameDiff extends SDBaseOps { INDArray arr = v.getVariable().getArr(); long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize); - updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true)); + GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false); + gu.setStateViewArray(view, arr.shape(), arr.ordering(), true); + updaterMap.put(v.getName(), gu); } initializedTraining = true; @@ -2753,31 +2165,20 @@ public class SameDiff extends SDBaseOps { if (hasListeners) { for (Listener l : activeListeners) { - requiredVars.addAll(l.requiredVariables(this).evaluationVariables()); + ListenerVariables v = l.requiredVariables(this); + if(v != null) { + requiredVars.addAll(v.evaluationVariables()); + } } } String[] requiredVarsArr = requiredVars.toArray(new String[0]); while (iterator.hasNext()) { - long dataStart = hasListeners ? System.currentTimeMillis() : 0; MultiDataSet ds = iterator.next(); - long dataEnd = hasListeners ? System.currentTimeMillis() : 0; Map placeholderMap = toPlaceholderMap(ds); - Map m; - Map outs = null; - if (hasListeners) { - - for (Listener l : activeListeners) { - l.iterationStart(this, at, ds, (dataEnd - dataStart)); - } - - m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); - } else { - m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); - } - + Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); for (Map.Entry> e : variableEvals.entrySet()) { INDArray prediction = m.get(e.getKey()); @@ -2790,15 +2191,6 @@ public class SameDiff extends SDBaseOps { } } - if (hasListeners) { - for (Listener l : activeListeners) { - Map outVars = Maps.newHashMap( - Maps.filterKeys(outs, - Predicates.in(l.requiredVariables(this).evaluationVariables()))); - l.iterationDone(this, at, ds, null); - } - } - at.setIteration(at.iteration() + 1); } @@ -2977,7 +2369,7 @@ public class SameDiff extends SDBaseOps { * INDArray out = sd.output() * .data(data) * .output("pred") - * .execSingle(); + * .outputSingle(); * } * */ @@ -3013,7 +2405,7 @@ public class SameDiff extends SDBaseOps { if (outputs != null && outputs.length != 0) { neededOutputs = Arrays.asList(outputs); } else { - neededOutputs = outputs(); + neededOutputs = getLossVariables(); } String[] neededOutputsArr = neededOutputs.toArray(new String[0]); @@ -3083,7 +2475,7 @@ public class SameDiff extends SDBaseOps { * .output("out") * .input("x", xValue) * .input(y, yValue) - * .execSingle(); + * .outputSingle(); * } * */ @@ -3091,14 +2483,6 @@ public class SameDiff extends SDBaseOps { return new BatchOutputConfig(this); } - /** - * @deprecated See {@link #outputAll(Map)} and {@link #batchOutput()} - */ - @Deprecated - public Map execAll(Map placeholders) { - return outputAll(placeholders); - } - /** * Do inference for all variables for a single batch. *

@@ -3109,15 +2493,6 @@ public class SameDiff extends SDBaseOps { public Map outputAll(Map placeholders) { return batchOutput().outputAll().inputs(placeholders).exec(); } - - /** - * @deprecated See {@link #outputSingle(Map, String)} and {@link #batchOutput()} - */ - @Deprecated - public INDArray execSingle(Map placeholders, String output) { - return outputSingle(placeholders, output); - } - /** * Do inference for a single variable for a single batch. *

@@ -3129,14 +2504,6 @@ public class SameDiff extends SDBaseOps { return batchOutput().output(output).inputs(placeholders).execSingle(); } - /** - * @deprecated See {@link #output(Map, List)} and {@link #batchOutput()} - */ - @Deprecated - public Map exec(Map placeholders, List outputs) { - return output(placeholders, outputs); - } - /** * Do inference for the given variables for a single batch. *

@@ -3144,16 +2511,8 @@ public class SameDiff extends SDBaseOps { *

* Special case of {@link #batchOutput()}. */ - public Map output(Map placeholders, List outputs) { - return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec(); - } - - /** - * @deprecated See {@link #output(Map, String...)} and {@link #batchOutput()} - */ - @Deprecated - public Map exec(Map placeholders, String... outputs) { - return output(placeholders, outputs); + public Map output(Map placeholders, @NonNull List outputs) { + return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output(); } /** @@ -3164,7 +2523,7 @@ public class SameDiff extends SDBaseOps { * Special case of {@link #batchOutput()}. */ public Map output(Map placeholders, String... outputs) { - return batchOutput().output(outputs).inputs(placeholders).exec(); + return batchOutput().output(outputs).inputs(placeholders).output(); } @@ -3177,31 +2536,36 @@ public class SameDiff extends SDBaseOps { * @param listeners Additional listeners to use during this operation. * @param outputs The variables to output and return. */ - public Map output(Map placeholders, @NonNull List listeners, String... outputs) { - return batchOutputHelper(placeholders, listeners, outputs); + public Map output(Map placeholders, List listeners, String... outputs) { + return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs); } - protected Map batchOutputHelper(Map placeholders, @NonNull List listeners, String... outputs) { + protected Map batchOutputHelper(Map placeholders, List listeners, Operation operation, String... outputs) { List activeListeners = new ArrayList<>(); + if(operation == null) + operation = Operation.INFERENCE; + for (Listener l : this.listeners) - if (l.isActive(Operation.INFERENCE)) + if (l.isActive(operation)) activeListeners.add(l); - for (Listener l : listeners) - if (l.isActive(Operation.INFERENCE)) - activeListeners.add(l); - - for (Listener l : activeListeners) { - l.operationStart(this, Operation.INFERENCE); + if(listeners != null) { + for (Listener l : listeners) + if (l.isActive(operation)) + activeListeners.add(l); } - validateListenerActivations(activeListeners, Operation.INFERENCE); + for (Listener l : activeListeners) { + l.operationStart(this, operation); + } - Map ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.emptyList(), activeListeners, outputs); + validateListenerActivations(activeListeners, operation); + + Map ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs); for (Listener l : activeListeners) { - l.operationEnd(this, Operation.INFERENCE); + l.operationEnd(this, operation); } return ret; } @@ -3236,7 +2600,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #one(String, DataType, int...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, int... shape) { @@ -3245,7 +2609,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #one(String, DataType, long...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, long... shape) { @@ -3255,31 +2619,31 @@ public class SameDiff extends SDBaseOps { /** * Create a new variable with the specified shape, with all values initialized to 1.0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { - return var(name, new ConstantInitScheme('f', 1.0), dataType, ArrayUtil.toLongArray(shape)); + return one(name, dataType, ArrayUtil.toLongArray(shape)); } /** * Create a new variable with the specified shape, with all values initialized to 1.0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - return var(name, new ConstantInitScheme('f', 1.0), dataType, shape); + return constant(name, Nd4j.ones(dataType, shape)); } /** * See {@link #zero(String, DataType, long...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, long... shape) { @@ -3288,7 +2652,7 @@ public class SameDiff extends SDBaseOps { /** * See {@link #zero(String, DataType, int...)}. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, int... shape) { @@ -3297,26 +2661,26 @@ public class SameDiff extends SDBaseOps { /** * Create a new variable with the specified shape, with all values initialized to 0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - return var(name, new ZeroInitScheme(), dataType, shape); + return constant(name, Nd4j.zeros(dataType, shape)); } /** * Create a new variable with the specified shape, with all values initialized to 0. - * Creates a VARIABLE type SDVariable. + * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { - return var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(shape)); + return zero(name, dataType, ArrayUtil.toLongArray(shape)); } /** @@ -3348,39 +2712,13 @@ public class SameDiff extends SDBaseOps { } } - SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); - name = v.getVarName(); + SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType()); + name = v.name(); variables.put(name, Variable.builder().name(name).variable(v).build()); constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads return v; } - /** - * Return a variable of given shape in which all values have a given constant value. - * - * @param value constant to set for each value - * @param shape shape of the variable as long array - * @return A new SDVariable of provided shape with constant value. - */ - @Deprecated - public SDVariable constant(SDVariable value, long... shape) { - return constant(null, value, shape); - } - - /** - * Return a variable of given shape in which all values have a given constant value. - * - * @param name Name of the new SDVariable - * @param value constant to set for each value - * @param shape shape of the variable as long array - * @return A new SDVariable of provided shape with constant value. - */ - @Deprecated - public SDVariable constant(String name, SDVariable value, long... shape) { - SDVariable ret = f().constant(value, shape); - return updateVariableNameAndReference(ret, name); - } - /** * Create a a placeholder variable. Placeholders are variables that expect an array to be provided during training * and inference.
@@ -3394,7 +2732,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkState(!variables.containsKey(name), "Variable already exists with name %s", name); - SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType, null); + SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType); variables.put(name, Variable.builder().name(name).variable(ret).build()); return ret; } @@ -3412,8 +2750,6 @@ public class SameDiff extends SDBaseOps { return var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape); } - //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! - /** * Variable initialization with a specified {@link WeightInitScheme} * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. @@ -3447,14 +2783,19 @@ public class SameDiff extends SDBaseOps { } } + Preconditions.checkState(variableType != VariableType.VARIABLE || weightInitScheme != null, "A weight initalization scheme must be provided" + + " when creating a VARIABLE type SDVariables - variable name: \"%s\"", name); - SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme); + SDVariable ret = new SDVariable(name, variableType, this, shape, dataType); addVariable(ret); - if (variableType == VariableType.PLACEHOLDER) { - setOriginalPlaceHolderShape(name, shape); - putShapeForVarName(name, shape); + if(variableType == VariableType.VARIABLE){ + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + INDArray vArr = weightInitScheme.create(dataType, shape); + variablesArrays.put(name, new DeviceLocalNDArray(vArr, true)); + } } + return ret; } @@ -3570,25 +2911,29 @@ public class SameDiff extends SDBaseOps { * @return */ public SDVariable var(@NonNull final SDVariable v) { - if (variables.containsKey(v.getVarName()) && variables.get(v.getVarName()).getVariable().getArr() != null) - return variables.get(v.getVarName()).getVariable(); + if (variables.containsKey(v.name()) && variables.get(v.name()).getVariable().getArr() != null) + return variables.get(v.name()).getVariable(); - if (v.getVarName() == null || v.getVarName().length() < 1) + if (v.name() == null || v.name().length() < 1) throw new IllegalArgumentException("Name for variable must be defined"); VariableType vt = v.getVariableType(); NDArraySupplierInitScheme s = null; switch (vt) { case VARIABLE: - s = new NDArraySupplierInitScheme(v.getArr()); - //Intentional fallthrough + SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); + addVariable(r); + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ + variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true)); + } + return r; case ARRAY: - SDVariable ret = new SDVariable(v.getVarName(), v.getVariableType(), this, v.getShape(), v.dataType(), s); + SDVariable ret = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); return addVariable(ret); case CONSTANT: - return constant(v.getVarName(), v.getArr()); + return constant(v.name(), v.getArr()); case PLACEHOLDER: - return placeHolder(v.getVarName(), v.dataType(), v.placeholderShape()); + return placeHolder(v.name(), v.dataType(), v.placeholderShape()); default: throw new RuntimeException("Unknown/not supported variable type: " + vt); } @@ -3683,12 +3028,10 @@ public class SameDiff extends SDBaseOps { } } - SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType(), new NDArraySupplierInitScheme(arr)); + SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType()); associateArrayWithVariable(arr, ret); addVariable(ret); - if (getShapeForVarName(name) == null) - putShapeForVarName(name, arr.shape()); return ret; } @@ -3738,7 +3081,7 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : variables) { - String n = variable.getVarName(); + String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); @@ -3757,7 +3100,7 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && initializedTraining) { //Remove updater state for now constant variables for (SDVariable v : variables) { - GradientUpdater gu = updaterMap.remove(v.getVarName()); + GradientUpdater gu = updaterMap.remove(v.name()); Map m = gu == null ? null : gu.getState(); if (m != null) { for (INDArray arr : m.values()) { @@ -3767,27 +3110,27 @@ public class SameDiff extends SDBaseOps { } //Also check dataset feature/label mapping - remove any placeholders here... - if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.name())) { List newFM = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); //New list in case of immutable list - newFM.remove(v.getVarName()); + newFM.remove(v.name()); trainingConfig.setDataSetFeatureMapping(newFM); } - if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.name())) { List newLM = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); - newLM.remove(v.getVarName()); + newLM.remove(v.name()); trainingConfig.setDataSetLabelMapping(newLM); } - if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.name())) { List newFMM = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); - newFMM.remove(v.getVarName()); + newFMM.remove(v.name()); trainingConfig.setDataSetFeatureMaskMapping(newFMM); } - if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())) { + if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.name())) { List newLMM = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); - newLMM.remove(v.getVarName()); + newLMM.remove(v.name()); trainingConfig.setDataSetLabelMaskMapping(newLMM); } } @@ -3804,7 +3147,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable convertToVariable(@NonNull SDVariable constant) { Preconditions.checkState(constant.dataType().isFPType(), "Only floating point SDVariables can be converted to variables," + - " datatype of %s is %s", constant.getVarName(), constant.dataType()); + " datatype of %s is %s", constant.name(), constant.dataType()); convertToVariables(Collections.singletonList(constant)); return constant; } @@ -3836,7 +3179,7 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : constants) { - String n = variable.getVarName(); + String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); @@ -3856,17 +3199,18 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && initializedTraining) { //Add updater state for this variable: updaterState, updaterViews, updaterMap for (SDVariable v : constants) { - if (!updaterMap.containsKey(v.getVarName())) { + if (!updaterMap.containsKey(v.name())) { //Create new updater state INDArray arr = v.getArr(); long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); if (thisSize > 0) { INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); - GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true); - updaterMap.put(v.getVarName(), u); + GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false); + u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call... + updaterMap.put(v.name(), u); } else { GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); - updaterMap.put(v.getVarName(), u); + updaterMap.put(v.name(), u); } } } @@ -3946,7 +3290,53 @@ public class SameDiff extends SDBaseOps { sessions.clear(); //Recalculate datatypes of outputs, and dynamically update them - calculateOutputDataTypes(true); + Set allSeenOps = new HashSet<>(); + Queue queueOps = new LinkedList<>(); + + for(String s : dataTypeMap.keySet()){ + Variable v = variables.get(s); + v.getVariable().setDataType(dataTypeMap.get(s)); + List inToOp = v.getInputsForOp(); + if(inToOp != null){ + for(String op : inToOp) { + if (!allSeenOps.contains(op)) { + allSeenOps.add(op); + queueOps.add(op); + } + } + } + } + + while(!queueOps.isEmpty()){ + String op = queueOps.remove(); + SameDiffOp o = ops.get(op); + List inVars = o.getInputsToOp(); + List inDTypes = new ArrayList<>(); + if(inVars != null) { + for (String s : inVars) { + SDVariable v = variables.get(s).getVariable(); + inDTypes.add(v.dataType()); + } + } + List outDtypes = o.getOp().calculateOutputDataTypes(inDTypes); + List outVars = o.getOutputsOfOp(); + for( int i=0; i e : placeholdersPerThread.values()){ + //Not really thread safe - but renaming variables during execution in other threads can never be thread safe :) + if(e != null && e.containsKey(from)){ + INDArray arr = e.remove(from); + e.put(to, arr); + } + } + } + if (trainingConfig != null) { if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); @@ -4079,7 +3489,7 @@ public class SameDiff extends SDBaseOps { val args = function.args(); for (int i = 0; i < args.length; i++) { - if (args[i].getVarName().equals(varName)) { + if (args[i].name().equals(varName)) { /** * Since we are removing the variable reference * from the arguments we need to update both @@ -4097,6 +3507,8 @@ public class SameDiff extends SDBaseOps { break; } } + + variables.get(varName).getInputsForOp().remove(function.getOwnName()); } /** @@ -4178,15 +3590,6 @@ public class SameDiff extends SDBaseOps { variables.get(variableName).setGradient(variable); } - - /** - * @param varName - * @param forwardVariable - */ - public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) { - forwardVarForGrad.put(varName, forwardVariable); - } - /** * Get the gradient for the variable with the specified variable name. * Note that in order to run this function, {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)} must be executed first. @@ -4197,12 +3600,12 @@ public class SameDiff extends SDBaseOps { */ public SDVariable grad(String varName) { if (!sameDiffFunctionInstances.containsKey(GRAD_FN_KEY)) { - throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first."); + createGradFunction(); } SameDiff grad = getFunction(GRAD_FN_KEY); SDVariable var = grad.getVariable(varName); - return getFunction(GRAD_FN_KEY).getGradForVariable(var.getVarName()); + return getFunction(GRAD_FN_KEY).getGradForVariable(var.name()); } @@ -4384,12 +3787,12 @@ public class SameDiff extends SDBaseOps { public SDVariable addVariable(SDVariable variable) { Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same."); - if (variables.containsKey(variable.getVarName()) && !variables.get(variable.getVarName()).getVariable().equals(variable)) { - throw new IllegalArgumentException("Variable with name \"" + variable.getVarName() + "\" already exists"); + if (variables.containsKey(variable.name()) && !variables.get(variable.name()).getVariable().equals(variable)) { + throw new IllegalArgumentException("Variable with name \"" + variable.name() + "\" already exists"); } Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!"); - variables.put(variable.getVarName(), Variable.builder().name(variable.getVarName()).variable(variable).build()); + variables.put(variable.name(), Variable.builder().name(variable.name()).variable(variable).build()); return variable; } @@ -4402,11 +3805,6 @@ public class SameDiff extends SDBaseOps { * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) { - //xyz ops only have 1 output - //if there is already a base name defined, use that - if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null) - baseName = getBaseNameForFunction(function); - if (baseName == null) baseName = function.getOwnName(); @@ -4476,11 +3874,7 @@ public class SameDiff extends SDBaseOps { else if (function instanceof BaseOp) { SDVariable[] ret = new SDVariable[1]; SDVariable checkGet = getVariable(baseName); - char ordering = 'c'; SDVariable[] args = function.args(); - if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye - ordering = function.args()[0].getArr().ordering(); - } if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); @@ -4530,45 +3924,6 @@ public class SameDiff extends SDBaseOps { return sameDiffFunctionInstances.get(functionName); } - - /** - * @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - @Deprecated - public While whileStatement(SameDiffConditional sameDiffConditional, - SameDiffFunctionDefinition conditionBody, - SameDiffFunctionDefinition loopBody - , SDVariable[] inputVars) { - return While.builder() - .inputVars(inputVars) - .condition(conditionBody) - .predicate(sameDiffConditional) - .trueBody(loopBody) - .parent(this) - .blockName("while-" + UUID.randomUUID().toString()) - .build(); - } - - /** - * @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - @Deprecated - public If ifStatement(SameDiffConditional conditional, - SameDiffFunctionDefinition conditionBody, - SameDiffFunctionDefinition trueBody, - SameDiffFunctionDefinition falseBody - , SDVariable[] inputVars) { - return If.builder() - .conditionBody(conditionBody) - .falseBody(falseBody) - .trueBody(trueBody) - .predicate(conditional) - .inputVars(inputVars) - .parent(this) - .blockName("if-" + UUID.randomUUID().toString()) - .build(); - } - /** * Create a new TensorArray. */ @@ -4636,138 +3991,83 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.put(function, sub); } - - } - - @Deprecated - public INDArray execAndEndResult() { - List outputs = outputs(); - Preconditions.checkState(outputs.size() == 1, "Method can only be used with SameDiff instances with a single output"); - long tid = Thread.currentThread().getId(); - Map placeholders = placeholdersPerThread.get(tid); - return execSingle(placeholders, outputs.get(0)); } /** - * Create (if required) and then calculate the variable gradients (backward pass) for this graph.
- * After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}
- * Note: This method by default calculates VARIABLE type SDVariable gradients only (as well as any other - * gradients needed to calculate the variable gradients). That is, placeholder, constant, etc gradients are not - * calculated. If these gradients are required, they can be calculated using {@link #execBackwards(Map, List, Operation, MultiDataSet, Collection, List)} instead, - * which allows specifying the set of SDVariables to calculate the gradients for. For example, - * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. In some cases, - * {@link #createGradFunction()} may need to be called first + * See {@link #calculateGradients(Map, Collection)} + */ + public Map calculateGradients(Map placeholderVals, @NonNull String... variables) { + Preconditions.checkArgument(variables.length > 0, "No variables were specified"); + return calculateGradients(placeholderVals, Arrays.asList(variables)); + } + + /** + * Calculate and return the gradients for the specified variables * - * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map + * @param placeholderVals Placeholders. May be null + * @param variables Names of the variables that you want the gradient arrays for + * @return Gradients as a map, keyed by the variable name */ - public void execBackwards(Map placeholders, Operation op) { - execBackwards(placeholders, op, null, Collections.emptyList(), Collections.emptyList()); + public Map calculateGradients(Map placeholderVals, @NonNull Collection variables) { + Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified"); + OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables); + return oag.getGradients(); } /** - * See {@link #execBackwards(Map, Operation)}. - *

- * Uses {@link Operation#INFERENCE}. + * Calculate the activations and the gradients for the specified variables, in one execution call. + * This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but + * is more efficient than calling both separately. + * + * @param placeholderVals Placeholders. May be null + * @param outputVars Names of the variables that you want the activations/outputs for. May be null + * @param gradientVars Names of the variables that you want the gradient arrays for. May be null + * @return Activations and gradients, keyed by variable name */ - public void execBackwards(Map placeholders) { - execBackwards(placeholders, Operation.INFERENCE); - } - - protected void execBackwards(Map placeholders, Operation op, MultiDataSet batch, Collection requiredActivations, List activeListeners) { + public OutAndGrad calculateGradientsAndOutputs(Map placeholderVals, Collection outputVars, Collection gradientVars){ + Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()), + "No variables were specified for either output or gradients"); if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); } - //Collect (unique) list of gradient names... - Set varGradNames = new HashSet<>(); - for (Variable v : variables.values()) { - if (v.getVariable().getVariableType() == VariableType.VARIABLE) { - SDVariable g = v.getVariable().gradient(); - if (g != null) { - //Not all variables can have gradients... for example: suppose graph has 2 independent loss functions, - // optimizing only 1 might not require changing all variables - varGradNames.add(g.getVarName()); + List varNames = new ArrayList<>(); + if(outputVars != null){ + varNames.addAll(outputVars); + } + if(gradientVars != null) { + for (String s : gradientVars) { + Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s); + SDVariable v = getVariable(s).getGradient(); + if (v != null) { + //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable + varNames.add(v.name()); } } } - //Also add loss values - we need these so we can report them to listeners or loss curves... - if (!activeListeners.isEmpty() || op == Operation.TRAINING) { - varGradNames.addAll(lossVariables); + //Key is gradient variable name + SameDiff gradFn = getFunction(GRAD_FN_KEY); + gradFn.setListeners(listeners); + Map grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0])); + + Map outOutputs = outputVars == null ? null : new HashMap(); + Map outGrads = gradientVars == null ? null : new HashMap(); + if(outputVars != null){ + for(String s : outputVars){ + outOutputs.put(s, grads.get(s)); + } + } + if(gradientVars != null) { + for (String s : gradientVars) { + if (getVariable(s).getGradient() != null) { + String gradVar = getVariable(s).getGradient().name(); + outGrads.put(s, grads.get(gradVar)); + } + } } - //Edge case: if no variables, no variable gradients to calculate... - if (varGradNames.isEmpty()) { - log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\n" + - "If gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead"); - } - - List vargradNamesList = new ArrayList<>(varGradNames); - execBackwards(placeholders, vargradNamesList, op, batch, requiredActivations, activeListeners); - } - - /** - * See {@link #execBackwards(Map, List, Operation)} - */ - public Map execBackwards(Map placeholders, Operation op, String... variableGradNamesList) { - return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.emptyList(), Collections.emptyList()); - } - - /** - * See {@link #execBackwards(Map, Operation, String...)}. - *

- * Uses {@link Operation#INFERENCE}. - */ - public Map execBackwards(Map placeholders, String... variableGradNamesList) { - return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList); - } - - /** - * As per {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)}, but the set of gradients to calculate can be specified manually.
- * For example, to calculate the gradient for placeholder variable "myPlaceholder", use - * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. - * - * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map - * @param variableGradNamesList Names of the gradient variables to calculate - */ - public Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation) { - return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.emptyList(), Collections.emptyList()); - } - - /** - * See {@link #execBackwards(Map, List, Operation)}. - *

- * Uses {@link Operation#INFERENCE}. - */ - public Map execBackwards(Map placeholders, List variableGradNamesList) { - return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE); - } - - protected Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation, - MultiDataSet batch, Collection requiredActivations, List activeListeners) { - if (getFunction(GRAD_FN_KEY) == null) { - createGradFunction(); - } - - log.trace("About to execute backward function"); - - //Edge case: if no variables, no variable gradients to calculate... - if (variableGradNamesList.isEmpty()) { - log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)"); - return Collections.emptyMap(); - } - - SameDiff sd = sameDiffFunctionInstances.get(GRAD_FN_KEY); - sd.listeners.clear(); - sd.listeners.addAll(activeListeners); - - At at = new At(0, 0, 0, Thread.currentThread().getId(), operation); - if (trainingConfig != null) { - at.setIteration(trainingConfig.getIterationCount()); - at.setEpoch(trainingConfig.getEpochCount()); - } - - return sd.directExecHelper(placeholders, at, batch, requiredActivations, activeListeners, variableGradNamesList.toArray(new String[0])); + return new OutAndGrad(outOutputs, outGrads); } /** @@ -4781,7 +4081,7 @@ public class SameDiff extends SDBaseOps { } /** - * Create the gradient function (for calculating gradients via {@link #execBackwards(Map, Operation, String[])}) if it is not already defined. + * Create the gradient function (for calculating gradients via {@link #calculateGradients(Map, Collection)}) if it is not already defined. * Users do not usually need to call this function manually, as it is called as required in the aforementioned method. *

* If the gradient function already exists, this method is a no-op.
@@ -4808,14 +4108,23 @@ public class SameDiff extends SDBaseOps { if (trainingConfig != null && trainingConfig.getLossVariables() != null && !trainingConfig.getLossVariables().isEmpty()) { lossVariables.addAll(trainingConfig.getLossVariables()); } else { - List outputs = outputs(); - if (outputs.size() == 1) { - String outName = outputs.get(0); + List lossInferred = bestGuessLossVariables(); + if (lossInferred.size() == 1) { + String outName = lossInferred.get(0); String opName = variables.get(outName).getOutputOfOp(); if (opName == null || !(ops.get(opName).getOp() instanceof ExternalErrorsFunction)) { - log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", outputs.get(0)); + log.info("Inferring output \"{}\" as loss variable as none were previously set." + + "Use SameDiff.setLossVariables() or SDVariable.markAsLoss() to override", lossInferred.get(0)); + } + lossVariables.add(lossInferred.get(0)); + } else if(lossInferred.isEmpty()){ + //Check for external errors function + for(SameDiffOp o : ops.values()){ + if(o.getOp() instanceof ExternalErrorsFunction){ + List l = o.getOutputsOfOp(); + lossVariables.add(l.get(0)); + } } - lossVariables.add(outputs.get(0)); } } } @@ -4917,9 +4226,9 @@ public class SameDiff extends SDBaseOps { "point variable (datatype: %s). Only floating point variables may be used as loss function variable", s, v.dataType()); v = v.sum(); //If output is not a scalar: we'll use loss = v.sum(), same as adding loss for multiple outputs. We don't always know for sure if output is scalar at this point if (v.dataType() == initialGrad.dataType()) { - sameDiff.setGradientForVariableName(v.getVarName(), initialGrad); + sameDiff.setGradientForVariableName(v.name(), initialGrad); } else { - sameDiff.setGradientForVariableName(v.getVarName(), initialGrad.castTo(v.dataType())); + sameDiff.setGradientForVariableName(v.name(), initialGrad.castTo(v.dataType())); } if (finalOutputs.contains(v)) { log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", s); @@ -5047,7 +4356,7 @@ public class SameDiff extends SDBaseOps { //At this point: we know the set of variables that are connected to the loss - these all (and only) need gradients Queue availableForDiff = new LinkedList<>(); for (SDVariable lossVar : finalOutputs) { - Variable v = sameDiff.variables.get(lossVar.getVarName()); + Variable v = sameDiff.variables.get(lossVar.name()); if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); availableForDiff.add(opName); @@ -5229,52 +4538,39 @@ public class SameDiff extends SDBaseOps { associateSameDiffWithOpsAndVariables(); } - /** - * Set the original shape for a given place holder.
- * This is used to track original shapes of place holder variables.
- * The reason we track original shapes is to validate possible candidate arrays coming in (especially with -1 - * as the expected shapes). - *

- * Note that if {@link #isPlaceHolder(String)} - * returns false for the passed in vertex id, - * a {@link ND4JIllegalStateException} is thrown. - *

- * - * @param variableName the vertex id for the original shape - * @param shape the shape of the place holder + * Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general. */ - public void setOriginalPlaceHolderShape(String variableName, @NonNull long... shape) { - if (!isPlaceHolder(variableName)) { - throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?"); + protected List bestGuessLossVariables() { + List out = new ArrayList<>(); + for (Variable v : variables.values()) { + if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders + (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops + (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops + (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) + continue; + } + + //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user + if (v.getOutputOfOp() != null) { + String opName = v.getOutputOfOp(); + SameDiffOp o = ops.get(opName); + if (o.getOp() instanceof Assert) { + continue; + } + + //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed + // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here + // This applies to SameDiff while loops as well + if (o.getOp() instanceof Switch) { + continue; + } + } + + + out.add(v.getName()); } - - if (shape == null) { - throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed"); - } - - - if (placeHolderOriginalShapes.containsKey(variableName) && !Arrays.equals(placeHolderOriginalShapes.get(variableName), shape)) { - throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + variableName); - } - - //after validation now only set once - placeHolderOriginalShapes.put(variableName, shape); - - } - - - /** - * Get the original shape for the vertex id if one was set (other wise returns null).
- * This is mainly for use in validating passed in arrays as arguments to {@link #resolveVariablesWith(Map)} - * usually when executing using {@link #execAll(Map)} - * - * @param varName the vertex id to get the original shape for. - * @return the set vertex - */ - @Deprecated - public long[] getOriginalShapeForPlaceHolder(String varName) { - return placeHolderOriginalShapes.get(varName); + return out; } /** @@ -5289,53 +4585,6 @@ public class SameDiff extends SDBaseOps { return variables.get(varName).getVariable().isPlaceHolder(); } - - /** - * Resolve all ndarrays by updating the variables for each array specified in the given map. - * An {@link IllegalStateException} will be thrown if not all arrays are specified for resolution. - * - * @param arrays the arrays to resolve. - */ - public void resolveVariablesWith(Map arrays) { - for (Map.Entry e : arrays.entrySet()) { - SDVariable varForName = getVariable(e.getKey()); - if (varForName == null) { - throw new ND4JIllegalStateException("A placeholder array was provided for variable with name \"" + e.getKey() + - "\" but no variable with this name exists"); - } - - Variable v = variables.get(e.getKey()); - if (varForName.getVariableType() == VariableType.PLACEHOLDER) { - //Check shape: - long[] shape = varForName.placeholderShape(); - long[] newShape = e.getValue().shape(); - Preconditions.checkState(shape.length == newShape.length, "Placeholder shape not compatible (mismatched rank): placeholder \"%s\" " + - "shape %s, got incompatible shape %s", e.getKey(), shape, newShape); - } - } - - - for (val entry : arrays.entrySet()) { - if (!variables.get(entry.getKey()).getVariable().isPlaceHolder()) { - throw new ND4JIllegalStateException("Illegal variable " + entry.getKey() + " passed in. Variable found not to be a place holder variable"); - } - - val specifiedShape = getOriginalShapeForPlaceHolder(entry.getKey()); - //whole shape was specified: validate whether the input array shape is equal - if (!Shape.isPlaceholderShape(specifiedShape)) { - if (!Shape.shapeEquals(specifiedShape, entry.getValue().shape())) { - throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(specifiedShape) + " but array shape was " + Arrays.toString(entry.getValue().shape())); - } - } - - associateArrayWithVariable(entry.getValue(), getVariable(entry.getKey())); - setArrayForVariable(entry.getKey(), entry.getValue()); - } - - //declare resolved - resolvedVariables = true; - } - /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

@@ -5363,20 +4612,20 @@ public class SameDiff extends SDBaseOps { throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } - if (newVarName == null && variables.containsKey(varToUpdate.getVarName()) - && variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) { + if (newVarName == null && variables.containsKey(varToUpdate.name()) + && variables.get(varToUpdate.name()).getVariable() != varToUpdate) { //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name // "mean" and consequently a new variable name needs to be generated - newVarName = generateNewVarName(varToUpdate.getVarName(), 0); + newVarName = generateNewVarName(varToUpdate.name(), 0); } - if (newVarName == null || varToUpdate.getVarName().equals(newVarName)) { + if (newVarName == null || varToUpdate.name().equals(newVarName)) { return varToUpdate; } - val oldVarName = varToUpdate.getVarName(); + val oldVarName = varToUpdate.name(); varToUpdate.setVarName(newVarName); - updateVariableName(oldVarName, newVarName); + renameVariable(oldVarName, newVarName); return varToUpdate; } @@ -5462,7 +4711,7 @@ public class SameDiff extends SDBaseOps { 0, 0, -1, - 0, 0, 0, 0, 0, 0); + 0, 0, 0, 0, 0, 0, 0, 0, 0); return flatNode; } @@ -5538,12 +4787,12 @@ public class SameDiff extends SDBaseOps { val idxForOps = new IdentityHashMap(); List allVars = variables(); for (SDVariable variable : allVars) { - INDArray arr = variable.getArr(); - log.trace("Exporting variable: [{}]", variable.getVarName()); + INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr(); + log.trace("Exporting variable: [{}]", variable.name()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output // numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0) - String varName = variable.getVarName(); + String varName = variable.name(); int varIdx; int outputNum; if (variables.get(varName).getOutputOfOp() != null) { @@ -5564,11 +4813,11 @@ public class SameDiff extends SDBaseOps { } - reverseMap.put(variable.getVarName(), varIdx); + reverseMap.put(variable.name(), varIdx); - log.trace("Adding [{}] as [{}]", variable.getVarName(), varIdx); + log.trace("Adding [{}] as [{}]", variable.name(), varIdx); int shape = 0; - int name = bufferBuilder.createString(variable.getVarName()); + int name = bufferBuilder.createString(variable.name()); int array = 0; int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum); byte varType = (byte) variable.getVariableType().ordinal(); @@ -5579,10 +4828,32 @@ public class SameDiff extends SDBaseOps { if (variable.getVariableType() == VariableType.PLACEHOLDER) { val shp = variable.getShape(); - shape = FlatVariable.createShapeVector(bufferBuilder, shp); + if(shp != null) { + //Some models may have no shape defined, not ever a placeholder type shape + shape = FlatVariable.createShapeVector(bufferBuilder, shp); + } } - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType); + int controlDeps = 0; + int controlDepsForOp = 0; + int controlDepsForVar = 0; + Variable v = variables.get(varName); + + int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder); + if(cds != null) + controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds); + + int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder); + if(cdsForOp != null) + controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp); + + int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder); + if(cdsForVar != null) + controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar); + + + int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, + array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar); flatVariables.add(flatVariable); } @@ -5593,43 +4864,6 @@ public class SameDiff extends SDBaseOps { flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId)); } - // we're dumping scopes now - for (Map.Entry scope : sameDiffFunctionInstances.entrySet()) { - if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) { - //Skip the gradient function for export - continue; - } - - flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder)); - val currVarList = new ArrayList(scope.getValue().variables()); - // converting all ops from node - for (val node : scope.getValue().variables()) { - INDArray arr = node.getArr(); - if (arr == null) { - continue; - } - - int name = bufferBuilder.createString(node.getVarName()); - int array = arr.toFlatArray(bufferBuilder); - int id = IntPair.createIntPair(bufferBuilder, ++idx, 0); - - val pair = parseVariable(node.getVarName()); - reverseMap.put(pair.getFirst(), idx); - - log.trace("Adding [{}] as [{}]", pair.getFirst(), idx); - - byte varType = (byte) node.getVariableType().ordinal(); - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType); - flatVariables.add(flatVariable); - } - - //add functions - for (SameDiffOp op : scope.getValue().ops.values()) { - DifferentialFunction func = op.getOp(); - flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null)); - } - } - int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets)); int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables)); int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); @@ -5647,7 +4881,7 @@ public class SameDiff extends SDBaseOps { for (SDVariable v : variables()) { if (!v.isPlaceHolder()) continue; - placeholderOffsets[i++] = bufferBuilder.createString(v.getVarName()); + placeholderOffsets[i++] = bufferBuilder.createString(v.name()); } } int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets); @@ -5958,7 +5192,7 @@ public class SameDiff extends SDBaseOps { vars.add(fg.variables(i)); } - FlatConfiguration conf = fg.configuration(); +// FlatConfiguration conf = fg.configuration(); /* Reconstruct the graph We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control @@ -5992,9 +5226,37 @@ public class SameDiff extends SDBaseOps { //TODO Infer this properly! Could be constant, etc. VariableType vt = VariableType.values()[v.variabletype()]; - SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null); + SDVariable var = new SDVariable(n, vt, sd, shape, dtype); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); - sd.variableNameToShape.put(n, shape); + Variable v2 = sd.variables.get(n); + + //Reconstruct control dependencies + if(v.controlDepsLength() > 0){ + int num = v.controlDepsLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0){ + int num = v.controlDepForOpLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0){ + int num = v.controlDepsForVarLength(); + List l = new ArrayList<>(num); + for( int i=0; i 0) { + int l = fn.controlDepsLength(); + List list = new ArrayList<>(l); + for( int i=0; i 0) { + int l = fn.varControlDepsLength(); + List list = new ArrayList<>(l); + for( int i=0; i 0) { + int l = fn.controlDepForLength(); + List list = new ArrayList<>(l); + for( int i=0; i()); } if (!v.getInputsForOp().contains(df.getOwnName())) { - v.getInputsForOp( - - ).add(df.getOwnName()); + v.getInputsForOp().add(df.getOwnName()); } } @@ -6092,7 +5382,7 @@ public class SameDiff extends SDBaseOps { if (varsForOp != null && varsForOp.size() == numOutputs) { varNames = new String[varsForOp.size()]; for (int i = 0; i < varNames.length; i++) { - varNames[i] = varsForOp.get(i).getVarName(); + varNames[i] = varsForOp.get(i).name(); sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName()); } sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames)); @@ -6105,7 +5395,7 @@ public class SameDiff extends SDBaseOps { varNames[i] = n; if (!sd.variables.containsKey(n)) { //Need to create the variable - perhaps it wasn't exported. Note output of node -> can only be VARIABLE type - SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null, null); + SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); variablesByNodeAndOutNum.put(new Pair<>(opId, i), var); } @@ -6414,32 +5704,6 @@ public class SameDiff extends SDBaseOps { return sb.toString(); } - /** - * Calculate data types for the variables in the graph - */ - public Map calculateOutputDataTypes() { - return calculateOutputDataTypes(false); - } - - /** - * Calculate data types for the variables in the graph - */ - public Map calculateOutputDataTypes(boolean dynamicUpdate) { - List allVars = new ArrayList<>(variables.keySet()); - DataTypesSession session = new DataTypesSession(this, dynamicUpdate); - Map phValues = new HashMap<>(); - for (Variable v : variables.values()) { - if (v.getVariable().isPlaceHolder()) { - org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType(); - Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName()); - phValues.put(v.getName(), dt); - } - } - Map out = session.output(allVars, phValues, null, - Collections.emptyList(), Collections.emptyList(), At.defaultAt(Operation.INFERENCE)); - return out; - } - /** * For internal use only. * Creates a new discinct block name from baseName. @@ -6470,14 +5734,14 @@ public class SameDiff extends SDBaseOps { * @return The imported graph */ public static SameDiff importFrozenTF(File graphFile) { - return TFGraphMapper.getInstance().importGraph(graphFile); + return TFGraphMapper.importGraph(graphFile); } /** * See {@link #importFrozenTF(File)} */ public static SameDiff importFrozenTF(GraphDef graphDef) { - return TFGraphMapper.getInstance().importGraph(graphDef); + return TFGraphMapper.importGraph(graphDef); } @@ -6487,7 +5751,7 @@ public class SameDiff extends SDBaseOps { * Again, the input can be text or binary. */ public static SameDiff importFrozenTF(InputStream graph) { - return TFGraphMapper.getInstance().importGraph(graph); + return TFGraphMapper.importGraph(graph); } @@ -6511,7 +5775,7 @@ public class SameDiff extends SDBaseOps { int start = 1; // if we already have a name like "op_2", start from trying "op_3" - if (base.contains("_")) { + if (base.contains("_") && base.matches(".*_\\d+")) { // extract number used to generate base Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); // extract argIndex used to generate base diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index 353c2d1e1..d50daddb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -440,7 +440,7 @@ public class TrainingConfig { * @param evaluations The evaluations to run */ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + return trainEvaluation(variable.name(), labelIndex, evaluations); } /** @@ -468,7 +468,7 @@ public class TrainingConfig { * @param evaluations The evaluations to run */ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + return validationEvaluation(variable.name(), labelIndex, evaluations); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java new file mode 100644 index 000000000..d0bc4b8b6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java @@ -0,0 +1,19 @@ +package org.nd4j.autodiff.samediff.api; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Map; + +/** + * A simple object holding two maps - one of output arrays, another of gradient arrays + */ +@AllArgsConstructor +@Data +public class OutAndGrad { + + private final Map outputs; + private final Map gradients; + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java index ba5aaa234..ffa0c5f85 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java @@ -73,7 +73,7 @@ public class BatchOutputConfig { public BatchOutputConfig output(@NonNull SDVariable... outputs){ String[] outNames = new String[outputs.length]; for(int i = 0 ; i < outputs.length ; i++){ - outNames[i] = outputs[i].getVarName(); + outNames[i] = outputs[i].name(); } return output(outNames); @@ -104,7 +104,7 @@ public class BatchOutputConfig { * See {@link #input(String, INDArray)} */ public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){ - return input(variable.getVarName(), placeholder); + return input(variable.name(), placeholder); } /** @@ -132,19 +132,35 @@ public class BatchOutputConfig { return this; } + /** + * @deprecated Use {@link #output()} + */ + @Deprecated + public Map exec() { + return output(); + } + /** * Do inference and return the results */ - public Map exec(){ + public Map output(){ return sd.output(placeholders, listeners, outputs.toArray(new String[0])); } + /** + * @deprecated Use {@link #outputSingle()} + */ + @Deprecated + public INDArray execSingle() { + return outputSingle(); + } + /** * Do inference and return the results for the single output * * Only works if exactly one output is specified */ - public INDArray execSingle(){ + public INDArray outputSingle(){ Preconditions.checkState(outputs.size() == 1, "Can only use execSingle() when exactly one output is specified, there were %s", outputs.size()); return exec().get(outputs.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java index f4477adad..389e14306 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java @@ -81,7 +81,7 @@ public class EvaluationConfig { * See {@link #evaluate(String, int, IEvaluation[])} */ public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ - return evaluate(variable.getVarName(), labelIndex, evaluations); + return evaluate(variable.name(), labelIndex, evaluations); } /** @@ -106,7 +106,7 @@ public class EvaluationConfig { * See {@link #evaluate(String, IEvaluation[])} */ public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){ - return evaluate(variable.getVarName(), evaluations); + return evaluate(variable.name(), evaluations); } /** @@ -129,7 +129,7 @@ public class EvaluationConfig { * See {@link #labelIndex(String, int)} */ public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){ - return labelIndex(variable.getVarName(), labelIndex); + return labelIndex(variable.name(), labelIndex); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index 376430f54..5cafa09aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -75,7 +75,7 @@ public class OutputConfig { public OutputConfig output(@NonNull SDVariable... outputs) { String[] outNames = new String[outputs.length]; for (int i = 0; i < outputs.length; i++) { - outNames[i] = outputs[i].getVarName(); + outNames[i] = outputs[i].name(); } return output(outNames); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java new file mode 100644 index 000000000..776d26794 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java @@ -0,0 +1,444 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.function.Predicate; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Object dependency tracker. + *
+ * Dependency are denoted by: X -> Y, which means "Y depends on X"
+ * In this implementation:
+ * - Dependencies may be satisfied, or not satisfied
+ * - The implementation tracks when the dependency for an object Y are fully satisfied. This occurs when:
+ * 1. No dependencies X->Y exist
+ * 2. All dependencies of the form X->Y have been marked as satisfied, via markSatisfied(x)
+ * - When a dependency is satisfied, any dependent (Ys) are checked to see if all their dependencies are satisfied
+ * - If a dependent has all dependencies satisfied, it is added to the "new all satisfied" queue for processing, + * which can be accessed via {@link #hasNewAllSatisfied()}, {@link #getNewAllSatisfied()} and {@link #getNewAllSatisfiedList()}
+ *
+ * Note: Two types of dependencies exist
+ * 1. Standard dependencies - i.e., "Y depends on X"
+ * 2. "Or" dependencies - i.e., "Y depends on (A or B)".
+ * For Or dependencies of the form "(A or B) -> Y", Y will be marked as "all dependencies satisfied" if either A or B is marked as satisfied. + * + * @param For a dependency X -> Y, Y has type T + * @param For a dependency X -> Y, X has type D + */ +@Slf4j +public abstract class AbstractDependencyTracker { + @Getter + private final Map> dependencies; //Key: the dependent. Value: all things that the key depends on + @Getter + private final Map>> orDependencies; //Key: the dependent. Value: the set of OR dependencies + private final Map> reverseDependencies = new HashMap<>(); //Key: the dependee. Value: The set of all dependents that depend on this value + private final Map> reverseOrDependencies = new HashMap<>(); + private final Set satisfiedDependencies = new HashSet<>(); //Mark the dependency as satisfied. If not in set: assumed to not be satisfied + + private final Set allSatisfied; //Set of all dependent values (Ys) that have all dependencies satisfied + private final Queue allSatisfiedQueue = new LinkedList<>(); //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods + + + protected AbstractDependencyTracker() { + dependencies = (Map>) newTMap(); + orDependencies = (Map>>) newTMap(); + allSatisfied = newTSet(); + } + + /** + * @return A new map where the dependents (i.e., Y in "X -> Y") are the key + */ + protected abstract Map newTMap(); + + /** + * @return A new set where the dependents (i.e., Y in "X -> Y") are the key + */ + protected abstract Set newTSet(); + + /** + * @return A String representation of the dependent object + */ + protected abstract String toStringT(T t); + + /** + * @return A String representation of the dependee object + */ + protected abstract String toStringD(D d); + + /** + * Clear all internal state for the dependency tracker + */ + public void clear() { + dependencies.clear(); + orDependencies.clear(); + reverseDependencies.clear(); + reverseOrDependencies.clear(); + satisfiedDependencies.clear(); + allSatisfied.clear(); + allSatisfiedQueue.clear(); + } + + /** + * @return True if no dependencies have been defined + */ + public boolean isEmpty() { + return dependencies.isEmpty() && orDependencies.isEmpty() && + allSatisfiedQueue.isEmpty(); + } + + /** + * @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)} + */ + public boolean isSatisfied(@NonNull D x) { + return satisfiedDependencies.contains(x); + } + + /** + * Mark the specified value as satisfied. + * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) + * call, both of these dependencies are considered satisfied. + * + * @param x Value to mark + * @param satisfied Whether to mark as satisfied (true) or unsatisfied (false) + */ + public void markSatisfied(@NonNull D x, boolean satisfied) { + if (satisfied) { + boolean alreadySatisfied = satisfiedDependencies.contains(x); + + if (!alreadySatisfied) { + satisfiedDependencies.add(x); + + //Check if any Y's exist that have dependencies that are all satisfied, for X -> Y + Set s = reverseDependencies.get(x); + Set s2 = reverseOrDependencies.get(x); + + Set set; + if (s != null && s2 != null) { + set = newTSet(); + set.addAll(s); + set.addAll(s2); + } else if (s != null) { + set = s; + } else if (s2 != null) { + set = s2; + } else { + if (log.isTraceEnabled()) { + log.trace("No values depend on: {}", toStringD(x)); + } + return; + } + + for (T t : set) { + Set required = dependencies.get(t); + Set> requiredOr = orDependencies.get(t); + boolean allSatisfied = true; + if (required != null) { + for (D d : required) { + if (!isSatisfied(d)) { + allSatisfied = false; + break; + } + } + } + if (allSatisfied && requiredOr != null) { + for (Pair p : requiredOr) { + if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) { + allSatisfied = false; + break; + } + } + } + + if (allSatisfied) { + if (!this.allSatisfied.contains(t)) { + this.allSatisfied.add(t); + this.allSatisfiedQueue.add(t); + } + } + } + } + + } else { + satisfiedDependencies.remove(x); + if (!allSatisfied.isEmpty()) { + + Set reverse = reverseDependencies.get(x); + if (reverse != null) { + for (T y : reverse) { + if (allSatisfied.contains(y)) { + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + } + Set orReverse = reverseOrDependencies.get(x); + if (orReverse != null) { + for (T y : orReverse) { + if (allSatisfied.contains(y) && !isAllSatisfied(y)) { + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + } + } + } + } + + /** + * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} + * or {@link #addOrDependency(Object, Object, Object)} + * + * @param y Dependent to check + * @return True if Y depends on any values + */ + public boolean hasDependency(@NonNull T y) { + Set s1 = dependencies.get(y); + if (s1 != null && !s1.isEmpty()) + return true; + + Set> s2 = orDependencies.get(y); + return s2 != null && !s2.isEmpty(); + } + + /** + * Get all dependencies x, for x -> y, and (x1 or x2) -> y + * + * @param y Dependent to get dependencies for + * @return List of dependencies + */ + public DependencyList getDependencies(@NonNull T y) { + Set s1 = dependencies.get(y); + Set> s2 = orDependencies.get(y); + + List l1 = (s1 == null ? null : new ArrayList<>(s1)); + List> l2 = (s2 == null ? null : new ArrayList<>(s2)); + + return new DependencyList<>(y, l1, l2); + } + + /** + * Add a dependency: y depends on x, as in x -> y + * + * @param y The dependent + * @param x The dependee that is required for Y + */ + public void addDependency(@NonNull T y, @NonNull D x) { + if (!dependencies.containsKey(y)) + dependencies.put(y, new HashSet()); + + if (!reverseDependencies.containsKey(x)) + reverseDependencies.put(x, newTSet()); + + dependencies.get(y).add(x); + reverseDependencies.get(x).add(y); + + checkAndUpdateIfAllSatisfied(y); + } + + protected void checkAndUpdateIfAllSatisfied(@NonNull T y) { + boolean allSat = isAllSatisfied(y); + if (allSat) { + //Case where "x is satisfied" happened before x->y added + if (!allSatisfied.contains(y)) { + allSatisfied.add(y); + allSatisfiedQueue.add(y); + } + } else if (allSatisfied.contains(y)) { + if (!allSatisfiedQueue.contains(y)) { + StringBuilder sb = new StringBuilder(); + sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies") + .append(" were marked satisfied, but is now additional dependencies have been added.\n"); + DependencyList dl = getDependencies(y); + if (dl.getDependencies() != null) { + sb.append("Dependencies:\n"); + for (D d : dl.getDependencies()) { + sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n"); + } + } + if (dl.getOrDependencies() != null) { + sb.append("Or dependencies:\n"); + for (Pair p : dl.getOrDependencies()) { + sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")"); + } + } + throw new IllegalStateException(sb.toString()); + } + + //Not satisfied, but is in the queue -> needs to be removed + allSatisfied.remove(y); + allSatisfiedQueue.remove(y); + } + } + + protected boolean isAllSatisfied(@NonNull T y) { + Set set1 = dependencies.get(y); + + boolean allSatisfied = true; + if (set1 != null) { + for (D d : set1) { + allSatisfied = isSatisfied(d); + if (!allSatisfied) + break; + } + } + if (allSatisfied) { + Set> set2 = orDependencies.get(y); + if (set2 != null) { + for (Pair p : set2) { + allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond()); + if (!allSatisfied) + break; + } + } + } + return allSatisfied; + } + + + /** + * Remove a dependency (x -> y) + * + * @param y The dependent that currently requires X + * @param x The dependee that is no longer required for Y + */ + public void removeDependency(@NonNull T y, @NonNull D x) { + if (!dependencies.containsKey(y) && !orDependencies.containsKey(y)) + return; + + Set s = dependencies.get(y); + if (s != null) { + s.remove(x); + if (s.isEmpty()) + dependencies.remove(y); + } + + Set s2 = reverseDependencies.get(x); + if (s2 != null) { + s2.remove(y); + if (s2.isEmpty()) + reverseDependencies.remove(x); + } + + + Set> s3 = orDependencies.get(y); + if (s3 != null) { + boolean removedReverse = false; + Iterator> iter = s3.iterator(); + while (iter.hasNext()) { + Pair p = iter.next(); + if (x.equals(p.getFirst()) || x.equals(p.getSecond())) { + iter.remove(); + + if (!removedReverse) { + Set set1 = reverseOrDependencies.get(p.getFirst()); + Set set2 = reverseOrDependencies.get(p.getSecond()); + + set1.remove(y); + set2.remove(y); + + if (set1.isEmpty()) + reverseOrDependencies.remove(p.getFirst()); + if (set2.isEmpty()) + reverseOrDependencies.remove(p.getSecond()); + + removedReverse = true; + } + } + } + } + if (s3 != null && s3.isEmpty()) + orDependencies.remove(y); + } + + /** + * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
+ * If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the + * dependency is considered satisfied + * + * @param y Dependent + * @param x1 Dependee 1 + * @param x2 Dependee 2 + */ + public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) { + if (!orDependencies.containsKey(y)) + orDependencies.put(y, new HashSet>()); + + if (!reverseOrDependencies.containsKey(x1)) + reverseOrDependencies.put(x1, newTSet()); + if (!reverseOrDependencies.containsKey(x2)) + reverseOrDependencies.put(x2, newTSet()); + + orDependencies.get(y).add(new Pair<>(x1, x2)); + reverseOrDependencies.get(x1).add(y); + reverseOrDependencies.get(x2).add(y); + + checkAndUpdateIfAllSatisfied(y); + } + + /** + * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) + */ + public boolean hasNewAllSatisfied() { + return !allSatisfiedQueue.isEmpty(); + } + + /** + * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} + * Throws an exception if {@link #hasNewAllSatisfied()} returns false.
+ * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; + * the value is considered "processed" at this point. + * + * @return The next new "all satisfied dependent" + */ + public T getNewAllSatisfied() { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + return allSatisfiedQueue.remove(); + } + + /** + * @return As per {@link #getNewAllSatisfied()} but returns all values + */ + public List getNewAllSatisfiedList() { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + List ret = new ArrayList<>(allSatisfiedQueue); + allSatisfiedQueue.clear(); + return ret; + } + + /** + * As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches + * the provided predicate. If no value matches the predicate, null is returned + * + * @param predicate Predicate gor checking + * @return The first value matching the predicate, or null if no values match the predicate + */ + public T getFirstNewAllSatisfiedMatching(@NonNull Predicate predicate) { + Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); + + T t = allSatisfiedQueue.peek(); + if (predicate.test(t)) { + t = allSatisfiedQueue.remove(); + allSatisfied.remove(t); + return t; + } + + if (allSatisfiedQueue.size() > 1) { + Iterator iter = allSatisfiedQueue.iterator(); + while (iter.hasNext()) { + t = iter.next(); + if (predicate.test(t)) { + iter.remove(); + allSatisfied.remove(t); + return t; + } + } + } + + return null; //None match predicate + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 387e25f48..1f93dbe94 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -16,46 +16,59 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.Getter; -import lombok.NonNull; +import lombok.*; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.function.Predicate; import java.util.*; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.factory.Nd4j; /** - * Additional functionality to add: - * - Workspaces support - * - Proper cache support + * AbstractSession is a SameDiff graph execution class that inference and training it built upon + * It walks through the graph, dynamically executing operations that can be executed next, but (again, dynamically) only + * executing the subset of the graph that is actually required to get the requested outputs.
+ * None of what AbstractSession implements is NDArray-specific.
+ * Note that most of the implementation complexity comes from dynamic graphs - i.e., nested loops, control ops, etc * * @param Node output type - for example, INDArray, shape, etc depending on what we're calculating * @param Op type + * @author Alex Black */ @Slf4j public abstract class AbstractSession { - //All execution happens in a frame... this is the name of the main/outer frame + /** + * All execution in Samediff happens in a frame... this is the name of the main/outer frame - i.e., the "default" frame + * Other frames (such as for loops) may be nested within this frame + */ public static final String OUTER_FRAME = "main"; protected final SameDiff sameDiff; @Getter - protected final Map nodeOutputs = new HashMap<>(); + protected final Map nodeOutputs = new HashMap<>(); //Key: variable (at a given frame + iteration). Value: the calculated output for that variable @Getter - protected final Map> tensorArrays = new HashMap<>(); //Stores the outputs for a TensorArray ops - protected final Queue availableForExec = new LinkedList<>(); - protected final Set availableForExecSet = new HashSet<>(); //Same content as the queue, but used for O(1) contains instead of ordered removal + protected final Map> tensorArrays = new HashMap<>(); //Stores the underlying arrays for TensorArray ops + /* + The dependency tracker is responsible for determining what ops (at what frame/iteration) can be executed next, given + what has been executed so far. + For static graphs, such as abstraction would not be necessary; for dynamic graphs (i.e., nested loops, of arbitary + number of iterations and depth - and also switch ops which can cause whole subgraphs to not be executed) this is necessary + Note: the ExecStep represents one step for execution - some steps are as simple as "execute an op (at the given frame/iter)" + It works by adding dependencies (X -> Y - such as "op Y depends on the output of op X") and then marking them as + satisfied ("op X has been calculated"). Once all dependencies for an execution step have been satisfied, the execution step + is added to a queue - outputs of which can be accessed with dt.getNewAllSatisfied() and dt.getNewAllSatisfiedList(), + at which point it is removed from the dependency tracker + */ + protected final DependencyTracker dt = new DependencyTracker<>(); + /** * Contains variables we *might* need to execute in process of getting outputs we want. * Variables not in this set are definitely not needed to get the requested output variables, but variables that are @@ -63,45 +76,22 @@ public abstract class AbstractSession { */ protected final Set subgraph = new HashSet<>(); /** - * Stores what variables are required to calculate the specific variable. These inputs could be inputs to an op that - * calculates the variable's value, or it could be a control dependenci - * Keys: variable (in specific frame/iteration) to be executed - * Values: inputs to that node (inc. frame and iteration), unordered - needed for execution of op giving variable + * As per subgraph set, but for ops instead */ - protected final Map> execInputs = new HashMap<>(); + protected final Set subgraphOps = new HashSet<>(); /** - * As per execInputs map - with the different that the iteration number should be ignored (i.e., always 0) - * Reason: Enter nodes - these are executed once - * Example: EnterOp(x) -> LoopCondition(less(x,y)): less op requires "X" on all iterations which is the output of the - * enter op, which is only executed for iteration 0 in a frame. + * Constains the names of ops that don't have any inputs. Kept because normally ops are triggered for execution when + * their all their inputs have been calculated; we'll trigger that step manually during execution initialization */ - protected final Map> execInputsAllIter = new HashMap<>(); - - /** - * Contains the set set of constant and placeholders inputs - * Essentially the same as the execInputs map, but the constants and placeholders are used for calculating all instances - * of a variable - i.e., the input (constant/placeholder) applies to all frames and iterations. - * Keys: variable (any/all frame/iteration) to be executed - * Values: constant or placeholder needed for execution of op giving variable - */ - protected final Map> execConstInputs = new HashMap<>(); - /** - * Map for exit ops. This is used to determine where an exit op should exit to. - * Values added on enter ops. Note that it's not sufficient to - * Key: frame name (for enter/exit nodes). - * Value: parent frame name + iteration - */ - @Getter - protected final Map frameParents = new HashMap<>(); - + protected final Set zeroInputOpsInSubgraph = new HashSet<>(); public AbstractSession(@NonNull SameDiff sameDiff) { this.sameDiff = sameDiff; } - public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter){ - VarId varId = newVarId(variable, frame, iteration, parentFrameIter); + public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter) { + VarId varId = new VarId(variable, frame, iteration, parentFrameIter); return nodeOutputs.containsKey(varId); } @@ -114,62 +104,36 @@ public abstract class AbstractSession { /** * Get a previously calculated output + * * @param enforceExistence If true: throw an exception if the array does not exist */ public T get(String variable, String frame, int iteration, FrameIter parentFrameIter, boolean enforceExistence) { //TODO eventually we'll cache and reuse VarId objects here to avoid garbage generation on lookup etc - VarId varId = newVarId(variable, frame, iteration, parentFrameIter); + VarId varId = new VarId(variable, frame, iteration, parentFrameIter); T out = nodeOutputs.get(varId); - if(enforceExistence) { + if (enforceExistence) { Preconditions.checkNotNull(out, "No output found for variable %s (frame %s, iteration %s)", variable, frame, iteration); } return out; } - public VarId newVarId(String variable, String frame, int iteration, FrameIter parentFrameIter) { - //TODO eventually we'll cache and reuse VarId objects here to avoid garbage generation on lookup - return new VarId(variable, frame, iteration, parentFrameIter); - } - - public VarId newVarId(String variable, FrameIter frameIter) { - return newVarId(variable, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()); - } - /** - * @deprecated Use {@link #output(List, Map, MultiDataSet, Collection, List, At)}. + * Get the output of the session - i.e., perform inference/forward pass and return the autputs for the specified variables * - * @param training Uses Operation.TRAINING if true, otherwise Operation.INFERENCE - */ - @Deprecated - public Map output(@NonNull List variables, Map placeholderValues, - MultiDataSet batch, Collection requiredActivations, boolean training, At at){ - if(at == null){ - if(training) - at = At.defaultAt(Operation.TRAINING); - else - at = At.defaultAt(Operation.INFERENCE); - } - return output(variables, placeholderValues, batch, requiredActivations, Collections.emptyList(), at); - } - - /** - * Get the output of the session - i.e., perform inference/forward pass - * - * @param variables Name of the variables we want the arrays/activations for - * @param placeholderValues The placeholder values (if any). - * @param batch The batch data, used to call Listener.opExecution - * @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null. + * @param variables Name of the variables we want the arrays/activations for + * @param placeholderValues The placeholder values (if any). May be null. + * @param batch The batch data, used to call Listener.opExecution + * @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null. * @return The specified variable values, optionally in the specified workspace */ public Map output(@NonNull List variables, Map placeholderValues, - MultiDataSet batch, Collection requiredActivations, List listeners, At at) { + MultiDataSet batch, Collection requiredActivations, List listeners, At at) { + Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); - Preconditions.checkState(!variables.isEmpty(), "Variables to perform forward pass for must not be empty"); - - if(requiredActivations == null) + if (requiredActivations == null) requiredActivations = Collections.emptyList(); - if(at == null) + if (at == null) at = At.defaultAt(); //Step 0: validation - that variables exist, placeholders have arrays, etc @@ -177,44 +141,46 @@ public abstract class AbstractSession { Preconditions.checkState(sameDiff.variableMap().containsKey(s), "Requested output variable %s does not exist in SameDiff instance", s); } - placeholderValues = preprocessPlaceholders(placeholderValues); + Set reqOutputVariablesSet = new HashSet<>(variables); - //Clear state from past - availableForExec.clear(); - availableForExecSet.clear(); + placeholderValues = preprocessPlaceholders(placeholderValues, at); + + //Clear state from past iterations, if any + dt.clear(); subgraph.clear(); - execInputs.clear(); - execInputsAllIter.clear(); - execConstInputs.clear(); - nodeOutputs.clear(); //TODO eventually we'll have cache here for later execs... main challenge is detecting in-place array modifications and invalidating old results + subgraphOps.clear(); + nodeOutputs.clear(); //TODO eventually we'll have (optional) cache here for later execs... main challenge is detecting in-place array modifications and invalidating old results. And overall memory use... tensorArrays.clear(); //Step 1: determine subgraph structure we actually need to execute //Basic plan: work backwards from the variables we want, based on the graph structure, to work out what // we actually need to execute - List allRequired = new ArrayList<>(requiredActivations); + //TODO we'll optimize this and cache the results, only recalculating if the graph structure changes + Set userRequestedUnique = new HashSet<>(variables); + Set allRequired = new HashSet<>(requiredActivations); allRequired.addAll(variables); initSubgraph(allRequired); - //Step 1a: Check that we have required placeholders + //Step 2: Check that we have required placeholders List phNames = sameDiff.inputs(); - if(placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)){ + if (placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)) { /* We only have a subset of all placeholders Validate that we have all *required* placeholder values. Some might not be needed to calculate the requested outputs A placeholder is required if: (a) It's one of the requested outputs (b) It's required to calculate any of the ops in the subgraph + For example, we might have a label placeholder, and we're doing inference not training */ - for(String s : phNames){ + for (String s : phNames) { boolean required = false; - if(variables.contains(s)){ //TODO List.contains - O(N) + if (variables.contains(s)) { required = true; } - if(!required){ + if (!required) { Variable v = sameDiff.getVariables().get(s); - if(v.getInputsForOp() != null){ - for(String s2 : v.getInputsForOp()){ - if(subgraph.contains(s2)){ + if (v.getInputsForOp() != null) { + for (String s2 : v.getInputsForOp()) { + if (subgraph.contains(s2)) { //Placeholder is required required = true; break; @@ -223,200 +189,562 @@ public abstract class AbstractSession { } } - if(required && (placeholderValues == null || !placeholderValues.containsKey(s))){ - - // Some Keras layers (like GRU) do different things depending on whether the model is training. - // We provide this value directly. - if(s.endsWith("keras_learning_phase")){ - placeholderValues.put(s, (T) Nd4j.scalar(at.operation().isTrainingPhase())); - } else { - throw new IllegalStateException( - "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided"); - } + if (required && (placeholderValues == null || !placeholderValues.containsKey(s))) { + throw new IllegalStateException( + "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + + " but a placeholder value was not provided"); } } } - //Step 2: execute in any order, until we have all required nodeOutputs + //Step 3: Mark the (required) variables, constants and placeholders as available via dependency tracker + //And also any "zero dependency" ops - i.e., those without any inputs + ExecStep start = new ExecStep(ExecType.EXEC_START, "", null); //Dummy dependency to trigger the variables and constants + for (SDVariable v : sameDiff.variables()) { + VariableType vt = v.getVariableType(); + if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) { + ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT; + ExecStep es = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + + Variable var = sameDiff.getVariables().get(v.name()); + if (var.getControlDeps() != null) { + addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed + } + } + } + for (String s : phNames) { + ExecStep es = new ExecStep(ExecType.PLACEHOLDER, s, new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + + Variable var = sameDiff.getVariables().get(s); + if (var.getControlDeps() != null) { + addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed + } + } + for (String s : zeroInputOpsInSubgraph) { + ExecStep es = new ExecStep(ExecType.OP, s, new FrameIter(OUTER_FRAME, 0, null)); + dt.addDependency(es, start); + } + dt.markSatisfied(start, true); + + + //Step 4: execute in any order, but not switching to new frame/iteration until all from current frame/iter ops + // are done - until we have all required nodeOutputs /* - The idea is simple: we start off with a set of "available to execute" variables - just the placeholders and - constants at this point. + The idea is simple: we start off with a set of "available to execute" variables - just the placeholders, + constants and variables (assuming no control dependencies) at the start of execution. Then, we remove an "available to execute" node and execute it. Execution may be: - (a) For constants and placeholders: just looking up the value - (b) For variables as outputs of ops: actually executing the op + (a) For constants, variable type SDVariables, and placeholders: just look up the value + (b) For variables as outputs of ops: actually execute the op After execution, we look at the graph structure and determine what that now executed/calculated variable is an input to. If all inputs are available for the op, we mark all output variables of that op as available for execution. + Both parts of this (tracking dependencies, and also what's now available to execute) are handled in the dependency tracker We stop computation once all the required outputs are available. At this point, subgraph may NOT be empty - for example, switch ops may cause entire branches of the graph to be skipped. */ - Map out = new HashMap<>(); - int step = 0; - while (out.size() < variables.size()) { - if(availableForExec.size() == 0){ - int missingCount = variables.size() - out.size(); - StringBuilder sb = new StringBuilder(); - sb.append("No variable are available for execution at step ") - .append(step).append(": ").append(missingCount).append(" values remaining"); - Set missing = new HashSet<>(); - for(String s : variables){ - if(!out.containsKey(s)){ - missing.add(s); - } + Map out = new HashMap<>(); //Outputs, returned to the user + int step = 0; //Number of execution steps + //Next 3: current execution frame + String currentFrame = OUTER_FRAME; + int currentFrameIter = 0; + FrameIter currParentFrame = null; + ExecStepPredicate predicate = new ExecStepPredicate(); + while (out.size() < userRequestedUnique.size()) { + if (!dt.hasNewAllSatisfied()) { + //Haven't got all of the outputs the user requested, but there's nothing left that we can execute. Should not happen. + execFailed(userRequestedUnique, out, step); + } + + //Get variable in the current frame/iteration and execute it's corresponding op + //If no more ops exist for the current frame/iter, we'll switch to the next frame/iter + //The idea is to not mix the order of execution of ops in different frames/iters - i.e., finish the current + // frame/iter before starting the next one + predicate.setCurrentFrame(currentFrame); + predicate.setCurrentFrameIter(currentFrameIter); + predicate.setCurrParentFrame(currParentFrame); + + ExecStep es = dt.getFirstNewAllSatisfiedMatching(predicate); + if (es == null) { + //We must have finished the current frame/iter, and are switching to the next one + es = dt.getNewAllSatisfied(); + } + + currentFrame = es.getFrameIter().getFrame(); + currentFrameIter = es.getFrameIter().getIteration(); + currParentFrame = es.getFrameIter().getParentFrame(); + + log.trace("Beginning execution step {}: {}", step, es); + + FrameIter outFrameIter; + boolean skipDepUpdate = false; //Only used for Switch ops, which have slighly different handling... + boolean skipMarkSatisfied = false; //Only for enter ops, because of different frame/iter + if (es.getType() == ExecType.CONSTANT || es.getType() == ExecType.VARIABLE) { + VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); + T arr = getConstantOrVariable(es.getName()); + Preconditions.checkNotNull(arr, "Encountered null placeholder array for constant: %s", vid); + nodeOutputs.put(vid, arr); + outFrameIter = new FrameIter(OUTER_FRAME, 0, null); + if (allRequired.contains(es.getName())) { + //User requested const/variable as one of the outputs + out.put(es.getName(), arr); } - if(missingCount <= 10){ - sb.append(". Missing variables: "); - sb.append(missing); + } else if (es.getType() == ExecType.PLACEHOLDER) { + VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); + nodeOutputs.put(vid, placeholderValues.get(es.getName())); + outFrameIter = new FrameIter(OUTER_FRAME, 0, null); + if (allRequired.contains(es.getName())) { + //User requested placeholder value as one of the outputs + out.put(es.getName(), placeholderValues.get(es.getName())); + } + } else if (es.getType() == ExecType.OP) { + String opName = es.getName(); + SameDiffOp op = sameDiff.getOps().get(opName); + DifferentialFunction o = op.getOp(); + + if (o instanceof Enter) { + //Enter op: output is variable in a new (specified) frame, iteration 0. + //Parent is current (input) frame + String outFrame = ((Enter) o).getFrameName(); + outFrameIter = new FrameIter(outFrame, 0, es.getFrameIter()); + } else if (o instanceof Exit) { + //Exit node forwards input to parent frame + String outFrame = es.getFrameIter().getParentFrame().getFrame(); + int outIter = es.getFrameIter().getParentFrame().getIteration(); + FrameIter outParentFrame = es.getFrameIter().getParentFrame().getParentFrame(); + outFrameIter = new FrameIter(outFrame, outIter, outParentFrame); + } else if (o instanceof NextIteration) { + //NextIteration op: forwards its single input to its output varible in the current frame, but increments the iteration number + outFrameIter = es.getFrameIter().clone(); + outFrameIter.setIteration(outFrameIter.getIteration()); } else { - sb.append(". First 10 missing variables: "); - Iterator iter = missing.iterator(); - for( int i=0; i<10 && iter.hasNext(); i++ ){ - if(i > 0) - sb.append(","); - sb.append(iter.next()); + //Standard ops - output variable has same frame and iteration number as the input(s) + //Also loopCond, merge, while, etc + outFrameIter = es.getFrameIter(); + } + + + //Resolve the inputs to this execution step (op) to actual arrays + Set inputs = null; + Set allIterInputs = null; + Set constAndPhInputs = null; + DependencyList dl = dt.getDependencies(es); + + List inputNames = op.getInputsToOp(); + if (inputNames != null && !inputNames.isEmpty()) { + inputs = new HashSet<>(); + allIterInputs = new HashSet<>(); + constAndPhInputs = new HashSet<>(); + List deps = dl.getDependencies(); + if (deps != null && !deps.isEmpty()) { + for (ExecStep dep : deps) { + switch (dep.getType()) { + case OP: + case SWITCH_L: + case SWITCH_R: + //The current execution step depends on one output of the op "dep" + SameDiffOp toExecOp = sameDiff.getOps().get(es.getName()); + List inputsToExecOp = toExecOp.getInputsToOp(); + SameDiffOp inputOp = sameDiff.getOps().get(dep.getName()); + List inputOpOutNames = inputOp.getOutputsOfOp(); + for (String s : inputsToExecOp) { + if (inputOpOutNames.contains(s)) { + VarId vid = new VarId(s, dep.getFrameIter().getFrame(), dep.getFrameIter().getIteration(), dep.getFrameIter().getParentFrame()); + inputs.add(vid); + } + } + break; + case VARIABLE: + inputs.add(new VarId(dep.getName(), OUTER_FRAME, 0, null)); + break; + case CONSTANT: + case PLACEHOLDER: + constAndPhInputs.add(dep.getName()); + break; + default: + throw new UnsupportedOperationException("Not yet implemented: " + dep.getType()); + } + } } } - String s = sb.toString(); - throw new IllegalStateException(s); - } - - //Get any variable and execute it's corresponding op - VarId varToExec = availableForExec.remove(); - availableForExecSet.remove(varToExec); - if (nodeOutputs.containsKey(varToExec)) { - //Already processed this one. May occur if execution was triggered by a different output of a multi-output op - //But we'll still update its descendants to ensure they are marked as available - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), nodeOutputs.get(varToExec)); - } - updateDescendentsForExec(step, varToExec); - continue; - } - - //Get inputs to this variable. May be actual op inputs, or just control dependencies - Set inputsToVar = execInputs.get(varToExec); - VarId allIterInputVar = newVarId(varToExec.getVariable(), varToExec.getFrame(), 0, varToExec.getParentFrame()); - Set inputsToVarAllIter = execInputsAllIter.get(allIterInputVar); - Set constPhForVar = execConstInputs.get(varToExec.getVariable()); - - log.trace("Beginning execution step {}: variable {}", step, varToExec); - - if (sameDiff.getVariable(varToExec.getVariable()).isPlaceHolder()) { - //Variable is placeholder: do lookup - nodeOutputs.put(varToExec, placeholderValues.get(varToExec.getVariable())); - updateDescendentsForExec(step, varToExec); //Check + mark descendants as available for exec - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), placeholderValues.get(varToExec.getVariable())); - } - } else if (sameDiff.getVariable(varToExec.getVariable()).isConstant() || - sameDiff.getVariable(varToExec.getVariable()).getVariableType() == VariableType.VARIABLE) { - //Variable is constant: do lookup - //OR variable is VARIABLE type - i.e., a trainable parameter... - T phArr = getConstantOrVariable(varToExec.getVariable()); - Preconditions.checkNotNull(phArr, "Encountered null placeholder array for constant: %s", varToExec); - nodeOutputs.put(varToExec, phArr); - updateDescendentsForExec(step, varToExec); //Check + mark descendants as available for exec - if (variables.contains(varToExec.getVariable())) { //Check if required output - out.put(varToExec.getVariable(), phArr); - } - } else if (sameDiff.getVariableOutputOp(varToExec.getVariable()) != null) { - //Variable is the output of an op -> execute op - String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp(); + // Do execution of the op, in 2 steps + // (a) "Parameterize" the op - i.e., find and set the arrays on the op, allocate outputs, etc ready for execution + // (b) actually execute the operation + O parameterizedOp = getAndParameterizeOp(opName, outFrameIter, inputs, allIterInputs, constAndPhInputs, placeholderValues, reqOutputVariablesSet); + T[] opOutputValues = getOutputs(parameterizedOp, outFrameIter, inputs, allIterInputs, constAndPhInputs, listeners, at, batch, reqOutputVariablesSet); + List opOutVarNames = op.getOutputsOfOp(); - //Execute op - FrameIter frameIter = varToExec.toFrameIter(); - O parameterizedOp = getAndParameterizeOp(opName, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, placeholderValues); - T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, at, batch); - - - //Post execution: work out what is now available for exec - String[] opOutputVarNames = sameDiff.getOpById(opName).outputVariablesNames(); - - Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" + + Preconditions.checkState(opOutputValues.length == opOutVarNames.size(), "Unexpected number of outputs from executed op %s:" + " got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length, - opOutputVarNames.length, opOutputVarNames); + opOutVarNames.size(), opOutVarNames); - for (int i = 0; i < opOutputVarNames.length; i++) { - if (opOutputValues[i] == null && parameterizedOp instanceof Switch) { - //Skip null - for switch op only. Switch op forwards input to only one of its outputs - //All other ops should not + //Store the op outputs + for (int i = 0; i < opOutputValues.length; i++) { + if (opOutputValues[i] == null && op.getOp() instanceof Switch) { + //Switch op only forwards the input to one of the outputs continue; } - Preconditions.checkNotNull(opOutputValues[i], "Encountered null output (output %s) for op %s at execution step %s", i, parameterizedOp.getClass().getSimpleName(), step); + String n = opOutVarNames.get(i); + VarId vid = new VarId(n, outFrameIter.getFrame(), outFrameIter.getIteration(), outFrameIter.getParentFrame()); + nodeOutputs.put(vid, opOutputValues[i]); - VarId outputVarId; - boolean addDummyOutput = false; - if (parameterizedOp instanceof Enter) { - //Enter op: output is variable in a new (specified) frame, iteration 0. - String frame = ((Enter) parameterizedOp).getFrameName(); - boolean isConstant = ((Enter) parameterizedOp).isConstant(); - FrameIter outParentFrame = varToExec.getParentFrame(); - if(isConstant && outParentFrame != null){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - outParentFrame = outParentFrame.clone(); - FrameIter toZero = outParentFrame; - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - outputVarId = newVarId(opOutputVarNames[i], frame, 0, outParentFrame); - addDummyOutput = true; - } else if (parameterizedOp instanceof Exit) { - //Exit node forwards input to parent frame (which is already reflected in varToExec) - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else if (parameterizedOp instanceof NextIteration) { - //NextIteration op: forwards its single input to its output varible in the current frame, but increments the iteration number - //Note that varToExec has already had its iteration number incremented by 1 (relative to its input) in updateDescendentsForExec... so don't increment here - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else if (parameterizedOp instanceof LoopCond) { - //LoopCond just forwards input to output - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - addDummyOutput = true; - } else { - //Standard ops - output variable has same frame and iteration number as the input(s) - outputVarId = newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()); - } - - if(addDummyOutput){ - //For ops like enter/exit/nextiteration, these don't have a real output for that node - //But, we still want an entry in nodeOutputs, which we also use for checking if an op has already been executed - nodeOutputs.put(newVarId(opOutputVarNames[i], varToExec.getFrame(), varToExec.getIteration(), varToExec.getParentFrame()), null); - } - - nodeOutputs.put(outputVarId, opOutputValues[i]); - updateDescendentsForExec(step, outputVarId); //Check + mark descendants as available for exec - - if (variables.contains(opOutputVarNames[i])) { //Check if required output - out.put(opOutputVarNames[i], opOutputValues[i]); + if (allRequired.contains(n)) { + out.put(n, opOutputValues[i]); } } + + //Post execution: update dependency tracker so we know what is available to execute next, given we now + // have these new values + if (o instanceof Switch) { + /* + Switch is a special case: only one output/branch is considered to exist post execution. + Unlike every other type of op, only 1 of 2 output arrays is actually executed. + For dependency tracking purposes, this is why we have SWITCH_L and _R execution types. + If we just depended on the op, the dependency tracker would incorrectly conclude that ops relying on + both branches (i.e., including the unavailable one) can now be executed + */ + skipDepUpdate = true; + skipMarkSatisfied = true; + int nullCount = (opOutputValues[0] == null ? 1 : 0) + (opOutputValues[1] == null ? 1 : 0); + Preconditions.checkState(nullCount == 1, "Expected exactly one output to be present for switch ops, got %s", nullCount); + boolean left = opOutputValues[0] != null; + ExecStep branch; + if (left) { + branch = new ExecStep(ExecType.SWITCH_L, es.getName(), es.getFrameIter()); + } else { + branch = new ExecStep(ExecType.SWITCH_R, es.getName(), es.getFrameIter()); + } + updateDescendantDeps(branch, outFrameIter); + dt.markSatisfied(branch, true); + } else if (o instanceof Enter) { + //Enter op: we want to say that the inner frame is executed... + skipDepUpdate = true; + skipMarkSatisfied = true; + Enter e = (Enter) o; + FrameIter fi = new FrameIter(e.getFrameName(), 0, es.getFrameIter()); + ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi); + updateDescendantDeps(exec, fi); + dt.markSatisfied(exec, true); + } else if (o instanceof Exit) { + //Exit op: we want to say that the parent frame is executed... + skipDepUpdate = true; + skipMarkSatisfied = true; + FrameIter fi = es.getFrameIter().getParentFrame(); + ExecStep exec = new ExecStep(ExecType.OP, es.getName(), fi); + updateDescendantDeps(exec, fi); + dt.markSatisfied(exec, true); + } + + /* + Edge case for TensorFlow import control dependencies: for some reason, TF allows op control dependencies + like /while/x -> SomeConstant - i.e., a constant depending on something inside a scope. + This should be handled with an enter op, but TF doesn't always use this :/ + Note that this is equivalent to marking the control dependency as satisfied on the first iteration + TODO double check that this is exactly the same behaviour as TF - otherwise this approach might fail in + some rare cases that rely on the constant/variable not being available + */ + List cdFor = op.getControlDepFor(); + if (cdFor != null) { + ExecStep cdEs = new ExecStep(ExecType.CONTROL_DEP, opName, null); + if (!dt.isSatisfied(cdEs)) { + dt.markSatisfied(cdEs, true); + } + } + } else { - Variable v = sameDiff.getVariables().get(varToExec.getVariable()); - throw new IllegalStateException("Unable to execute variable " + varToExec + " of type " + v.getVariable().getVariableType()); + //Should never happen + throw new RuntimeException("Unknown ExecStep: " + es); } + + //Standard ops + if (!skipDepUpdate) { + updateDescendantDeps(es, outFrameIter); + } + if (!skipMarkSatisfied) { + dt.markSatisfied(es, true); + } + step++; } + //TODO we should clear the node outputs map to get rid of the invalid (closed, out of workspace, etc) arrays - //TODO under what circumstances should we clear the nodeOutputs map? - //TODO when should we close the workspace? (Might want to leave it open if we expect to re-use) - + out = postProcessOutput(out); //Hook-in for subclass sessions, if needed return out; } - protected void initSubgraph(List variables) { + /** + * Add the control dependency from Op -> variable + * + * @param es Execution step for the variable + * @param v Variable + */ + protected void addVarControlDeps(ExecStep es, Variable v) { + List cds = v.getControlDeps(); + if (cds != null) { + for (String s : cds) { + ExecStep controlES = new ExecStep(ExecType.CONTROL_DEP, s, null); + dt.addDependency(es, controlES); //Before this variable can be considered available for use, we need specified op to be executed + } + } + } + + /** + * Execution failed - can't calculate all requested outputs, and there's nothing left to calculate. + * Throws an exception with a useful message + * + * @param userRequestedUnique All outputs that the user requseted + * @param out Current outputs + * @param step Execution step + */ + protected void execFailed(Set userRequestedUnique, Map out, int step) { + int missingCount = userRequestedUnique.size() - out.size(); + StringBuilder sb = new StringBuilder(); + sb.append("No variable are available for execution at step ") + .append(step).append(": ").append(missingCount).append(" values remaining"); + Set missing = new HashSet<>(); + for (String s : userRequestedUnique) { + if (!out.containsKey(s)) { + missing.add(s); + } + } + if (missingCount <= 10) { + sb.append(". Missing variables: "); + sb.append(missing); + } else { + sb.append(". First 10 missing variables: "); + Iterator iter = missing.iterator(); + for (int i = 0; i < 10 && iter.hasNext(); i++) { + if (i > 0) + sb.append(","); + sb.append(iter.next()); + } + } + String s = sb.toString(); +// System.out.println(sameDiff.summary()); + throw new IllegalStateException(s); + } + + /** + * Update the descendant dependencies + * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker + * This is for a specific frame and iteration, for both sides of the dependency (in and out) + * + * @param justExecuted The execution step that has just completed + * @param outFrameIter The frame/iteration of the output + */ + protected void updateDescendantDeps(ExecStep justExecuted, FrameIter outFrameIter) { + ExecType t = justExecuted.getType(); + String n = justExecuted.getName(); + if (justExecuted.getType() == ExecType.OP) { + SameDiffOp op = sameDiff.getOps().get(n); + List outNames = op.getOutputsOfOp(); + for (String s : outNames) { + Variable v = sameDiff.getVariables().get(s); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + + + //Also add control dependencies (variable) + List cdForOps = v.getControlDepsForOp(); + if (cdForOps != null) { + for (String opName : cdForOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + } + } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) { + Variable v = sameDiff.getVariables().get(n); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + addDependenciesForOp(opName, outFrameIter); + } + } + } + } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) { + SameDiffOp op = sameDiff.getOps().get(n); + List outNames = op.getOutputsOfOp(); + String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1)); + Variable v = sameDiff.getVariables().get(branchVarName); + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + } else { + throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted); + } + } + + /** + * Suppose operation X has just been executed. + * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp + * (which includes X, but may not only be X) + * + * @param opName Name of the op + * @param depFrameIter Frame/iteration of the op instance to be executed + */ + protected void addDependenciesForOp(String opName, FrameIter depFrameIter) { + SameDiffOp op = sameDiff.getOps().get(opName); + List inputs = op.getInputsToOp(); + List cdOps = op.getControlDeps(); + List cdVars = op.getVarControlDeps(); + + ExecStep es = new ExecStep(ExecType.OP, opName, depFrameIter); + if (!(op.getOp() instanceof NextIteration) && dt.hasDependency(es)) { + //Already processed this once. We only add dependencies once per op (for a given frame/iteration) + return; + } + + if (op.getOp() instanceof Merge) { + //Merge ops are a special case: they can be executed with EITHER ONE of the inputs available - unlike every + // other op, we don't need all inputs, just one, before it can be executed + Variable v0 = sameDiff.getVariables().get(inputs.get(0)); + Variable v1 = sameDiff.getVariables().get(inputs.get(1)); + + ExecStep or0 = getExecStepForVar(v0.getName(), depFrameIter); + ExecStep or1 = getExecStepForVar(v1.getName(), depFrameIter); + dt.addOrDependency(es, or0, or1); + } else if (op.getOp() instanceof NextIteration) { + //For NextIteration, dependencies should be of the form X(iter) -> NextIter(iter+1) + FrameIter fi = depFrameIter.clone(); + fi.setIteration(fi.getIteration() + 1); + es = new ExecStep(ExecType.OP, opName, fi); + for (String s : inputs) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } else { + for (String s : inputs) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } + + if (cdOps != null) { + for (String s : cdOps) { + ExecStep req = getExecStepForVar(s, depFrameIter); + dt.addDependency(es, req); + } + } + + if (cdVars != null) { + for (String s : cdVars) { + + } + } + } + + /** + * Get the ExecStep for the given variable, given execution is happening at the specified frame/iteration + */ + protected ExecStep getExecStepForVar(String varName, FrameIter frameIter) { + Variable v = sameDiff.getVariables().get(varName); + VariableType vt = v.getVariable().getVariableType(); + if (vt == VariableType.VARIABLE) { + return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); + } else if (vt == VariableType.PLACEHOLDER) { + return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); + } else if (vt == VariableType.CONSTANT) { + return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); + } else { + //Array type. Must be output of an op + String outOfOp = v.getOutputOfOp(); + SameDiffOp sdo = sameDiff.getOps().get(outOfOp); + if (sdo.getOp() instanceof Switch) { + //For dependency tracking purposes, we track left and right output branches of switch op separately + //Otherwise, ops depending both branches will be marked as available if we just rely on "op has been executed" + List opOutputs = sdo.getOutputsOfOp(); + int idx = opOutputs.indexOf(v.getName()); + if (idx == 0) { + //Left branch + return new ExecStep(ExecType.SWITCH_L, outOfOp, frameIter); + } else if (idx == 1) { + //Right branch + return new ExecStep(ExecType.SWITCH_R, outOfOp, frameIter); + } else { + //Should never happen + throw new IllegalStateException("Expected variable \"" + v.getName() + "\" to be an output of operation \"" + + outOfOp + "\", but op output variables are: " + opOutputs); + } + } else if (sdo.getOp() instanceof Enter) { + Enter e = (Enter) sdo.getOp(); + + //For enter ops, "constant=true" enter ops are available for ALL iterations, hence use iter=0 + //For constant=false, these are only available at iteration 0 - so use *current* iteration, same as all other ops + // (which is this case, won't be triggered on iter > 0 - as desired/expected) + if (e.isConstant()) { + FrameIter fi = frameIter.clone(); + fi.setIteration(0); + + //Nested constant enter case: Iteration 0 all the way down... + String inVarName = sdo.getInputsToOp().get(0); + FrameIter parentFrame = fi.getParentFrame(); + while (parentFrame != null) { + Variable var = sameDiff.getVariables().get(inVarName); + if (var.getOutputOfOp() != null) { + String opName = var.getOutputOfOp(); + SameDiffOp sdo2 = sameDiff.getOps().get(opName); + if (sdo2.getOp() instanceof Enter) { + Enter e2 = (Enter) sdo.getOp(); + if (e2.isConstant()) { + parentFrame.setIteration(0); + parentFrame = parentFrame.getParentFrame(); + inVarName = sdo2.getInputsToOp().get(0); + } else { + break; + } + } else { + break; + } + } else { + break; + } + } + + return new ExecStep(ExecType.OP, outOfOp, fi); + } + + //Intentional fall-through to default case + } + return new ExecStep(ExecType.OP, outOfOp, frameIter); + } + } + + /** + * Initialize the subgraph - the subgraph and subgraphOps sets + * This works our what ops and variables we might need to execute to get the requested outputs. + * In general, this is a subset of the graph. + * + * @param variables Set of output variables we need + */ + protected void initSubgraph(Set variables) { //Step 1: determine subgraph structure we actually need to execute Queue processingQueue = new LinkedList<>(variables); @@ -434,21 +762,20 @@ public abstract class AbstractSession { // until after execution of some other ops (for example, in conditional operations) numInputs += controlDeps.size(); } - if (numInputs == 0) { - VarId vid = newVarId(varName, OUTER_FRAME, 0, null); - if(!availableForExecSet.contains(vid)) { - availableForExec.add(vid); - availableForExecSet.add(vid); - } - execInputs.put(vid, new HashSet()); + if (numInputs == 0 && opName != null) { + zeroInputOpsInSubgraph.add(opName); } subgraph.add(varName); - if(controlDeps != null){ + if (opName != null) { + subgraphOps.add(opName); + } + + if (controlDeps != null) { //If variable has control dependencies, it's not available right away... to make it available, // we need the "inputs" to be available first. This is mainly used for TF import. - for(String s : controlDeps){ - if(!subgraph.contains(s)){ + for (String s : controlDeps) { + if (!subgraph.contains(s)) { processingQueue.add(s); } } @@ -477,359 +804,28 @@ public abstract class AbstractSession { } } - /** - * This method should be called for a variable once it's array is ready for use. - * For example, post op execution, etc - * - * @param execStep Current execution step (mainly for debugging) - * @param executedVar Variable that was just executed - */ - protected void updateDescendentsForExec(int execStep, VarId executedVar) { - String varName = executedVar.getVariable(); - Variable var = sameDiff.getVariables().get(executedVar.getVariable()); - //Find any ops (or variables with control dependencies) that this is required for execution of and check if now available for exec - List l = sameDiff.getVariables().get(executedVar.getVariable()).getInputsForOp(); - String[] inputForOps = l == null ? null : l.toArray(new String[l.size()]); //Just executed variable is input to these ops - List controlDepForVars = var.getControlDepsForVar(); //Just executed variable is a control dependency for these variables - List controlDepForOps = var.getControlDepsForOp(); //Just executed variable is a control dependency for these ops - - - SDVariable v = var.getVariable(); - boolean isConstOrPhInput = v.isPlaceHolder() || v.isConstant(); - - //After a variable becomes available, we should look at the ops this is an input to, and check if we can execute this op now... - if (inputForOps != null) { - for (String opName : inputForOps) { - - DifferentialFunction fn = sameDiff.getOpById(opName); - if (fn instanceof Merge) { - //Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once... - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected only 1 output variable for merge op, got %s", opOutputs); - VarId outVarId = newVarId(opOutputs.get(0), executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked merge op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof Enter) { - //Enter node: available for exec when any of its inputs are available for exec - // Note input feeds from one frame to another - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected only 1 output variable for enter op, got %s", opOutputs); - Enter e = (Enter) fn; - boolean isConstant = e.isConstant(); - VarId outVarId = newVarId(opOutputs.get(0), e.getFrameName(), 0, executedVar.toFrameIter()); //Note: parent frame of output op is enter var's *current* frame - - if(isConstant && executedVar.getParentFrame() != null){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - outVarId.setParentFrame(outVarId.getParentFrame().clone()); - FrameIter fi = outVarId.getParentFrame(); - while(fi != null){ - fi.setIteration(0); - fi = fi.getParentFrame(); - } - } - - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked enter op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Also record the parent frame: we'll need this when we get to the corresponding exit ops - frameParents.put(e.getFrameName(), executedVar.toFrameIter()); - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof Exit) { - //Exit node forwards input to parent frame - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - FrameIter parentFrame = frameParents.get(executedVar.getFrame()); - Preconditions.checkNotNull(parentFrame, "Parent frame must not be null for exit op: variable to exec is %s", executedVar); - - VarId outVarId = new VarId(opOutputs.get(0), parentFrame.getFrame(), parentFrame.getIteration(), executedVar.getParentFrame().getParentFrame()); //Parent frame of output is parent of current parent - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked Exit op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } else if (fn instanceof NextIteration) { - //NextIteration is available for execution when its single input is available - //NextIteration op: forwards its single input to the output of the current frame, but increments the iteration number - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - Preconditions.checkState(opOutputs.size() == 1, "Expected exactly 1 output for NextIteration op: got %s", opOutputs); - VarId outVarId = newVarId(opOutputs.get(0), executedVar.getFrame(), executedVar.getIteration() + 1, executedVar.getParentFrame()); - - if (!nodeOutputs.containsKey(outVarId) && subgraph.contains(outVarId.getVariable()) && !availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked NextIteration op ({}) variable {} as available for execution: input {} is now available", opName, outVarId, executedVar); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - continue; - } - //Note for LoopCond: just forwards input to output - so basically handle it the same as other ops here - - - //Can execute this op - and hence get it's output variables - if all inputs (and control deps) are available - String[] inputsThisOp = fn.argNames(); - boolean allInputsAvailable = true; - if (inputsThisOp != null) { - allInputsAvailable = allInputsAvailable(execStep, inputsThisOp, executedVar); - } - - //Check Op control dependencies - List opControlDeps = sameDiff.getOps().get(opName).getControlDeps(); - if (opControlDeps != null && allInputsAvailable) { - for (String cd : opControlDeps) { - VarId vcd = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if (!nodeOutputs.containsKey(vcd)) { - allInputsAvailable = false; - break; - } - } - } - - List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); - if (opOutputs != null) { - - for (String s : opOutputs) { - //The input (for normal ops - not Enter/Exit/NextIteration) have the same frame and iteration number as the just executed var - //Exception 1 to this: constants. If variable is a constant, then it's always iteration 0 of the main frame (unless variable control dep exists) - //Exception 2 to this: placeholders. As above - SDVariable sdv = sameDiff.getVariable(s); - Variable variable = sameDiff.getVariables().get(s); - VarId outVarId; - if (sdv.isConstant() || sdv.isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || var.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - outVarId = newVarId(s, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant exists - //We should look up based on x's frame/iteration - outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - - //Mark that we need the specified input to calculate this output - addToExecInputs(isConstOrPhInput, executedVar, outVarId); - - //Check variable control dependencies, for each of the op outputs - if(allInputsAvailable && variable.getControlDeps() != null && !variable.getControlDeps().isEmpty()){ - //If one of the op outputs has a control dependency input, make sure this is available - // before executing the op - //For example, if z=add(x,y) and control dependency A->z exists, then don't execute op until A is available - for(String cd : variable.getControlDeps()){ - Variable cdVar = sameDiff.getVariables().get(cd); - VarId cdVarId = null; - if (cdVar.getVariable().isConstant() || cdVar.getVariable().isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || var.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - cdVarId = newVarId(cd, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant -> thisOutput exists - //We should look up based on x's frame/iteration - cdVarId = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - cdVarId = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - allInputsAvailable &= nodeOutputs.containsKey(cdVarId); - if(!allInputsAvailable) - break; - } - } - } - - if (allInputsAvailable) { - //Op can be executed -> variables as output are available for exec - - for (String s : opOutputs) { - if (!subgraph.contains(s)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - VarId vid = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(vid)) { - availableForExec.add(vid); - availableForExecSet.add(vid); - log.trace("Marked variable as available for execution: {} - output of op {} ({}) with op inputs {}", vid, opName, - fn.getClass().getSimpleName(), (inputsThisOp == null ? "" : Arrays.toString(inputsThisOp))); - } - } - } - } - - } - } - - //Also check variable control dependencies... if control dependency varX->varY exists and varY is a constant/placeholder/variable, - // then it's not going to be triggered by the op-based check above - if(controlDepForVars != null){ - for(String s : controlDepForVars){ - if (!subgraph.contains(s)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - - SDVariable depFor = sameDiff.getVariable(s); - if(depFor.getVariableType() != VariableType.ARRAY){ - //Control dependency executedVar -> s exists, where "s" is not the output of an op - //Even thought this is a constant, we'll inherit the frame and iteration from the control dependency - // otherwise, we lose this frame/iteration information for any downstream variables using the constant within a frame - VarId outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked variable as available for execution: {} - control dependency {} -> {} exists", outVarId, executedVar.getVariable(), s); - } - } else { - //Another edge case: OpX has output varY (with no inputs), and control dependency executedVar -> varY exists - //We should check if OpX is now available for execution... - //Similarly, if we have OpX with inputs, but we're only waiting on a varible control dependency Z -> X - // then we might not get triggered as available for exec above either - String opName = sameDiff.getVariables().get(s).getOutputOfOp(); - if(opName != null){ - SameDiffOp op = sameDiff.getOps().get(opName); - boolean allInputsAvailable = true; - if(op.getInputsToOp() != null && !op.getInputsToOp().isEmpty()){ - List inputList = op.getInputsToOp(); - allInputsAvailable = allInputsAvailable(execStep, inputList.toArray(new String[inputList.size()]), executedVar); - } - - if(allInputsAvailable && op.getControlDeps() != null){ - for(String cd : op.getControlDeps()){ - VarId vid = newVarId(cd, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); //Note: is array type, therefore has same frame/iter as parent - allInputsAvailable &= nodeOutputs.containsKey(vid); - if(!allInputsAvailable) - break; - } - } - if(allInputsAvailable){ - for(String opOutput : op.getOutputsOfOp()){ - Variable v2 = sameDiff.getVariables().get(opOutput); - if(v2.getControlDeps() != null){ - for(String s2 : v2.getControlDeps()){ - VarId vid = newVarId(s2, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); //Note: is array type, therefore has same frame/iter as parent - allInputsAvailable &= nodeOutputs.containsKey(vid); - if(!allInputsAvailable) - break; - } - } - } - } - - if(allInputsAvailable){ - VarId outVarId = newVarId(s, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - log.trace("Marked variable as available for execution: {} - is output of op {} with no inputs (but has control dependencies)", outVarId, op.getName()); - } - } - } - } - } - } - - //Edge case: if control dependency varX->opY exists, and opY doesn't have any inputs, it also can't be triggeered - // (made available for execution) by any of the previous checks. For any ops that DO have inputs, they will - // be triggered already - if(controlDepForOps != null){ - for(String opName : controlDepForOps){ - SameDiffOp op = sameDiff.getOps().get(opName); - if(op.getInputsToOp() == null || op.getInputsToOp().isEmpty()){ - for(String out : op.getOutputsOfOp()){ - if (!subgraph.contains(out)) - continue; //Don't need this variable to calculate requested outputs - so don't mark as available for execution - - //TODO is it possible to have both variable and op control dependencies?? - VarId outVarId = newVarId(out, OUTER_FRAME, 0, null); - if(!availableForExecSet.contains(outVarId)) { - availableForExec.add(outVarId); - availableForExecSet.add(outVarId); - log.trace("Marked variable as available for execution: {} - op control dependency variable {} -> op {} exists", outVarId, executedVar.getVariable(), opName); - } - } - } - } - } - } - - protected boolean allInputsAvailable(int execStep, String[] inputsThisOp, VarId executedVar){ - for (String in : inputsThisOp) { - //The input (for normal ops - not Enter/Exit/NextIteration) have the same frame and iteration number as the just executed var - //Exception 1 to this: constants. If variable is a constant, then it's always iteration 0 of the main frame (unless variable control dep exists) - //Exception 2 to this: placeholders. As above - //TODO Add SameDiff.isConstant(String) method... or SDVariable.isConstant() (or both) - SDVariable sdv = sameDiff.getVariable(in); - Variable variable = sameDiff.getVariables().get(in); - VarId vid; - boolean nestedWhile = false; - if (sdv.isConstant() || sdv.isPlaceHolder()) { - //Constant - if(variable.getControlDeps() == null || variable.getControlDeps().isEmpty()){ - //Standard case - do a lookup of placeholder/constant - vid = newVarId(in, OUTER_FRAME, 0, null); - } else { - //Edge case: control dependency x -> constant exists - //We should look up based on x's frame/iteration - vid = newVarId(in, executedVar.getFrame(), executedVar.getIteration(), executedVar.getParentFrame()); - } - } else { - //Normal (non-constant) - //Edge case: "Enter" nodes always have iteration 0 by definition. In some TF graphs/loops, the enter node - // is used in multiple iterations (like, a constant in a loop condition) - not just the first iteration - int iter = executedVar.getIteration(); - FrameIter parentFrame = executedVar.getParentFrame(); - if(sdv.getVariableType() == VariableType.ARRAY && sameDiff.getOps().get(variable.getOutputOfOp()).getOp() instanceof Enter){ - iter = 0; - Enter e = (Enter)sameDiff.getOps().get(variable.getOutputOfOp()).getOp(); - if(e.isConstant()){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - parentFrame = parentFrame.clone(); - FrameIter toZero = parentFrame; - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - } - vid = newVarId(in, executedVar.getFrame(), iter, parentFrame); - } - if (!nodeOutputs.containsKey(vid)) { - return false; - } - } - return true; - } - /** * Preprocess the placeholder values, if required. * Mainly reserved for casting in the case of InferenceSession + * * @param placeholders Placeholders to preprocess. * @return Preprocessed placeholders */ - protected Map preprocessPlaceholders(Map placeholders){ + protected Map preprocessPlaceholders(Map placeholders, At at) { return placeholders; } + /** + * Post process the session output values, if required. + * Override if required in session subclasses + * + * @param output Output to be returned to the user + * @return Post processed output + */ + protected Map postProcessOutput(Map output) { + return output; + } + /** * Get the constant or variable output - for example, constant array or constant shape. * Note that both constants and variables (i.e., VariableType.CONSTANT and VariableType.VARIABLE) are the same @@ -848,9 +844,11 @@ public abstract class AbstractSession { * @param inputs The inputs to the op (excluding constants/placeholders) - for the specific frame + iteration * @param allIterInputs The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0) * @param constAndPhInputs The constant and placeholder inputs - used for all frames/iterations + * @param allReqVariables All required variables requested for the current session execution (not just the current op outputs) * @return The parameterized op */ - public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues); + public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, + Map placeholderValues, Set allReqVariables); /** * Execute the op - calculate INDArrays, or shape info, etc @@ -858,88 +856,49 @@ public abstract class AbstractSession { * @param op Operation to exit. This should be parameterized (i.e., all inputs set) * @param outputFrameIter The frame and iteration of the outputs * @param inputs The specific input arrays for the op + * @param allReqVariables All required variables requested for the current session execution (not just the current op outputs) * @return The outputs of the op */ public abstract T[] getOutputs(O op, FrameIter outputFrameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, - List listeners, At at, MultiDataSet batch); + List listeners, At at, MultiDataSet batch, Set allReqVariables); /** - * This method is used to record that the specified input is required for calculating the specified output. - * While the graph structure itself provides us with the (input vars) -> op -> (output vars) type structure, it - * doesn't tell us exactly which array copy (i.e., variable + frame + iteration) to use as which copy of the output - * variable (variable + frame + iteration). - *

- * This method is basically used to store information we need to parameterize ops for execution later - * - * @param isConstOrPh If true: inputVar is either a constant or a placeholder - * @param inputVar Input variable (i.e., the X in (X, ...) -> op -> (forVariable,...)) - * @param forVariable Output variable (i.e., the Y in (inputVar, ...) -> op -> (Y,...)) + * Get the VarId from the specified name. The VarId should be in one or the other of the collections, + * and only one VarId with that name should exist */ - protected void addToExecInputs(boolean isConstOrPh, VarId inputVar, VarId forVariable) { - if (!subgraph.contains(forVariable.getVariable())) - return; //Not needed to calculate requested outputs, so no need to record it's inputs + protected static VarId lookup(String name, Collection varIds, Collection varIds2, boolean exceptionOnNotFound) { + VarId vid = varIds == null ? null : lookup(name, varIds, false); + if (vid == null && varIds2 != null) + vid = lookup(name, varIds2, false); - if (isConstOrPh) { - //Mark that outVar needs to use placeholder/constant (same regardless of frame/iter) - if (!execConstInputs.containsKey(forVariable.getVariable())) - execConstInputs.put(forVariable.getVariable(), new HashSet()); - execConstInputs.get(forVariable.getVariable()).add(inputVar.getVariable()); - } else { - //Mark that outVar needs this specific executedVar (i.e., specific frame/iteration) - //However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example) - Variable v = sameDiff.getVariables().get(inputVar.getVariable()); - boolean isEnter = sameDiff.getVariableOutputOp(v.getVariable().getVarName()) instanceof Enter; - - if(isEnter){ - VarId iter0 = forVariable; - if(iter0.getIteration() != 0){ - iter0 = newVarId(iter0.getVariable(), iter0.getFrame(), 0, forVariable.getParentFrame()); - } - - Variable var = sameDiff.getVariables().get(inputVar.getVariable()); - Enter e = (Enter) sameDiff.getOps().get(var.getOutputOfOp()).getOp(); - if(e.isConstant()){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - iter0.setParentFrame(iter0.getParentFrame().clone()); - FrameIter toZero = iter0.getParentFrame(); - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - } - - if(!execInputsAllIter.containsKey(iter0)) - execInputsAllIter.put(iter0, new HashSet()); - execInputsAllIter.get(iter0).add(inputVar); - } else { - //Most variables - if (!execInputs.containsKey(forVariable)) - execInputs.put(forVariable, new HashSet()); - execInputs.get(forVariable).add(inputVar); - } + if (vid == null && exceptionOnNotFound) { + throw new RuntimeException("Could not find VarId for input \"" + name + "\""); } + return vid; } - - protected static VarId lookup(String name, Collection varIds, boolean exceptionOnNotFound){ - for(VarId vid : varIds){ - if(vid.getVariable().equals(name)){ + /** + * Get the VarId from the specified name. The VarId should be in the collection, + * and only one VarId with that name should exist + */ + protected static VarId lookup(String name, Collection varIds, boolean exceptionOnNotFound) { + for (VarId vid : varIds) { + if (vid.getVariable().equals(name)) { return vid; } } - if(exceptionOnNotFound) { + if (exceptionOnNotFound) { throw new RuntimeException("Could not find VarId to input " + name); } return null; } - /* - VarId: identifies a variable in a specific frame and frame iteration - Used for 2 places: - (a) to identify variables that are available for execution - (b) to store results + /** + * VarId: identifies the value of a variable in a specific frame and frame iteration
+ * Note that frames can be nested - which generally represents nested loop situations.
+ * Used for 2 places:
+ * (a) to identify variables that are available for execution
+ * (b) to store results
*/ @Data @AllArgsConstructor @@ -954,13 +913,17 @@ public abstract class AbstractSession { return "VarId(\"" + variable + "\",\"" + frame + "\"," + iteration + ",parent=" + parentFrame + ")"; } + /** + * @return FrameIter corresponding to the VarId + */ public FrameIter toFrameIter() { return new FrameIter(frame, iteration, parentFrame); } } - /* - FrameIter: Identifies frame + iteration. Used mainly for for exit nodes + /** + * FrameIter: Identifies a frame + iteration (but not a specific op or variable).
+ * Note that frames can be nested - which generally represents nested loop situations. */ @Data @AllArgsConstructor @@ -970,13 +933,82 @@ public abstract class AbstractSession { private FrameIter parentFrame; @Override - public String toString(){ + public String toString() { return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame.toString()) + ")"; } @Override - public FrameIter clone(){ + public FrameIter clone() { return new FrameIter(frame, iteration, (parentFrame == null ? null : parentFrame.clone())); } + + public VarId toVarId(String name) { + return new VarId(name, frame, iteration, parentFrame); + } } + + /** + * ExecType: Execution type, as used in ExecStep
+ * OP: Operation execution
+ * VARIABLE: Variable "execution", mainly used to trigger ops that depend on the variable
+ * CONSTANT: As per variable
+ * PLACEHOLDER: As per variable
+ * SWITCH_L and SWITCH_R: This is a bit of a hack to account for the fact that only one of + * the switch branches (left or right) will ever be available; without this, once the switch op is executed, we'll + * (incorrectly) conclude that *both* branches can be executed
+ * EXEC_START: Start of execution
+ * CONTROL_DEP: Control dependency for op. Used for TF import, due to its odd "constant depends on op in a frame" behaviour + */ + protected enum ExecType {OP, VARIABLE, CONSTANT, PLACEHOLDER, SWITCH_L, SWITCH_R, EXEC_START, CONTROL_DEP} + + ; + + /** + * ExecStep represents a single execution step, for a single op (or variable/constant etc) at a specific frame/iteration + */ + @Getter + @EqualsAndHashCode + protected static class ExecStep { + protected final ExecType type; + protected final String name; + protected final FrameIter frameIter; + + protected ExecStep(@NonNull ExecType execType, @NonNull String name, FrameIter frameIter) { + this.type = execType; + this.name = name; + this.frameIter = frameIter; + } + + protected VarId toVarId() { + return new VarId(name, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()); + } + + @Override + public String toString() { + return "ExecStep(" + type + ",name=\"" + name + "\"," + frameIter + ")"; + } + } + + /** + * Used in getting the next ExecStep that matches the specified (current) frame/iteration + */ + @Data + @AllArgsConstructor + @NoArgsConstructor + protected class ExecStepPredicate implements Predicate { + + protected String currentFrame; + protected int currentFrameIter; + protected FrameIter currParentFrame; + + @Override + public boolean test(ExecStep execStep) { + return currentFrame.equals(execStep.getFrameIter().getFrame()) && + currentFrameIter == execStep.getFrameIter().getIteration() && + (currParentFrame == null && execStep.getFrameIter().getParentFrame() == null || + currParentFrame.equals(execStep.getFrameIter().getParentFrame())); + } + } + + ; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java deleted file mode 100644 index 56a6a406e..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java +++ /dev/null @@ -1,107 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.samediff.internal; - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.listeners.At; -import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.nd4j.linalg.dataset.api.MultiDataSet; - -/** - * Infer datatypes for all variables. - * Optionally update the datatypes of variables as we go - */ -public class DataTypesSession extends AbstractSession { - - protected boolean dynamicUpdate; - - /** - * @param sameDiff SameDiff instance - * @param dynamicUpdate If true: Dynamically update the datatypes as we go - */ - public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) { - super(sameDiff); - this.dynamicUpdate = dynamicUpdate; - } - - @Override - public DataType getConstantOrVariable(String variableName) { - //Variables and constants should always have datatype available - DataType dt = sameDiff.getVariable(variableName).dataType(); - Preconditions.checkNotNull(dt, "No datatype available for variable %s", variableName); - return dt; - } - - @Override - public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues) { - DifferentialFunction df = sameDiff.getOpById(opName); - List inputDataTypes = new ArrayList<>(); - for(SDVariable v : df.args()){ - DataType dt = v.dataType(); - if(dt != null){ - inputDataTypes.add(dt); - } else { - String s = v.getVarName(); - for(VarId vid : inputs){ - if(vid.getVariable().equals(s)){ - DataType dt2 = nodeOutputs.get(vid); - Preconditions.checkNotNull(dt2, "No datatype for %s", vid); - inputDataTypes.add(dt2); - } - } - } - } - return new DataTypeCalc(df, inputDataTypes); - } - - @Override - public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set inputs, Set allIterInputs, - Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { - List outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes()); - - if(dynamicUpdate) { - SDVariable[] fnOutputs = op.getFn().outputVariables(); - for( int i=0; i inputTypes; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java new file mode 100644 index 000000000..c718bf152 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyList.java @@ -0,0 +1,20 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.primitives.Pair; + +import java.util.List; + +/** + * A list of dependencies, used in {@link AbstractDependencyTracker} + * + * @author Alex Black + */ +@Data +@AllArgsConstructor +public class DependencyList { + private T dependencyFor; + private List dependencies; + private List> orDependencies; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java new file mode 100644 index 000000000..d172221ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DependencyTracker.java @@ -0,0 +1,38 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Dependenci tracker. See {@link AbstractDependencyTracker} for details + * + * @param For a dependency X -> Y, Y has type T + * @param For a dependency X -> Y, X has type D + */ +@Slf4j +public class DependencyTracker extends AbstractDependencyTracker { + + @Override + protected Map newTMap() { + return new HashMap<>(); + } + + @Override + protected Set newTSet() { + return new HashSet<>(); + } + + @Override + protected String toStringT(T t) { + return t.toString(); + } + + @Override + protected String toStringD(D d) { + return d.toString(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java new file mode 100644 index 000000000..5e7e46c80 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/IdentityDependencyTracker.java @@ -0,0 +1,44 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * Object dependency tracker, using object identity (not object equality) for the Ys (of type T)
+ * See {@link AbstractDependencyTracker} for more details + * + * @author Alex Black + */ +@Slf4j +public class IdentityDependencyTracker extends AbstractDependencyTracker { + + @Override + protected Map newTMap() { + return new IdentityHashMap<>(); + } + + @Override + protected Set newTSet() { + return Collections.newSetFromMap(new IdentityHashMap()); + } + + @Override + protected String toStringT(T t) { + if(t instanceof INDArray){ + INDArray i = (INDArray)t; + return System.identityHashCode(t) + " - id=" + i.getId() + ", " + i.shapeInfoToString(); + } else { + return System.identityHashCode(t) + " - " + t.toString(); + } + } + + @Override + protected String toStringD(D d) { + return d.toString(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index e16dad580..55165b530 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.NonNull; +import lombok.*; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; @@ -24,15 +24,17 @@ import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; -import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.shape.Concat; +import org.nd4j.linalg.api.ops.impl.shape.Stack; import org.nd4j.linalg.api.ops.impl.shape.tensorops.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; @@ -48,36 +50,92 @@ import org.nd4j.linalg.util.ArrayUtil; import java.util.*; /** - * InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes. - * Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs. + * InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes.
+ * Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs.
+ * Note that while AbstractSession handles the graph structure component, InferenceSession handles only op execution + * and memory management
+ *
+ * For INDArray memory management - i.e., tracking and releasing memory manually, as soon as possible, to + * minimize memory use - this is implemented using a {@link SessionMemMgr} instance (for allocations/deallocations) and + * also {@link IdentityDependencyTracker} to track where arrays are actually used. The IdentityDependencyTracker tells + * us when the array is no longer needed (i.e., has been "fully consumed" by all ops depending on it) accounting for the + * fact that some operations, such as identity, enter, exit, etc, are "zero copy" for performance reasons. * * @author Alex Black */ @Slf4j -public class InferenceSession extends AbstractSession { +public class InferenceSession extends AbstractSession { private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" + "Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed."; + protected static final String KERAS_TRAIN_TEST = "keras_learning_phase"; + + @Getter + @Setter + private SessionMemMgr mmgr; //Used for allocating and deallocating memory + /** + * Array use tracker: What needs to happen before the array can be closed/released? + * As the name suggests, the INDArrays are tracked using qbject identity, not equality + */ + @Getter + @Setter + private IdentityDependencyTracker arrayUseTracker = new IdentityDependencyTracker<>(); + + public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); + + mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager } @Override - protected Map preprocessPlaceholders(Map placeholders){ - //Handle casting of the input array automatically. - //The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double - // array for a float placeholder - if(placeholders == null || placeholders.isEmpty()){ + protected Map preprocessPlaceholders(Map placeholders, At at) { + arrayUseTracker.clear(); + + //We'll also use this method as a "pre execution" hook-in, to mark variables as something we should never deallocate + //This occurs by never marking these "ConstantDep" and "VariableDep" instances as satisfied, so there's always + // an unsatisfied dependency for them in the array use tracker + //TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration... + for (SDVariable v : sameDiff.variables()) { + if (v.getVariableType() == VariableType.CONSTANT) { + arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.name())); + } else if (v.getVariableType() == VariableType.VARIABLE) { + arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.name())); + } + } + + //Workaround for some TF/Keras based models that require explicit train/test as a placeholder + boolean kerasWorkaround = false; + List phs = sameDiff.inputs(); + if (phs != null && !phs.isEmpty()) { + for (String s : phs) { + if (s.endsWith(KERAS_TRAIN_TEST) && !placeholders.containsKey(s)) { + // The behaviour of some Keras layers (like GRU) differs depending on whether the model is training. + // We provide this value directly, unless the user has provided this manually + INDArray scalar = mmgr.allocate(false, DataType.BOOL).assign(at.operation().isTrainingPhase()); + placeholders = new HashMap<>(placeholders); //Array might be singleton, or otherwise unmodifiable + placeholders.put(s, scalar); + kerasWorkaround = true; + } + } + } + + + if (placeholders == null || placeholders.isEmpty()) { return placeholders; } - Map out = new HashMap<>(); - for(Map.Entry e : placeholders.entrySet()){ + //Handle casting of the input array automatically. + //The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double + // array for a float placeholder + //TODO eventually we might have ops that support multiple input types, and hence won't need this casting + Map out = new HashMap<>(); + for (Map.Entry e : placeholders.entrySet()) { Preconditions.checkState(sameDiff.hasVariable(e.getKey()), "Invalid placeholder passed for execution: " + "No variable/placeholder with name %s exists", e.getKey()); INDArray arr = e.getValue(); //First: check workspaces - if(arr.isAttached()){ + if (arr.isAttached()) { MemoryWorkspace ws = arr.data() == null ? null : arr.data().getParentWorkspace(); if (ws != null && ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) { if (!ws.isScopeActive()) { @@ -96,89 +154,234 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, - Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { - if(listeners != null && listeners.size() > 0){ - SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); - for(Listener l : listeners){ - if(l.isActive(at.operation())) + protected Map postProcessOutput(Map output) { + + //For any queued (not yet processed) ops - mark them as satisfied, so we can deallocate any arrays + // that are waiting on them + if (dt.hasNewAllSatisfied()) { + List execSteps = dt.getNewAllSatisfiedList(); + for (ExecStep es : execSteps) { + if (es.getType() == ExecType.OP) { + OpDep od = new OpDep(es.getName(), es.getFrameIter().getFrame(), es.getFrameIter().getIteration(), es.getFrameIter().getParentFrame()); + arrayUseTracker.markSatisfied(od, true); + } + } + } + + //Also mark "end of execution" for array dependency tracker. Mainly used for TensorArray arrays at present. + //TODO Optimize for reduced memory for some TensorArray operations - i.e., close/deallocate earlier + arrayUseTracker.markSatisfied(new ExecDoneDep(), true); + if (arrayUseTracker.hasNewAllSatisfied()) { + List l = arrayUseTracker.getNewAllSatisfiedList(); + for (INDArray arr : l) { + mmgr.release(arr); + } + } + + return output; + } + + @Override + public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + if (listeners != null && listeners.size() > 0) { + SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); + for (Listener l : listeners) { + if (l.isActive(at.operation())) l.preOpExecution(sameDiff, at, sdOp); } } - INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs); - if(listeners != null && listeners.size() > 0){ - SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); + INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); + op.getOp().clearArrays(); - Map namedOutsBuilder = new HashMap<>(); + if (log.isTraceEnabled()) { + StringBuilder sb = new StringBuilder(); + sb.append(op.getName()).append(" - ").append(outputFrameIter).append(" outputs: "); + List opOutNames = op.getOutputsOfOp(); + for (int i = 0; i < out.length; i++) { + if (i > 0) + sb.append(", "); + sb.append("(").append(i).append(" - ").append(opOutNames.get(i)).append(" = ").append( + out[i] == null ? null : out[i].getId()).append(")"); + } + log.trace(sb.toString()); + } - for(int i = 0 ; i < out.length ; i++) - namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]); + //Call listeners, before we (maybe) deallocate input arrays + if (listeners != null && listeners.size() > 0) { + Map namedOuts = null; - Map namedOuts = Collections.unmodifiableMap(namedOutsBuilder); + for (Listener l : listeners) { + if (l.isActive(at.operation())) { + //Lazily create map, only if required + if (namedOuts == null) { + Map namedOutsBuilder = new HashMap<>(); - for(Listener l : listeners){ - if(l.isActive(at.operation())) { - l.opExecution(sameDiff, at, batch, sdOp, out); + for (int i = 0; i < out.length; i++) + namedOutsBuilder.put(op.outputsOfOp.get(i), out[i]); + namedOuts = Collections.unmodifiableMap(namedOutsBuilder); + } - for(String varName : namedOuts.keySet()){ - l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName)); + + l.opExecution(sameDiff, at, batch, op, out); + + for (String varName : namedOuts.keySet()) { + l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName)); } } } } + + + //Record array uses for memory management/deallocation + SameDiffOp o = sameDiff.getOps().get(op.getName()); + List outVarNames = o.getOutputsOfOp(); + for (int i = 0; i < out.length; i++) { + if (out[i] == null && o.getOp() instanceof Switch) + continue; //Switch case: we only ever get one of 2 outputs, other is null (branch not executed) + + String name = outVarNames.get(i); + Variable v = sameDiff.getVariables().get(name); + List inputsForOps = v.getInputsForOp(); + if (inputsForOps != null) { + for (String opName : inputsForOps) { + //Only add dependencies if we actually need the op this feeds into, otherwise the dependency + // will will never be marked as satisfied + if (!subgraphOps.contains(opName)) + continue; + + SameDiffOp forOp = sameDiff.getOps().get(opName); + + //TODO do switch or merge need special handling also? + if (forOp.getOp() instanceof Enter) { + Enter e = (Enter) forOp.getOp(); + if (e.isConstant()) { + /* + Contant enter case: Need to keep this array around for the entire duration of the frame, including + any nested frames, and all iterations. + Unfortunately, we don't know exactly when we're done with a frame for good + This isn't a great solution, but other possibilities (frame close, trying to detect all exit ops, + detecting return to parent frame, etc all fail in certain circumstances, such as due to control dependencies + on variables). + */ + Dep d = new ExecDoneDep(); + arrayUseTracker.addDependency(out[i], d); + } else { + Dep d = new OpDep(opName, e.getFrameName(), 0, outputFrameIter); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } + } else if (forOp.getOp() instanceof NextIteration) { + //The array is needed by the NEXT iteration op, not the current one + Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration() + 1, outputFrameIter.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); + } else if (forOp.getOp() instanceof Exit) { + //The array is needed at the EXIT frame (i.e., parent frame), not the inner/just executed one + FrameIter fi = outputFrameIter.getParentFrame(); + Dep d = new OpDep(opName, fi.getFrame(), fi.getIteration(), fi.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } else { + //All other ops... + Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame()); + arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed + } + } + } + + if (OUTER_FRAME.equals(outputFrameIter.getFrame()) && allReqVariables.contains(name)) { + //This variable is an output, record that in the array use tracker, so we don't deallocate it + arrayUseTracker.addDependency(out[i], new ReqOutputDep(name)); + } else if ((inputsForOps == null || inputsForOps.isEmpty()) && !arrayUseTracker.hasDependency(out[i])) { + //This particular array is not actually needed anywhere, so we can deallocate in immediately + //Possibly only a control dependency, or only one of the outputs of a multi-output op is used + if (log.isTraceEnabled()) { + log.trace("Found array id {} (output of {}) not required anywhere, deallocating", out[i].getId(), o.getName()); + } + mmgr.release(out[i]); + } + } + + //Mark current op dependency as satisfied... + Dep d = new OpDep(op.getName(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame()); + arrayUseTracker.markSatisfied(d, true); + + + //Close any no longer required arrays + if (arrayUseTracker.hasNewAllSatisfied()) { + List canClose = arrayUseTracker.getNewAllSatisfiedList(); + for (INDArray arr : canClose) { + if (log.isTraceEnabled()) { + log.trace("Closing array... id={}, {}", arr.getId(), arr.shapeInfoToString()); + } + mmgr.release(arr); + } + } + return out; } - public INDArray[] getOutputsHelper(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, - Set constAndPhInputs){ + public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs) { int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) + (allIterInputs == null ? 0 : allIterInputs.size()); boolean constPhInput = (opInputs == null || opInputs.size() == 0) && (allIterInputs == null || allIterInputs.size() == 0); - if(op instanceof Identity ) { + if (op instanceof Identity) { Identity i = (Identity) op; String[] argNames = i.argNames(); - Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", argNames); - VarId vid = newVarId(argNames[0], outputFrameIter); - return new INDArray[]{nodeOutputs.get(vid)}; + Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object) argNames); + VarId vid = outputFrameIter.toVarId(argNames[0]); - } else if(op instanceof Switch) { + INDArray orig = nodeOutputs.get(vid); + return new INDArray[]{orig}; + } else if (op instanceof Switch) { Switch s = (Switch) op; String[] argNames = s.argNames(); //Order: input, boolean array - VarId vidPredicate = newVarId(argNames[1], outputFrameIter); + VarId vidPredicate = outputFrameIter.toVarId(argNames[1]); INDArray predicate = this.nodeOutputs.get(vidPredicate); Preconditions.checkState(predicate.isScalar() && predicate.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", predicate); - VarId vid = newVarId(argNames[0], outputFrameIter); + VarId vid = outputFrameIter.toVarId(argNames[0]); if (predicate.getDouble(0) == 0.0) { return new INDArray[]{this.nodeOutputs.get(vid), null}; } else { return new INDArray[]{null, this.nodeOutputs.get(vid)}; } - } else if(op instanceof Enter) { + } else if (op instanceof Enter) { //Enter op: forwards input to specified execution frame - Enter e = (Enter)op; + Enter e = (Enter) op; String[] input = e.argNames(); - Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", input); + Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", (Object) input); Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", e.getOwnName(), opInputs, constAndPhInputs); VarId inputVarId; - if(constPhInput) { + if (constPhInput) { //Constant or placeholder inputVarId = new VarId(constAndPhInputs.iterator().next(), OUTER_FRAME, 0, null); - } else if(allIterInputs != null && allIterInputs.size() > 0){ + } else if (allIterInputs != null && allIterInputs.size() > 0) { inputVarId = allIterInputs.iterator().next(); } else { inputVarId = opInputs.iterator().next(); @@ -187,332 +390,356 @@ public class InferenceSession extends AbstractSession 0){ + } else if (allIterInputs != null && allIterInputs.size() > 0) { inputVarId = allIterInputs.iterator().next(); } else { inputVarId = opInputs.iterator().next(); } INDArray exitInput = this.nodeOutputs.get(inputVarId); return new INDArray[]{exitInput}; - } else if(op instanceof NextIteration){ + } else if (op instanceof NextIteration) { //NextIteration op: forwards its single input to the output of the current frame, but increments the iteration number Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", opInputs, constAndPhInputs); VarId in = (allIterInputs != null && !allIterInputs.isEmpty() ? allIterInputs.iterator().next() : opInputs.iterator().next()); Preconditions.checkState(outputFrameIter.getFrame().equals(in.getFrame()), "Expected same frame for NextIteration input vs. output:" + " got input %s, output %s", in, outputFrameIter); - Preconditions.checkState(outputFrameIter.getIteration() == in.getIteration()+1, "Expected output iteration for NextIteration output to" + + Preconditions.checkState(outputFrameIter.getIteration() == in.getIteration() + 1, "Expected output iteration for NextIteration output to" + " be 1 larger than the input iteration. Input: %s, output %s", in, outputFrameIter); INDArray inArr = this.nodeOutputs.get(in); + if (inArr == null) { + Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", + op.getOwnName(), sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration()); + } return new INDArray[]{inArr}; - } else if(op instanceof If) { - If i = (If) op; - String[] argNames = i.argNames(); //Order should be: [boolean], true, false - - - throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); - } else if(op instanceof Merge) { - //Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour + } else if (op instanceof Merge) { + //Merge available for forward pass when any of its inputs are available. When multiple are available, behaviour // is undefined Merge m = (Merge) op; String[] in = sameDiff.getInputsForOp(op); for (String s : in) { - VarId vid = newVarId(s, outputFrameIter); + VarId vid = outputFrameIter.toVarId(s); if (nodeOutputs.containsKey(vid)) { log.trace("Returning input \"{}\" for merge node \"{}\"", m.getOwnName(), s); - return new INDArray[]{nodeOutputs.get(vid)}; + INDArray arr = nodeOutputs.get(vid); + Preconditions.checkState(arr != null, "Could not find output array for %s", vid); + return new INDArray[]{arr}; } } throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) + ") - should not be executed at this point"); - } else if(op instanceof LoopCond) { + } else if (op instanceof LoopCond) { //LoopCond just forwards scalar boolean to output LoopCond lc = (LoopCond) op; String[] argNames = lc.argNames(); - Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", argNames); - VarId vid = newVarId(argNames[0], outputFrameIter); + Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object) argNames); + VarId vid = outputFrameIter.toVarId(argNames[0]); INDArray arr = nodeOutputs.get(vid); Preconditions.checkNotNull(arr, "Input to LoopCond op must not be null"); Preconditions.checkState(arr.isScalar() && arr.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape"); return new INDArray[]{arr}; - } else if(op instanceof BaseTensorOp) { + } else if (op instanceof BaseTensorOp) { //TensorOps - special cases... - if (op instanceof TensorArray) { - //Create a TensorArray - VarId vid = newVarId(op.outputVariable().getVarName(), outputFrameIter); - Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); - tensorArrays.put(vid, new ArrayList()); - - // Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0) - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(true), Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArrayRead) { - //Do lookup and return - //Input 0 is the TensorArray (or dummy variable that represents it). Sometimes (for import) this can be like (TensorArray -> Enter -> TensorArrayRead) - //Input 1 is the index - SDVariable idxSDV = op.arg(1); - INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", idxArr); - int i = idxArr.getInt(0); - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - - //Work out the frame/iteration: - VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(v == null && allIterInputs != null){ - v = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - - Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); - - while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){ - //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead - //TODO also TensorArrayWrite, scatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - v = newVarId(inTensorArray.getVarName(), v.getParentFrame()); - } - - List list = getTensorArrays().get(v); - Preconditions.checkState(list != null, "Could not find TensorList for %s", v); - Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", i, list.size(), v); - - INDArray out = list.get(i); - return new INDArray[]{out}; - } else if (op instanceof TensorArrayWrite) { - //TensorArrayWrite - also has a scalar 0.0 that it returns... - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - - Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); - - while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){ - //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite - //TODO also TensorArrayScatter, etc?? - inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); - tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame()); - } - - //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead - //Input 1 is the index - //Input 2 is the value to write - - String idxName = op.arg(1).getVarName(); - SDVariable idxSDV = sameDiff.getVariable(idxName); - INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); - int idx = idxArr.getInt(0); - - String inName = op.arg(2).getVarName(); - SDVariable inSDV = sameDiff.getVariable(inName); - INDArray arr = getArray(inSDV, opInputs, allIterInputs); - Preconditions.checkState(arr != null, "Could not find array for %s", inName); - - Preconditions.checkState(tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", tArr); - //TODO is this always safe to insert by index for all execution orders? - List l = tensorArrays.get(tArr); //.set(idx, arr); - while (l.size() <= idx) { - //Can't use set(int, E) if index >= size - l.add(null); - } - l.set(idx, arr); - - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArraySize) { - //Index 0 is the TensorArray (or dummy variable that represents it) - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - //Work out the varid (frame/iteration) of the tensor array: - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(DataType.INT, l.size())}; - } - } else if (op instanceof TensorArrayConcat) { - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - //TODO - empty checks. But is size 0 OK? - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - INDArray concat = Nd4j.concat(0, l.toArray(new INDArray[l.size()])); - return new INDArray[]{concat}; - } - } else if (op instanceof TensorArrayGather) { - //Input 0: the TensorArray - //Input 1: the indices (1d integer vector) - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String indicesName = op.arg(1).getVarName(); - SDVariable indicesSDV = sameDiff.getVariable(indicesName); - INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); - Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); - - int[] idxArrInt = idxArr.toIntVector(); - - //Edge case: -1 means "all" - ArrayList newList = new ArrayList<>(); - if(idxArrInt.length == 1 && idxArrInt[0] == -1){ - newList.addAll(l); - } else { - for (int id : idxArrInt) { - Preconditions.checkState(id >=0,"Index for TensorArrayGather must be >= 0, got %s", id); - newList.add(l.get(id)); - } - } - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - INDArray out = Nd4j.pile(newList); - return new INDArray[]{out}; - } - } else if (op instanceof TensorArrayScatter) { - //Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray - //Input 0: the TensorArray - //Input 1: the indices (1d integer vector) - //Input 2: The values to scatter - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName()); - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String indicesName = op.arg(1).getVarName(); - SDVariable indicesSDV = sameDiff.getVariable(indicesName); - INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); - Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); - Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); - int[] idxs = idxArr.toIntVector(); - - String valuesName = op.arg(2).getVarName(); - SDVariable valuesSDV = sameDiff.getVariable(valuesName); - INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); - - while (l.size() <= idxs.length) { //Can't use set(int, E) if index >= size - l.add(null); - } - - //Edge case: idxs being [-1] means "all sub arrays" (i.e., "unstack" case) - if(idxs.length == 1 && idxs[0] == -1){ - idxs = ArrayUtil.range(0, (int)valuesArr.size(0)); - } - - INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); - for (int i = 0; i < idxs.length; i++) { - idx[0] = NDArrayIndex.point(i); - INDArray get = valuesArr.get(idx).dup(); - int outIdx = idxs[i]; - if(valuesArr.rank() == 2 && get.rank() == 2){ - //Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7092 - get = get.reshape(get.length()); - } - if(valuesArr.rank() == 1 && get.rank() > 0){ - get = get.reshape(new long[0]); - } - l.set(outIdx, get); - } - - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else if (op instanceof TensorArraySplit) { - //Split values from a rank (N+1)d tensor into sequential indices of the TensorArray - //For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:] - //Input 0: the TensorArray - //Input 1: The values to split - //Input 2: the size of each split (1d integer vector) - - SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array - VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); - if(tArr == null && allIterInputs != null){ - tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); - } - List l = tensorArrays.get(tArr); - Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); - - String splitName = op.arg(1).getVarName(); - INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); - - - String sizeName = op.arg(2).getVarName(); - SDVariable sizeSDV = sameDiff.getVariable(sizeName); - INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); - Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); - Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", sizeArr.dataType(), sizeName); - int[] sizes = sizeArr.toIntVector(); - - while (l.size() <= sizes.length) { //Can't use set(int, E) if index >= size - l.add(null); - } - - INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); - int soFar = 0; - for (int i = 0; i < sizes.length; i++) { - idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]); - INDArray sub = splitArr.get(idx).dup(); - l.set(i, sub); - soFar += sizes[i]; - } - //Return dummy array - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - return new INDArray[]{Nd4j.scalar(0.0f)}; - } - } else { - throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName()); - } - } else if(op instanceof GradientBackwardsMarker){ - return new INDArray[]{Nd4j.scalar(1.0f)}; - } else if(op instanceof CustomOp){ - CustomOp c = (CustomOp)op; - Nd4j.getExecutioner().exec(c); + return getOutputsHelperTensorArrayOps(op, outputFrameIter, opInputs, allIterInputs); + } else if (op instanceof GradientBackwardsMarker) { + INDArray out = mmgr.allocate(false, DataType.FLOAT).assign(1.0f); + return new INDArray[]{out}; + } else if (op instanceof ExternalErrorsFunction) { + ExternalErrorsFunction fn = (ExternalErrorsFunction) op; + String n = fn.getGradPlaceholderName(); + INDArray arr = nodeOutputs.get(new VarId(n, OUTER_FRAME, 0, null)); + Preconditions.checkState(arr != null, "Could not find external errors placeholder array: %s", arr); + INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape()); + out.assign(arr); + return new INDArray[]{out}; + } else if (op instanceof CustomOp) { + CustomOp c = (CustomOp) op; + Nd4j.exec(c); return c.outputArguments(); - } else if(op instanceof Op) { + } else if (op instanceof Op) { Op o = (Op) op; - Nd4j.getExecutioner().exec(o); + Nd4j.exec(o); return new INDArray[]{o.z()}; } else { throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); } } + /** + * Forward pass for TensorArray ops + */ + public INDArray[] getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs) { + /* + TODO: TensorArray memory management note: For now, we'll close any INDArrays stored in the TensorArray at the end of + graph execution. This uses more memory than necessary for an earlier close strategy, but simplifies memory management. + This should be revisited and optimized later + */ + + if (op instanceof TensorArray) { + //Create a TensorArray + VarId vid = outputFrameIter.toVarId(op.outputVariable().name()); + Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); + tensorArrays.put(vid, new ArrayList()); + + // Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0) + INDArray dummy = mmgr.allocate(false, DataType.BOOL).assign(true); + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{dummy, scalar}; + } else if (op instanceof TensorArrayRead) { + //Do lookup and return + //Input 0 is the TensorArray (or dummy variable that represents it). Sometimes (for import) this can be like (TensorArray -> Enter -> TensorArrayRead) + //Input 1 is the index + SDVariable idxSDV = op.arg(1); + INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", idxArr); + int i = idxArr.getInt(0); + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + + //Work out the frame/iteration: + VarId v = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (v == null && allIterInputs != null) { + v = lookup(inTensorArray.name(), allIterInputs, false); + } + + Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.name()); + + while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) { + //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead + //TODO also TensorArrayWrite, scatter, etc?? + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg(); + v = v.getParentFrame().toVarId(inTensorArray.name()); + } + + List list = getTensorArrays().get(v); + Preconditions.checkState(list != null, "Could not find TensorList for %s", v); + Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", i, list.size(), v); + + INDArray out = list.get(i); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayWrite) { + //TensorArrayWrite - also has a scalar 0.0 that it returns... + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + //Work out the varid (frame/iteration) of the tensor array: + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + + Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.name()); + + while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) { + //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite + //TODO also TensorArrayScatter, etc?? + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg(); + tArr = tArr.getParentFrame().toVarId(inTensorArray.name()); + } + + //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead + //Input 1 is the index + //Input 2 is the value to write + + String idxName = op.arg(1).name(); + SDVariable idxSDV = sameDiff.getVariable(idxName); + INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); + int idx = idxArr.getInt(0); + + String inName = op.arg(2).name(); + SDVariable inSDV = sameDiff.getVariable(inName); + INDArray arr = getArray(inSDV, opInputs, allIterInputs); + Preconditions.checkState(arr != null, "Could not find array for %s", inName); + + Preconditions.checkState(tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", tArr); + //TODO is this always safe to insert by index for all execution orders? + List l = tensorArrays.get(tArr); //.set(idx, arr); + while (l.size() <= idx) { + //Can't use set(int, E) if index >= size + l.add(null); + } + l.set(idx, arr); + + //Add a dependency + Dep d = new ExecDoneDep(); + arrayUseTracker.addDependency(arr, d); + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArraySize) { + //Index 0 is the TensorArray (or dummy variable that represents it) + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + //Work out the varid (frame/iteration) of the tensor array: + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + INDArray scalar = mmgr.allocate(false, DataType.INT).assign(l.size()); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArrayConcat) { + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + + Concat c = new Concat(0, l.toArray(new INDArray[0])); + List shape = c.calculateOutputShape(); + INDArray out = mmgr.allocate(false, shape.get(0)); + c.setOutputArgument(0, out); + Nd4j.exec(c); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayGather) { + //Input 0: the TensorArray + //Input 1: the indices (1d integer vector) + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String indicesName = op.arg(1).name(); + SDVariable indicesSDV = sameDiff.getVariable(indicesName); + INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); + Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); + + int[] idxArrInt = idxArr.toIntVector(); + + //Edge case: -1 means "all" + List newList = new ArrayList<>(); + if (idxArrInt.length == 1 && idxArrInt[0] == -1) { + newList.addAll(l); + } else { + for (int id : idxArrInt) { + Preconditions.checkState(id >= 0, "Index for TensorArrayGather must be >= 0, got %s", id); + newList.add(l.get(id)); + } + } + + Stack s = new Stack(newList.toArray(new INDArray[0]), null, 0); + List shape = s.calculateOutputShape(); + INDArray out = mmgr.allocate(false, shape.get(0)); + s.setOutputArgument(0, out); + Nd4j.exec(s); + return new INDArray[]{out}; + } else if (op instanceof TensorArrayScatter) { + //Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray + //Input 0: the TensorArray + //Input 1: the indices (1d integer vector) + //Input 2: The values to scatter + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.name()); + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String indicesName = op.arg(1).name(); + SDVariable indicesSDV = sameDiff.getVariable(indicesName); + INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); + Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); + Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); + int[] idxs = idxArr.toIntVector(); + + String valuesName = op.arg(2).name(); + SDVariable valuesSDV = sameDiff.getVariable(valuesName); + INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); + + while (l.size() <= idxs.length) { //Can't use set(int, E) if index >= size + l.add(null); + } + + //Edge case: idxs being [-1] means "all sub arrays" (i.e., "unstack" case) + if (idxs.length == 1 && idxs[0] == -1) { + idxs = ArrayUtil.range(0, (int) valuesArr.size(0)); + } + + INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); + for (int i = 0; i < idxs.length; i++) { + idx[0] = NDArrayIndex.point(i); + INDArray get = mmgr.dup(valuesArr.get(idx)); + int outIdx = idxs[i]; + if (valuesArr.rank() == 1 && get.rank() > 0) { + get = get.reshape(); + } + l.set(outIdx, get); + + //Add dependency for values array until end of execution + arrayUseTracker.addDependency(get, new ExecDoneDep()); + } + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else if (op instanceof TensorArraySplit) { + //Split values from a rank (N+1)d tensor into sequential indices of the TensorArray + //For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:] + //Input 0: the TensorArray + //Input 1: The values to split + //Input 2: the size of each split (1d integer vector) + + SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array + VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false)); + if (tArr == null && allIterInputs != null) { + tArr = lookup(inTensorArray.name(), allIterInputs, false); + } + List l = tensorArrays.get(tArr); + Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); + + String splitName = op.arg(1).name(); + INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); + + + String sizeName = op.arg(2).name(); + SDVariable sizeSDV = sameDiff.getVariable(sizeName); + INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); + Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); + Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", sizeArr.dataType(), sizeName); + int[] sizes = sizeArr.toIntVector(); + + while (l.size() <= sizes.length) { //Can't use set(int, E) if index >= size + l.add(null); + } + + INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class); + int soFar = 0; + for (int i = 0; i < sizes.length; i++) { + idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]); + INDArray sub = mmgr.dup(splitArr.get(idx)); + l.set(i, sub); + soFar += sizes[i]; + + //Add dependency for values array until end of execution + arrayUseTracker.addDependency(sub, new ExecDoneDep()); + } + + //Return dummy array + INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0); + return new INDArray[]{scalar}; + } else { + throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName()); + } + } + + @Override public INDArray getConstantOrVariable(String variableName) { SDVariable v = sameDiff.getVariable(variableName); @@ -522,21 +749,19 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, - Set constAndPhInputs, Map placeholderValues) { + public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, Map placeholderValues, Set allReqVariables) { + SameDiffOp sdo = sameDiff.getOps().get(opName); + DifferentialFunction df = sdo.getOp(); - DifferentialFunction df = sameDiff.getOpById(opName); + //TODO Switch to OpContext - and make sure executing like that is thread safe (i.e., array fields in ops are not used etc) - //TODO We should clone these ops - probably - as we don't want them shared between threads/sessions! - //But let's only clone them *once* and cache in inference session - not on every exec + Preconditions.checkNotNull(df, "No differential function found with name \"%s\"", opName); - Preconditions.checkNotNull(df, "No differential function fond with name %s", opName); - - if(df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || - df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While || - df instanceof BaseTensorOp){ + if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || + df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) { //Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case - return df; + return sdo; } //Infer the args based on the inputs (variable + frame + iteration) @@ -546,123 +771,41 @@ public class InferenceSession extends AbstractSession constEnterInputs = null; - if(numArgs != (numNonConstIns + numConstPhIns + numNonConstInsAllIters)){ - boolean anyConstEnterInputs = false; - SDVariable[] args = df.args(); - for(SDVariable v : args){ - Variable var = sameDiff.getVariables().get(v.getVarName()); - //Nested enter case: - DifferentialFunction inputVarFn = (var.getOutputOfOp() == null ? null : sameDiff.getOps().get(var.getOutputOfOp()).getOp()); - if(inputVarFn instanceof Enter && ((Enter)inputVarFn).isConstant()){ - anyConstEnterInputs = true; - if(constEnterInputs == null) - constEnterInputs = new HashSet<>(); - constEnterInputs.add(v.getVarName()); - } - } - - int constEnterInputCount = 0; - if(anyConstEnterInputs){ - /* - 2019/01/26: AB - Resolve nested enter inputs (constants 2+ enters in) - Why this hack is necessary: consider the following (sub) graph: constX -> Enter(a) -> Enter(b) -> opY - On iterations (a=0, b=0) all is well, opY gets triggered as normal. - On iterations (a>0, b=*) the "opY is available for exec" won't be triggered. - This is because Enter(a) is only executed once, on iteration 0 of the outer loop. - Consequently, Enter(b) is not triggered as available on iteration 1+. - When we do the lookup for the actual array to use for op execution (i.e., get inputs for opY(a=1,b=0)) - it won't be found. - This is a bit of an ugly hack, though I've yet to find a cleaner solution. - It should only be required with the combination of: constants, 2 levels of enters, and more than 1 iteration in each loop. - */ - - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - for(String s : constEnterInputs){ - //First: check if this has already been provided - if(constAndPhInputs != null && constAndPhInputs.contains(s)){ - //already resolved/provided - continue; - } - boolean found = false; - if(allIterInputs != null) { - for (VarId vid : allIterInputs) { - if (s.equals(vid.getVariable())) { - //Already resolved/provided - found = true; - break; - } - } - } - if(found) - continue; - - constEnterInputCount++; - } - } - - if(numArgs > 1){ + if (numArgs != (numNonConstIns + numConstPhIns + numNonConstInsAllIters)) { + if (numArgs > 1) { //Might be due to repeated inputs Set uniqueArgNames = new HashSet<>(); Collections.addAll(uniqueArgNames, argNames); - Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters + constEnterInputCount), + Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), opName, uniqueArgNames, opInputs, constAndPhInputs); } else { - Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns + constEnterInputCount), + Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), opName, argNames, opInputs, constAndPhInputs); } } INDArray[] args = null; - if(argNames != null && argNames.length > 0) { + if (argNames != null && argNames.length > 0) { args = new INDArray[argNames.length]; int i = 0; - for(String s : argNames){ + for (String s : argNames) { SDVariable v = sameDiff.getVariable(s); - if(v.isConstant()) { + if (v.isConstant()) { args[i] = v.getArr(); - } else if(v.isPlaceHolder()) { - Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s); + } else if (v.getVariableType() == VariableType.VARIABLE) { + args[i] = v.getArr(); + } else if (v.isPlaceHolder()) { + Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", s); args[i] = placeholderValues.get(s); - } else if(constEnterInputs != null && constEnterInputs.contains(s)){ - //For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy - //For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should - // be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0)) - VarId vid = newVarId(s, frameIter.clone()); - vid.setIteration(0); - FrameIter toZero = vid.getParentFrame(); - while(toZero != null){ - toZero.setIteration(0); - toZero = toZero.getParentFrame(); - } - INDArray arr = this.nodeOutputs.get(vid); - args[i] = arr; } else { - if(opInputs != null) { - for (VarId vid : opInputs) { - if (vid.getVariable().equals(s)) { - args[i] = this.nodeOutputs.get(vid); - break; - } - } - } - if(args[i] == null && allIterInputs != null){ - for(VarId vid : allIterInputs){ - if(vid.getVariable().equals(s)){ - args[i] = this.nodeOutputs.get(vid); - break; - } - } - } + VarId vid = lookup(s, opInputs, allIterInputs, true); + args[i] = nodeOutputs.get(vid); } - Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName()); + Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.name()); i++; } - } //Set the op inputs and output arguments @@ -671,19 +814,23 @@ public class InferenceSession extends AbstractSession 0; - if(df instanceof CustomOp){ + if (df instanceof CustomOp) { DynamicCustomOp customOp = (DynamicCustomOp) df; - if(args != null) { + if (args != null) { customOp.setInputArguments(args); } - df.resolvePropertiesFromSameDiffBeforeExecution(); + if (df instanceof Identity) { + //We don't need to allocate an output array for Identity, we pass through the input array without copying + return sdo; + } + List outShape = customOp.calculateOutputShape(); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); - for( int i=0; i 0){ + if (args != null && args.length > 0) { op.setX(args[0]); if (args.length == 2 && !axisArg) op.setY(args[1]); @@ -749,51 +893,104 @@ public class InferenceSession extends AbstractSession outputShape = ((BaseOp) op).calculateOutputShape(); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); INDArray z = op.z(); - if (z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { + if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { if (log.isTraceEnabled()) { log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString()); } LongShapeDescriptor lsd = outputShape.get(0); - try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - z = Nd4j.create(lsd, false); - } + + boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); + z = mmgr.allocate(isOutput, lsd); op.setZ(z); } } - df.resolvePropertiesFromSameDiffBeforeExecution(); } - return df; + return sdo; } - protected INDArray getArray(SDVariable sdv, Collection opInputs, Collection allIterInputs){ - String n = sdv.getVarName(); - if(sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE){ + protected INDArray getArray(SDVariable sdv, Collection opInputs, Collection allIterInputs) { + String n = sdv.name(); + if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) { return getConstantOrVariable(n); } else { - VarId inVarId = null; - if(opInputs != null){ - inVarId = lookup(n, opInputs, false); - } - if(inVarId == null && allIterInputs != null && !allIterInputs.isEmpty()){ - inVarId = lookup(n, allIterInputs, false); - } - Preconditions.checkState(inVarId != null,"Could not find array for variable %s", sdv.getVarName()); + VarId inVarId = lookup(n, opInputs, allIterInputs, false); + Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.name()); return nodeOutputs.get(inVarId); } } + + @Data + public abstract static class Dep { + protected String frame; + protected FrameIter parentFrame; + } + + @AllArgsConstructor + @Data + @EqualsAndHashCode(callSuper = true) + public static class OpDep extends Dep { + protected String opName; + protected int iter; + + protected OpDep(@NonNull String opName, @NonNull String frame, int iter, FrameIter parentFrame) { + this.opName = opName; + this.frame = frame; + this.iter = iter; + this.parentFrame = parentFrame; + } + + @Override + public String toString() { + return "OpDep(" + opName + ",frame=" + frame + ",iter=" + iter + (parentFrame == null ? "" : ",parent=" + parentFrame) + ")"; + } + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class PlaceholderDep extends Dep { + protected String phName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class VariableDep extends Dep { + protected String varName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class ConstantDep extends Dep { + protected String constName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @AllArgsConstructor + protected static class ReqOutputDep extends Dep { + protected String outputName; + } + + @Data + @EqualsAndHashCode(callSuper = true) + @NoArgsConstructor + protected static class ExecDoneDep extends Dep { + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java index de3e96c2e..8e9b45067 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java @@ -30,8 +30,10 @@ import java.util.List; @Builder public class SameDiffOp { protected String name; - protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) - protected List inputsToOp; //Name of SDVariables as input - protected List outputsOfOp; //Name of SDVariables as output - protected List controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) + protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) + protected List inputsToOp; //Name of SDVariables as input + protected List outputsOfOp; //Name of SDVariables as output + protected List controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) + protected List varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op + protected List controlDepFor; //Name of the variables that this op is a control dependency for } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java new file mode 100644 index 000000000..b54db548a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SessionMemMgr.java @@ -0,0 +1,60 @@ +package org.nd4j.autodiff.samediff.internal; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; + +import java.io.Closeable; + +/** + * SessionMemMgr - aka "Session Memory Manager" is responsible for allocating, managing, and deallocating memory used + * during SameDiff execution.
+ * This interface allows different memory management strategies to be used, abstracted away from the actual graph + * execution logic + * + * @author Alex Black + */ +public interface SessionMemMgr extends Closeable { + + /** + * Allocate an array with the specified datatype and shape.
+ * NOTE: This array should be assumed to be uninitialized - i.e., contains random values. + * + * @param detached If true: the array is safe to return outside of the SameDiff session run (for example, the array + * is one that may be returned to the user) + * @param dataType Datatype of the returned array + * @param shape Array shape + * @return The newly allocated (uninitialized) array + */ + INDArray allocate(boolean detached, DataType dataType, long... shape); + + /** + * As per {@link #allocate(boolean, DataType, long...)} but from a LongShapeDescriptor instead + */ + INDArray allocate(boolean detached, LongShapeDescriptor descriptor); + + /** + * Allocate an uninitialized array with the same datatype and shape as the specified array + */ + INDArray ulike(INDArray arr); + + /** + * Duplicate the specified array, to an array that is managed/allocated by the session memory manager + */ + INDArray dup(INDArray arr); + + /** + * Release the array. All arrays allocated via one of the allocate methods should be returned here once they are no + * longer used, and all references to them should be cleared. + * After calling release, anything could occur to the array - deallocated, workspace closed, reused, etc. + * + * @param array The array that can be released + */ + void release(INDArray array); + + /** + * Close the session memory manager and clean up any memory / resources, if any + */ + void close(); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java new file mode 100644 index 000000000..992a747a0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -0,0 +1,231 @@ +package org.nd4j.autodiff.samediff.internal; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.learning.GradientUpdater; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.primitives.AtomicDouble; + +import java.util.*; + +/** + * TrainingSession extends InferenceSession, to add training-specific functionality:
+ * - Application of regularization (L1, L2, weight decay etc)
+ * - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)
+ * - Calculation of regularization scores (Score for L1, L2, etc) + * + * @author Alex Black + */ +@Slf4j +public class TrainingSession extends InferenceSession { + + protected TrainingConfig config; + protected Map gradVarToVarMap; + protected Map updaters; + protected Map lossVarsToLossIdx; + protected double[] currIterLoss; + protected Map, AtomicDouble> currIterRegLoss; + protected List listeners; + + + public TrainingSession(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Perform one iteration of training - i.e., do forward and backward passes, and update the parameters + * + * @param config Training configuration + * @param placeholders Current placeholders + * @param paramsToTrain Set of parameters that will be trained + * @param updaters Current updater state + * @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map) + * @param lossVariables Loss variables (names) + * @param listeners Listeners (if any) + * @param at Current epoch, iteration, etc + * @return The Loss at the current iteration + */ + public Loss trainingIteration(TrainingConfig config, Map placeholders, Set paramsToTrain, Map updaters, + MultiDataSet batch, List lossVariables, List listeners, At at) { + this.config = config; + this.updaters = updaters; + + //Preprocess listeners, get the relevant ones + if (listeners == null) { + this.listeners = null; + } else { + List filtered = new ArrayList<>(); + for (Listener l : listeners) { + if (l.isActive(at.operation())) { + filtered.add(l); + } + } + this.listeners = filtered.isEmpty() ? null : filtered; + } + + List requiredActivations = new ArrayList<>(); + gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for + for (String s : paramsToTrain) { + Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); + SDVariable v = sameDiff.getVariable(s); + Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s", + s, v.getVariableType()); + SDVariable grad = sameDiff.getVariable(s).getGradient(); + if (grad == null) { + //In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it + //For example: floatVar -> cast to integer -> cast to float -> sum -> loss + //In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss) + continue; + } + + requiredActivations.add(grad.name()); + + gradVarToVarMap.put(grad.name(), s); + } + + //Set up losses + lossVarsToLossIdx = new LinkedHashMap<>(); + List lossVars; + currIterLoss = new double[lossVariables.size()]; + currIterRegLoss = new HashMap<>(); + for (int i = 0; i < lossVariables.size(); i++) { + lossVarsToLossIdx.put(lossVariables.get(i), i); + } + + //Do training iteration + List outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations + Map m = output(outputVars, placeholders, batch, requiredActivations, listeners, at); + + + double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()]; + System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length); + if (currIterRegLoss.size() > 0) { + lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size()); + lossVars.addAll(lossVariables); + int s = currIterRegLoss.size(); + //Collect regularization losses + for (Map.Entry, AtomicDouble> entry : currIterRegLoss.entrySet()) { + lossVars.add(entry.getKey().getSimpleName()); + finalLoss[s] = entry.getValue().get(); + } + } else { + lossVars = lossVariables; + } + + Loss loss = new Loss(lossVars, finalLoss); + if (listeners != null) { + for (Listener l : listeners) { + if (l.isActive(Operation.TRAINING)) { + l.iterationDone(sameDiff, at, batch, loss); + } + } + } + + return loss; + } + + @Override + public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + //Get outputs from InferenceSession + INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + + List outputs = op.getOutputsOfOp(); + int outIdx = 0; + for (String s : outputs) { + //If this is a loss variable - record it + if (lossVarsToLossIdx.containsKey(s)) { + int lossIdx = lossVarsToLossIdx.get(s); + INDArray arr = out[outIdx]; + double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); + currIterLoss[lossIdx] += l; + } + + //If this is a gradient variable - apply the updater and update the parameter array in-line + if (gradVarToVarMap.containsKey(s)) { + String varName = gradVarToVarMap.get(s); + //log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s); + + Variable gradVar = sameDiff.getVariables().get(s); + if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) { + //Should be rare, and we should handle this by tracking dependencies, and only update when safe + // (i.e., dependency tracking) + throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName); + } + + GradientUpdater u = updaters.get(varName); + Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName); + + Variable var = sameDiff.getVariables().get(varName); + INDArray gradArr = out[outIdx]; + INDArray paramArr = var.getVariable().getArr(); + + //Pre-updater regularization (L1, L2) + List r = config.getRegularization(); + if (r != null && r.size() > 0) { + double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { + if (this.listeners != null) { + double score = reg.score(paramArr, at.iteration(), at.epoch()); + if (!currIterRegLoss.containsKey(reg.getClass())) { + currIterRegLoss.put(reg.getClass(), new AtomicDouble()); + } + currIterRegLoss.get(reg.getClass()).addAndGet(score); + } + reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); + } + } + } + + u.applyUpdater(gradArr, at.iteration(), at.epoch()); + + //Post-apply regularization (weight decay) + if (r != null && r.size() > 0) { + double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { + if (this.listeners != null) { + double score = reg.score(paramArr, at.iteration(), at.epoch()); + if (!currIterRegLoss.containsKey(reg.getClass())) { + currIterRegLoss.put(reg.getClass(), new AtomicDouble()); + } + currIterRegLoss.get(reg.getClass()).addAndGet(score); + } + reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); + } + } + } + + if (listeners != null) { + for (Listener l : listeners) { + if (l.isActive(at.operation())) + l.preUpdate(sameDiff, at, var, gradArr); + } + } + + //Update: + if (config.isMinimize()) { + paramArr.subi(gradArr); + } else { + paramArr.addi(gradArr); + } + log.trace("Applied updater to gradient and updated variable: {}", varName); + } + + outIdx++; + } + + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java index 670b21dda..e8041955b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java @@ -35,8 +35,7 @@ public class Variable { protected List controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list protected List controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of - protected List controlDeps; //Control dependencies: name of variables that must be available before this variable is considered available for execution - protected int outputOfOpIdx; //Index of the output for the op (say, variable is output number 2 of op "outputOfOp") + protected List controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution protected SDVariable gradient; //Variable corresponding to the gradient of this variable protected int variableIndex = -1; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java new file mode 100644 index 000000000..e498deaf5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/AbstractMemoryMgr.java @@ -0,0 +1,25 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Abstract memory manager, that implements ulike and dup methods using the underlying allocate methods + * + * @author Alex Black + */ +public abstract class AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray ulike(@NonNull INDArray arr) { + return allocate(false, arr.dataType(), arr.shape()); + } + + @Override + public INDArray dup(@NonNull INDArray arr) { + INDArray out = ulike(arr); + out.assign(arr); + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java new file mode 100644 index 000000000..24992c50b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCloseMemoryMgr.java @@ -0,0 +1,43 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * A simple memory management strategy that deallocates memory as soon as it is no longer needed.
+ * This should result in a minimal amount of memory, but will have some overhead - notably, the cost of deallocating + * and reallocating memory all the time. + * + * @author Alex Black + */ +@Slf4j +public class ArrayCloseMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + return Nd4j.createUninitialized(dataType, shape); + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return Nd4j.create(descriptor, false); + } + + @Override + public void release(@NonNull INDArray array) { + if (!array.wasClosed() && array.closeable()) { + array.close(); + log.trace("Closed array (deallocated) - id={}", array.getId()); + } + } + + @Override + public void close() { + //No-op + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java new file mode 100644 index 000000000..8417bfb35 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.java @@ -0,0 +1,168 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.DependencyList; +import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; +import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.primitives.Pair; + +import java.util.*; + +/** + * A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:
+ * - All arrays that are supposed to be closed, have been closed
+ * - Arrays are only passed to the close method exactly one (unless they are requested outputs)
+ * - Arrays that are passed to the close method were originally allocated by the session memory manager
+ *
+ * How to use:
+ * 1. Perform an inference or training iteration, as normal
+ * 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays
+ *

+ * NOTE: This is intended for debugging and testing only + * + * @author Alex Black + */ +@Slf4j +public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + private final SameDiff sd; + private final SessionMemMgr underlying; + private final Map released = new IdentityHashMap<>(); + + public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) { + this.sd = sd; + this.underlying = underlying; + } + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + INDArray out = underlying.allocate(detached, dataType, shape); + released.put(out, false); + return out; + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + INDArray out = underlying.allocate(detached, descriptor); + released.put(out, false); + return out; + } + + @Override + public void release(INDArray array) { + Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" + + " this memory manager: id=%s", array.getId()); + if (released.get(array)) { + //Already released + InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); + IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); + DependencyList dl = arrayUseTracker.getDependencies(array); + System.out.println(dl); + if (dl.getDependencies() != null) { + for (InferenceSession.Dep d : dl.getDependencies()) { + System.out.println(d + ": " + arrayUseTracker.isSatisfied(d)); + } + } + if (dl.getOrDependencies() != null) { + for (Pair p : dl.getOrDependencies()) { + System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond())); + } + } + } + Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" + + " an earlier release call to this memory manager: id=%s", array.getId()); + log.trace("Released array: id = {}", array.getId()); + released.put(array, true); + } + + @Override + public void close() { + underlying.close(); + } + + /** + * Check that all arrays have been released (after an inference call) except for the specified arrays. + * + * @param except Arrays that should not have been closed (usually network outputs) + */ + public void assertAllReleasedExcept(@NonNull Collection except) { + Set allVarPhConst = null; + + for (INDArray arr : except) { + if (!released.containsKey(arr)) { + //Check if constant, variable or placeholder - maybe user requested that out + if (allVarPhConst == null) + allVarPhConst = identitySetAllConstPhVar(); + if (allVarPhConst.contains(arr)) + continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager + + throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager"); + } + + boolean released = this.released.get(arr); + if (released) { + throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was"); + } + } + + Set exceptSet = Collections.newSetFromMap(new IdentityHashMap()); + exceptSet.addAll(except); + + int numNotClosed = 0; + Set notReleased = Collections.newSetFromMap(new IdentityHashMap()); + InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); + IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); + for (Map.Entry e : released.entrySet()) { + INDArray a = e.getKey(); + if (!exceptSet.contains(a)) { + boolean b = e.getValue(); + if (!b) { + notReleased.add(a); + numNotClosed++; + log.info("Not released: array id {}", a.getId()); + DependencyList list = arrayUseTracker.getDependencies(a); + List l = list.getDependencies(); + List> l2 = list.getOrDependencies(); + if (l != null) { + for (InferenceSession.Dep d : l) { + if (!arrayUseTracker.isSatisfied(d)) { + log.info(" Not satisfied: {}", d); + } + } + } + if (l2 != null) { + for (Pair d : l2) { + if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) { + log.info(" Not satisfied: {}", d); + } + } + } + } + } + } + + if (numNotClosed > 0) { + System.out.println(sd.summary()); + throw new IllegalStateException(numNotClosed + " arrays were not released but should have been"); + } + } + + protected Set identitySetAllConstPhVar() { + Set set = Collections.newSetFromMap(new IdentityHashMap()); + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) { + set.add(v.getArr()); + } + } + return set; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java new file mode 100644 index 000000000..30b891c2f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/NoOpMemoryMgr.java @@ -0,0 +1,42 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * A simple "no-op" memory manager that relies on JVM garbage collector for memory management. + * Assuming other references have been cleared (they should have been) the arrays will be cleaned up by the + * garbage collector at some point. + * + * This memory management strategy is not recommended for performance or memory reasons, and should only be used + * for testing and debugging purposes + * + * @author Alex Black + */ +public class NoOpMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + return Nd4j.createUninitialized(dataType, shape); + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return Nd4j.create(descriptor, false); + } + + @Override + public void release(@NonNull INDArray array) { + //No-op, rely on GC to clear arrays + } + + @Override + public void close() { + //No-op + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 7da89aa36..8d2e9f624 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -3266,7 +3266,7 @@ public abstract class SDBaseOps { if (cond_result.dataType() != DataType.BOOL) - throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean."); + throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); final Set alreadyEntered = Sets.newHashSet(); @@ -3275,7 +3275,7 @@ public abstract class SDBaseOps { for(int i = 0 ; i < loopVars.length ; i++){ SDVariable[] s = f().switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; - alreadyEntered.add(s[1].getVarName()); + alreadyEntered.add(s[1].name()); exits[i] = f().exit(s[0]); } @@ -3290,17 +3290,17 @@ public abstract class SDBaseOps { @Override public SDVariable intercept(SDVariable argument) { - if(!declared.contains(argument.getVarName())) + if(!declared.contains(argument.name())) return argument; - if(alreadyEntered.contains(argument.getVarName())) + if(alreadyEntered.contains(argument.name())) return argument; - if(done.containsKey(argument.getVarName())) - return done.get(argument.getVarName()); + if(done.containsKey(argument.name())) + return done.get(argument.name()); SDVariable e = f().enter(argument, frameName, true); - done.put(argument.getVarName(), e); + done.put(argument.name(), e); return e; } }); @@ -3371,7 +3371,7 @@ public abstract class SDBaseOps { //cleanup partially added block for(SDVariable v : sd().getVariablesInScope(ifScope)) - sd().getVariables().remove(v.getVarName()); + sd().getVariables().remove(v.name()); for(SameDiffOp op : sd().getOpsInScope(ifScope)) { for(String in : op.getInputsToOp()){ @@ -3381,7 +3381,7 @@ public abstract class SDBaseOps { } - throw new IllegalStateException("Can not use " + pred.getVarName() + throw new IllegalStateException("Can not use " + pred.name() + " as the condition of an If statement, the condition must be a boolean."); } @@ -3394,15 +3394,15 @@ public abstract class SDBaseOps { public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it - if(!declared.contains(argument.getVarName())) + if(!declared.contains(argument.name())) return argument; // if we've already added a switch, move on - if(switches.containsKey(argument.getVarName())) - return switches.get(argument.getVarName())[1]; + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[1]; SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.getVarName(), s); + switches.put(argument.name(), s); return s[1]; } }); @@ -3410,9 +3410,9 @@ public abstract class SDBaseOps { SDVariable trueOut = trueBody.define(sd()); sd().removeArgumentInterceptor(); - if(declared.contains(trueOut.getVarName())) { + if(declared.contains(trueOut.name())) { SDVariable[] s = f().switchOp(trueOut, pred); - switches.put(trueOut.getVarName(), s); + switches.put(trueOut.name(), s); trueOut = s[1]; } @@ -3424,15 +3424,15 @@ public abstract class SDBaseOps { public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it - if(!declared2.contains(argument.getVarName())) + if(!declared2.contains(argument.name())) return argument; // if we've already added a switch, move on - if(switches.containsKey(argument.getVarName())) - return switches.get(argument.getVarName())[0]; + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[0]; SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.getVarName(), s); + switches.put(argument.name(), s); return s[0]; } }); @@ -3440,9 +3440,9 @@ public abstract class SDBaseOps { SDVariable falseOut = falseBody.define(sd()); sd().removeArgumentInterceptor(); - if(declared2.contains(falseOut.getVarName())) { + if(declared2.contains(falseOut.name())) { SDVariable[] s = f().switchOp(falseOut, pred); - switches.put(falseOut.getVarName(), s); + switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index a17cb41b1..668a7a4a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -90,10 +90,10 @@ public class SDNN extends SDOps { } /** - * @see #biasAdd(String, SDVariable, SDVariable) + * @see #biasAdd(String, SDVariable, SDVariable, boolean) */ - public SDVariable biasAdd(SDVariable input, SDVariable bias) { - return biasAdd(null, input, bias); + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + return biasAdd(null, input, bias, nchw); } /** @@ -102,12 +102,14 @@ public class SDNN extends SDOps { * @param name Name of the output variable * @param input 4d input variable * @param bias 1d bias + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs * @return Output variable */ - public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) { + public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { validateFloatingPoint("biasAdd", "input", input); validateFloatingPoint("biasAdd", "bias", bias); - SDVariable ret = f().biasAdd(input, bias); + SDVariable ret = f().biasAdd(input, bias, nchw); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java index d752facd4..f6434a56f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java @@ -37,7 +37,7 @@ public class SDValidation { if (v == null) return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-numerical data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-numerical data type " + v.dataType()); } /** @@ -52,7 +52,7 @@ public class SDValidation { return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + - v.getVarName() + "\" with non-integer data type " + v.dataType()); + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -65,8 +65,8 @@ public class SDValidation { */ protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) { if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + - v2.getVarName() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" + + v2.name() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); } /** @@ -79,7 +79,7 @@ public class SDValidation { if (v == null) return; if (!v.dataType().isIntType()) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-integer data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -94,7 +94,7 @@ public class SDValidation { return; if (!v.dataType().isIntType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + - v.getVarName() + "\" with non-integer data type " + v.dataType()); + v.name() + "\" with non-integer data type " + v.dataType()); } /** @@ -107,7 +107,7 @@ public class SDValidation { if (v == null) return; if (!v.dataType().isFPType()) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-floating point data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-floating point data type " + v.dataType()); } /** @@ -122,7 +122,7 @@ public class SDValidation { return; if (!v.dataType().isFPType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" + - v.getVarName() + "\" with non-floating point data type " + v.dataType()); + v.name() + "\" with non-floating point data type " + v.dataType()); } /** @@ -135,7 +135,7 @@ public class SDValidation { if (v == null) return; if (v.dataType() != DataType.BOOL) - throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-boolean point data type " + v.dataType()); + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-boolean point data type " + v.dataType()); } /** @@ -150,7 +150,7 @@ public class SDValidation { return; if (v.dataType() != DataType.BOOL) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" + - v.getVarName() + "\" with non-boolean data type " + v.dataType()); + v.name() + "\" with non-boolean data type " + v.dataType()); } /** @@ -162,8 +162,8 @@ public class SDValidation { */ protected static void validateBool(String opName, SDVariable v1, SDVariable v2) { if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) - throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + - v2.getVarName() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" + + v2.name() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); } /** @@ -190,7 +190,7 @@ public class SDValidation { String[] names = new String[vars.length]; DataType[] dtypes = new DataType[vars.length]; for (int j = 0; j < vars.length; j++) { - names[j] = vars[j].getVarName(); + names[j] = vars[j].name(); dtypes[j] = vars[j].dataType(); } throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index cce38cf24..e89047ee4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.serde; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; @@ -762,7 +763,7 @@ public class FlatBuffersMapper { SDVariable[] inputs = node.args(); for (SDVariable input : inputs) { - String varName = input.getVarName(); + String varName = input.name(); int outIdx; if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) { DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp(); @@ -847,6 +848,28 @@ public class FlatBuffersMapper { } int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes); + //Control dependencies: + SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName()); + + int opCds = 0; + int[] opCdsArr = mapOrNull(sdo.getControlDeps(), bufferBuilder); + if(opCdsArr != null){ + opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr); + } + + int varCds = 0; + int[] varCdsArr = mapOrNull(sdo.getVarControlDeps(), bufferBuilder); + if(varCdsArr != null){ + varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr); + } + + int cdsFor = 0; + int[] cdsForArr = mapOrNull(sdo.getControlDepFor(), bufferBuilder); + if(cdsForArr != null){ + cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr); + } + + int flatNode = FlatNode.createFlatNode( bufferBuilder, ownId, @@ -867,12 +890,26 @@ public class FlatBuffersMapper { outVarNamesOffset, opNameOffset, outTypesOffset, //Output types - scalar + scalar, + opCds, + varCds, + cdsFor ); return flatNode; } + public static int[] mapOrNull(List list, FlatBufferBuilder fbb){ + if(list == null) + return null; + int[] out = new int[list.size()]; + int i=0; + for(String s : list){ + out[i++] = fbb.createString(s); + } + return out; + } + public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){ Map nameToIdxMap = new HashMap<>(); int count = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index eeb6b1b78..f76c42c50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -133,16 +133,9 @@ public class LegacyOpMapper { public static Class aggregateOpClass(int opNum) { switch (opNum) { - case 0: - return HierarchicSoftmax.class; - case 1: - return AggregateDot.class; + case 2: return AggregateAxpy.class; - case 3: - return AggregateSkipGram.class; - case 4: - return AggregateCBOW.class; case 5: return AggregateGEMM.class; default: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java index afe8551f3..44bd9b79e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java @@ -69,8 +69,8 @@ public class GraphTransformUtil { // we want to end up with (x -> A -> z) List allSubGraphFns = sg.allFunctionsInSubgraph(); for (int i = 0; i < oldOutputs.size(); i++) { - String oldOutVarName = oldOutputs.get(i).getVarName(); - String newOutVarName = newOutputs.get(i).getVarName(); + String oldOutVarName = oldOutputs.get(i).name(); + String newOutVarName = newOutputs.get(i).name(); Preconditions.checkState(!oldOutVarName.equals(newOutVarName), "Reusing old variables not yet implemented"); //Update inputs for ops: if X->opA, and now Y->opA, then X.inputsForOps contains "opA"; Y.inputsForOps should be updated @@ -133,7 +133,7 @@ public class GraphTransformUtil { //Step 2: Update input variables: if X -> (subgraph) exists, then X.inputsForOp needs to be updated List inputs = sg.inputs(); for (SDVariable v : inputs) { - Variable var = sd.getVariables().get(v.getVarName()); + Variable var = sd.getVariables().get(v.name()); if (var.getInputsForOp() != null) { List newInputsForOp = new ArrayList<>(var.getInputsForOp()); for (String opName : var.getInputsForOp()) { @@ -160,7 +160,7 @@ public class GraphTransformUtil { SDVariable[] outputs = df.outputVariables(); if (outputs != null) { for (SDVariable v : outputs) { - vars.remove(v.getVarName()); + vars.remove(v.name()); } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java index c9f1f52bf..6514ee49e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java @@ -62,7 +62,7 @@ public class SubGraph { //But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else List filteredOutputs = new ArrayList<>(allOutputs.size()); for(SDVariable v : allOutputs){ - Variable var = sameDiff.getVariables().get(v.getVarName()); + Variable var = sameDiff.getVariables().get(v.name()); List inputsFor = var.getInputsForOp(); boolean allInSubgraph = true; if(inputsFor != null){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java index 5d7e117a2..bf6ba09fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java @@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate { } SDVariable in = inputs[inNum]; - DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName()); + DifferentialFunction df = sameDiff.getVariableOutputOp(in.name()); if (df == null || !e.getValue().matches(sameDiff, df)) { return false; } @@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate { for(Map.Entry entry : opInputSubgraphPredicates.entrySet()){ OpPredicate p2 = entry.getValue(); SDVariable arg = rootFn.arg(entry.getKey()); - DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName()); + DifferentialFunction df = sd.getVariableOutputOp(arg.name()); if(df != null){ childNodes.add(df); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index b8625afde..0f1e0bd52 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -107,7 +107,7 @@ public class GradCheckUtil { Set fnOutputs = new HashSet<>(); for(DifferentialFunction f : sd.ops()){ for(SDVariable s : f.outputVariables()){ - fnOutputs.add(s.getVarName()); + fnOutputs.add(s.name()); } } @@ -131,12 +131,12 @@ public class GradCheckUtil { // in this case, gradients of x and y are all 0 too //Collect variables to get gradients for - we want placeholders AND variables - Set gradVarNames = new HashSet<>(); + Set varsNeedingGrads = new HashSet<>(); for(Variable v : sd.getVariables().values()){ if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){ SDVariable g = v.getVariable().getGradient(); Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable()); - gradVarNames.add(g.getVarName()); + varsNeedingGrads.add(v.getName()); } } @@ -164,14 +164,14 @@ public class GradCheckUtil { } - sd.execBackwards(placeholderValues, new ArrayList<>(gradVarNames)); + Map gm = sd.calculateGradients(placeholderValues, varsNeedingGrads); //Remove listener, to reduce overhead sd.getListeners().remove(listenerIdx); Map grad = new HashMap<>(); for(SDVariable v : sd.variables()){ - if (fnOutputs.contains(v.getVarName())) { + if (fnOutputs.contains(v.name())) { //This is not an input to the graph continue; } @@ -179,20 +179,20 @@ public class GradCheckUtil { //Skip non-fp variables, or variables that don't impact loss function value continue; } - SDVariable g = sd.grad(v.getVarName()); + SDVariable g = sd.grad(v.name()); if(g == null){ - throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\""); + throw new IllegalStateException("Null gradient variable for \"" + v.name() + "\""); } - INDArray ga = g.getArr(); + INDArray ga = gm.get(v.name()); if(ga == null){ - throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName()); + throw new IllegalStateException("Null gradient array encountered for variable: " + v.name()); } - if(!Arrays.equals(v.getArr().shape(), g.getArr().shape())){ + if(!Arrays.equals(v.getArr().shape(), ga.shape())){ throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + - v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + + v.name() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape())); } - grad.put(v.getVarName(), ga.dup()); + grad.put(v.name(), ga.dup()); } //Validate gradients for each variable: @@ -201,25 +201,25 @@ public class GradCheckUtil { double maxError = 0.0; Random r = new Random(12345); for(SDVariable s : sd.variables()){ - if (fnOutputs.contains(s.getVarName()) || !s.dataType().isFPType()) { + if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) { //This is not an input to the graph, or is not a floating point input (so can't be gradient checked) continue; } - if(skipVariables != null && skipVariables.contains(s.getVarName())){ - log.info("Grad check: skipping variable \"{}\"", s.getVarName()); + if(skipVariables != null && skipVariables.contains(s.name())){ + log.info("Grad check: skipping variable \"{}\"", s.name()); continue; } if(s.dataType() != DataType.DOUBLE){ - log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.getVarName(), s.dataType()); + log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType()); } - String name = s.getVarName(); + String name = s.name(); INDArray a = s.getArr(); long n = a.length(); if(print){ - log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n); + log.info("Starting test for variable \"{}\" with {} values", s.name(), n); } Iterator iter; @@ -256,11 +256,11 @@ public class GradCheckUtil { iter = new NdIndexIterator('c',a.shape()); } - INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.getVarName())); + INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.name())); if(varMask != null){ - Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.getVarName(), a.shape(), varMask.shape()); - Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.getVarName(), varMask.dataType()); + Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.name(), a.shape(), varMask.shape()); + Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType()); } int i=0; @@ -281,12 +281,12 @@ public class GradCheckUtil { double orig = a.getDouble(idx); a.putScalar(idx, orig+eps); double scorePlus = 0.0; - Map m = sd.exec(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue(); + Map m = sd.output(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue(); for(INDArray arr : m.values()){ scorePlus += arr.sumNumber().doubleValue(); } a.putScalar(idx, orig-eps); - m = sd.exec(placeholderValues, lossFnVariables); + m = sd.output(placeholderValues, lossFnVariables); double scoreMinus = 0.0; for(INDArray arr : m.values()){ scoreMinus += arr.sumNumber().doubleValue(); @@ -294,9 +294,9 @@ public class GradCheckUtil { a.putScalar(idx, orig); double numericalGrad = (scorePlus - scoreMinus) / (2 * eps); - INDArray aGrad = grad.get(s.getVarName()); + INDArray aGrad = grad.get(s.name()); if(aGrad == null){ - log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.getVarName()); + log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.name()); continue; } double analyticGrad = aGrad.getDouble(idx); @@ -408,18 +408,18 @@ public class GradCheckUtil { //Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations sd.createGradFunction(); - Set gradVarNames = new HashSet<>(); + Set varsRequiringGrads = new HashSet<>(); for(String s : actGrads){ SDVariable grad = sd.getVariable(s).gradient(); Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s); - gradVarNames.add(grad.getVarName()); + varsRequiringGrads.add(s); } //Calculate analytical gradients - sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames)); + Map grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<>(varsRequiringGrads)); Map gradientsForAct = new HashMap<>(); for(String s : actGrads){ - INDArray arr = sd.getVariable(s).gradient().getArr(); + INDArray arr = grads.get(s); Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s); gradientsForAct.put(s, arr.dup()); } @@ -497,12 +497,12 @@ public class GradCheckUtil { listener.setIdx(idx); listener.setEps(config.getEps()); double scorePlus = 0.0; - Map m = sd.exec(config.getPlaceholderValues(), lossFnVariables); + Map m = sd.output(config.getPlaceholderValues(), lossFnVariables); for(INDArray arr : m.values()){ scorePlus += arr.sumNumber().doubleValue(); } listener.setEps(-config.getEps()); - m = sd.exec(config.getPlaceholderValues(), lossFnVariables); + m = sd.output(config.getPlaceholderValues(), lossFnVariables); double scoreMinus = 0.0; for(INDArray arr : m.values()){ scoreMinus += arr.sumNumber().doubleValue(); @@ -597,10 +597,10 @@ public class GradCheckUtil { Set varSetStr = new HashSet<>(); for(SDVariable v : vars){ - if(varSetStr.contains(v.getVarName())){ - throw new IllegalStateException("Variable with name " + v.getVarName() + " already encountered"); + if(varSetStr.contains(v.name())){ + throw new IllegalStateException("Variable with name " + v.name() + " already encountered"); } - varSetStr.add(v.getVarName()); + varSetStr.add(v.name()); } Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list"); @@ -645,7 +645,7 @@ public class GradCheckUtil { Map variableMap = sd.getVariables(); Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed"); for(Map.Entry e : variableMap.entrySet()){ - Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().getVarName()), "Name not equal"); + Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().name()), "Name not equal"); } if(generateAndCheckGradFn) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 74c1d868d..fc7572180 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -190,11 +190,13 @@ public class OpValidation { //Check forward pass: if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) { SameDiff sd = testCase.sameDiff(); + + //Collect variables we need outputs for... + Set reqVars = testCase.fwdTestFns().keySet(); + + Map out; try { - if(testCase.placeholderValues() != null){ - sd.resolveVariablesWith(testCase.placeholderValues()); - } - sd.exec(null, sd.outputs()); + out = sd.output(testCase.placeholderValues(), new ArrayList<>(reqVars)); } catch (Exception e) { throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e); } @@ -206,7 +208,7 @@ public class OpValidation { e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg()); } - INDArray actual = v.getArr(); + INDArray actual = out.get(v.name()); if (actual == null) { throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\""); } @@ -269,8 +271,8 @@ public class OpValidation { for( int i=0; i varsBefore = original.getVariables(); Map varsAfter = deserialized.getVariables(); Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet()); + +// System.out.println(original.summary()); +// System.out.println("\n\n\n\n"); +// System.out.println(deserialized.summary()); + for(String s : varsBefore.keySet()){ Variable vB = varsBefore.get(s); Variable vA = varsAfter.get(s); @@ -324,13 +337,15 @@ public class OpValidation { Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(), "Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType()); - Preconditions.checkState((vB.getInputsForOp() == null) == (vA.getInputsForOp() == null), "Input to ops differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp()); - Preconditions.checkState(vB.getInputsForOp() == null || vB.getInputsForOp().equals(vA.getInputsForOp()), "Inputs differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp()); + equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp()); - Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "Output of op differ: %s vs. %s", vB.getOutputOfOp(), vA.getOutputOfOp()); + Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", s, vB.getOutputOfOp(), vA.getOutputOfOp()); - Preconditions.checkState((vB.getControlDeps() == null) == (vA.getControlDeps() == null), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps()); - Preconditions.checkState(vB.getControlDeps() == null || vB.getControlDeps().equals(vA.getControlDeps()), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps()); + equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps()); + + equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp()); + + equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar()); } //Check loss variables: @@ -343,51 +358,62 @@ public class OpValidation { lossVarBefore, lossVarAfter); } + if(tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) { + //Finally: check execution/output + Map outOrig = original.outputAll(tc.placeholderValues()); + Map outDe = deserialized.outputAll(tc.placeholderValues()); + Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model"); - //Finally: check execution/output - Map outOrig = original.outputAll(tc.placeholderValues()); - Map outDe = deserialized.outputAll(tc.placeholderValues()); - Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model"); + for (String s : outOrig.keySet()) { + INDArray orig = outOrig.get(s); + INDArray deser = outDe.get(s); - for(String s : outOrig.keySet()){ - INDArray orig = outOrig.get(s); - INDArray deser = outDe.get(s); - - Function f = tc.fwdTestFns().get(s); - String err = null; - if(f != null){ - err = f.apply(deser); - } else { - if(!orig.equals(deser)){ - //Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op) - long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1; - if(orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)){ - long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue(); - if(count != count2){ - err = "INDArray equality failed"; - } else { - //TODO is there a better way to do this? - NdIndexIterator iter = new NdIndexIterator(orig.shape()); - while(iter.hasNext()){ - long[] i = iter.next(); - double d1 = orig.getDouble(i); - double d2 = deser.getDouble(i); - if((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5 ){ - err = "INDArray equality failed"; - break; + Function f = tc.fwdTestFns().get(s); + String err = null; + if (f != null) { + err = f.apply(deser); + } else { + if (!orig.equals(deser)) { + //Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op) + long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1; + if (orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)) { + long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue(); + if (count != count2) { + err = "INDArray equality failed"; + } else { + //TODO is there a better way to do this? + NdIndexIterator iter = new NdIndexIterator(orig.shape()); + while (iter.hasNext()) { + long[] i = iter.next(); + double d1 = orig.getDouble(i); + double d2 = deser.getDouble(i); + if ((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5) { + err = "INDArray equality failed"; + break; + } } } + } else { + err = "INDArray equality failed"; } - } else { - err = "INDArray equality failed"; } } - } - Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err); + Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err); + } } } + protected static void equalConsideringNull(List l1, List l2, String msg, Object... args){ + //Consider null and length 0 list to be equal (semantically they mean the same thing) + boolean empty1 = l1 == null || l1.isEmpty(); + boolean empty2 = l2 == null || l2.isEmpty(); + if(empty1 && empty2){ + return; + } + Preconditions.checkState(l1 == null || l1.equals(l2), msg, args); + } + /** * Validate the outputs of a single op * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java index b10ed2bb3..fad760bb3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java @@ -121,7 +121,7 @@ public class TestCase { * @param output Expected INDArray */ public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) { - return expected(var.getVarName(), output); + return expected(var.name(), output); } /** @@ -135,7 +135,7 @@ public class TestCase { } public TestCase expected(SDVariable var, Function validationFn){ - return expected(var.getVarName(), validationFn); + return expected(var.name(), validationFn); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index f5ae0693d..7e7a50ab2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -25,6 +25,7 @@ public class NonInplaceValidationListener extends BaseListener { private static AtomicInteger failCounter = new AtomicInteger(); protected INDArray[] opInputs; + protected INDArray[] opInputsOrig; public NonInplaceValidationListener(){ useCounter.getAndIncrement(); @@ -42,14 +43,18 @@ public class NonInplaceValidationListener extends BaseListener { //No input op return; } else if(o.y() == null){ + opInputsOrig = new INDArray[]{o.x()}; opInputs = new INDArray[]{o.x().dup()}; } else { + opInputsOrig = new INDArray[]{o.x(), o.y()}; opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); opInputs = new INDArray[arr.length]; + opInputsOrig = new INDArray[arr.length]; for( int i=0; i return -1; } - // FIXME: int cast return (int) rDiagBinTotalCount.size(1); } @@ -394,7 +393,6 @@ public class EvaluationCalibration extends BaseEvaluation double[] mpb = meanPredictionBins; double[] fp = fracPositives; - // FIXME: int cast meanPredictionBins = new double[(int) (totalCountBins.length() - numZeroBins)]; fracPositives = new double[meanPredictionBins.length]; int j = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java index 3695a692b..a4d54e3c7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java @@ -154,7 +154,6 @@ public class ROCBinary extends BaseEvaluation { if(labels2d.dataType() != predictions2d.dataType()) labels2d = labels2d.castTo(predictions2d.dataType()); - // FIXME: int cast int n = (int) labels2d.size(1); if (underlying == null) { underlying = new ROC[n]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java index 07266399a..a2a1ed16e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java @@ -208,7 +208,6 @@ public class ROCMultiClass extends BaseEvaluation { if(labels2d.dataType() != predictions2d.dataType()) labels2d = labels2d.castTo(predictions2d.dataType()); - // FIXME: int cast int n = (int) labels2d.size(1); if (underlying == null) { underlying = new ROC[n]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index e1a0d1f82..cc206f0df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -216,7 +216,6 @@ public class RegressionEvaluation extends BaseEvaluation { } private static List createDefaultColumnNames(long nColumns) { - // FIXME: int cast List list = new ArrayList<>((int) nColumns); for (int i = 0; i < nColumns; i++) list.add("col_" + i); @@ -244,7 +243,6 @@ public class RegressionEvaluation extends BaseEvaluation { labels = labels.castTo(predictions.dataType()); if (!initialized) { - // FIXME: int cast initialize((int) labels.size(1)); } //References for the calculations is this section: @@ -394,7 +392,6 @@ public class RegressionEvaluation extends BaseEvaluation { if (exampleCountPerColumn == null) { return 0; } - // FIXME: int cast return (int) exampleCountPerColumn.size(1); } return columnNames.size(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java index 48b7a69bc..ef116b97b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java @@ -67,29 +67,41 @@ public final class FlatNode extends Table { public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } public FlatArray scalar() { return scalar(new FlatArray()); } public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } + public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; } + public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } + public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public static int createFlatNode(FlatBufferBuilder builder, - int id, - int nameOffset, - byte opType, - long opNum, - int propertiesOffset, - int inputOffset, - int inputPairedOffset, - int outputOffset, - int extraParamsOffset, - int extraIntegerOffset, - int extraBoolsOffset, - int dimensionsOffset, - int device, - int scope_id, - int scope_nameOffset, - int outputNamesOffset, - int opNameOffset, - int outputTypesOffset, - int scalarOffset) { - builder.startObject(19); + int id, + int nameOffset, + byte opType, + long opNum, + int propertiesOffset, + int inputOffset, + int inputPairedOffset, + int outputOffset, + int extraParamsOffset, + int extraIntegerOffset, + int extraBoolsOffset, + int dimensionsOffset, + int device, + int scope_id, + int scope_nameOffset, + int outputNamesOffset, + int opNameOffset, + int outputTypesOffset, + int scalarOffset, + int controlDepsOffset, + int varControlDepsOffset, + int controlDepForOffset) { + builder.startObject(22); FlatNode.addOpNum(builder, opNum); + FlatNode.addControlDepFor(builder, controlDepForOffset); + FlatNode.addVarControlDeps(builder, varControlDepsOffset); + FlatNode.addControlDeps(builder, controlDepsOffset); FlatNode.addScalar(builder, scalarOffset); FlatNode.addOutputTypes(builder, outputTypesOffset); FlatNode.addOpName(builder, opNameOffset); @@ -111,7 +123,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -151,6 +163,15 @@ public final class FlatNode extends Table { public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); } + public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } + public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java index 76335c1ae..4845f7320 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java @@ -29,16 +29,28 @@ public final class FlatVariable extends Table { public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } + public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; } + public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; } + public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public static int createFlatVariable(FlatBufferBuilder builder, - int idOffset, - int nameOffset, - byte dtype, - int shapeOffset, - int ndarrayOffset, - int device, - byte variabletype) { - builder.startObject(7); + int idOffset, + int nameOffset, + byte dtype, + int shapeOffset, + int ndarrayOffset, + int device, + byte variabletype, + int controlDepsOffset, + int controlDepForOpOffset, + int controlDepsForVarOffset) { + builder.startObject(10); + FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); + FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); + FlatVariable.addControlDeps(builder, controlDepsOffset); FlatVariable.addDevice(builder, device); FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addShape(builder, shapeOffset); @@ -49,7 +61,7 @@ public final class FlatVariable extends Table { return FlatVariable.endFlatVariable(builder); } - public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } + public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } @@ -59,6 +71,15 @@ public final class FlatVariable extends Table { public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } + public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); } + public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); } + public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); } + public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endFlatVariable(FlatBufferBuilder builder) { int o = builder.endObject(); return o; @@ -67,3 +88,4 @@ public final class FlatVariable extends Table { public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); } } + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java index 0165a18f3..7bba2b765 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/ui/LogFileWriter.java @@ -487,11 +487,15 @@ public class LogFileWriter { //Create outputs list: List outputs = sd.outputs(); - int[] outputListStrOffsets = new int[outputs.size()]; - for (int i = 0; i < outputListStrOffsets.length; i++) { - outputListStrOffsets[i] = fbb.createString(outputs.get(i)); + int outputsOffset = 0; + if(outputs != null && !outputs.isEmpty()) { + int[] outputListStrOffsets = new int[outputs.size()]; + for (int i = 0; i < outputListStrOffsets.length; i++) { + outputListStrOffsets[i] = fbb.createString(outputs.get(i)); + } + outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets); } - int outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets); + //Create variables list Map varMap = sd.getVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 82bfdc843..05ac2495c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -25,11 +25,7 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser; import org.nd4j.imports.descriptors.onnx.OpDescriptor; import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.linalg.api.ops.*; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -370,6 +366,8 @@ public class DifferentialFunctionClassHolder { return Merge.class; case Switch.OP_NAME: return Switch.class; + case LoopCond.OP_NAME: + return LoopCond.class; case ExternalErrorsFunction.OP_NAME: return ExternalErrorsFunction.class; default: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index fcf3fe630..f7bbc0620 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -46,6 +46,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, + org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.Flatten.class, org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd.class, @@ -69,13 +70,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class, - org.nd4j.linalg.api.ops.impl.controlflow.If.class, - org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative.class, org.nd4j.linalg.api.ops.impl.controlflow.Select.class, org.nd4j.linalg.api.ops.impl.controlflow.Where.class, org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class, - org.nd4j.linalg.api.ops.impl.controlflow.While.class, - org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class, @@ -326,7 +323,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.BinCount.class, org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class, org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class, - org.nd4j.linalg.api.ops.impl.transforms.Constant.class, org.nd4j.linalg.api.ops.impl.transforms.Histogram.class, org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class, org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class, @@ -581,8 +577,14 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class, org.nd4j.linalg.api.ops.random.impl.Range.class, org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class, - org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class - + org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, + org.nd4j.linalg.api.ops.custom.AdjustContrast.class, + org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, + org.nd4j.linalg.api.ops.custom.BitCast.class, + org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, + org.nd4j.linalg.api.ops.custom.DivideNoNan.class, + org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, + org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java deleted file mode 100644 index 95f238973..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java +++ /dev/null @@ -1,413 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper; - -import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.protobuf.TextFormat; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.io.IOUtils; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.VariableType; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.autodiff.samediff.internal.Variable; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.exception.ND4JIllegalStateException; - -import java.io.*; -import java.util.*; - -/** - * Base implementation for importing a graph - * - * @param the type of graph - * @param the type of node - * @param the attribute type - */ -@Slf4j -public abstract class BaseGraphMapper implements GraphMapper { - - - @Override - public Op.Type opTypeForNode(NODE_TYPE nodeDef) { - DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef)); - if (opWithTensorflowName == null) - throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef)); - Op.Type type = opWithTensorflowName.opType(); - return type; - - } - - - @Override - public void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings) { - val mappings = propertyMappings.get(getOpType(node)); - if (mappings == null || mappings.isEmpty()) { - return; - } - - - for (val entry : mappings.entrySet()) { - mapProperty(entry.getKey(), on, node, graph, sameDiff, propertyMappings); - } - } - - - /** - * @param inputStream - * @return - */ - @Override - public SameDiff importGraph(InputStream inputStream) { - return importGraph(inputStream, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(InputStream inputStream, Map> opImportOverrides, - OpImportFilter opFilter) { - GRAPH_TYPE def = readGraph(inputStream, opImportOverrides); - return importGraph(def, opImportOverrides, opFilter); - } - - protected GRAPH_TYPE readGraph(InputStream inputStream, Map> opImportOverrides) { - byte[] bytes = null; - GRAPH_TYPE def = null; - try { - bytes = IOUtils.toByteArray(inputStream); //Buffers internally - def = parseGraphFrom(bytes); - } catch (IOException e) { - try (BufferedInputStream bis2 = new BufferedInputStream(new ByteArrayInputStream(bytes)); BufferedReader reader = new BufferedReader(new InputStreamReader(bis2))) { - Message.Builder builder = getNewGraphBuilder(); - - StringBuilder str = new StringBuilder(); - String line = null; - while ((line = reader.readLine()) != null) { - str.append(line);//.append("\n"); - } - - TextFormat.getParser().merge(str.toString(), builder); - def = (GRAPH_TYPE) builder.build(); - } catch (Exception e2) { - e2.printStackTrace(); - } - } - - return def; - } - - /** - * @param graphFile - * @return - */ - @Override - public SameDiff importGraph(File graphFile) { - return importGraph(graphFile, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(File graphFile, Map> opImportOverrides, - OpImportFilter opFilter) { - GRAPH_TYPE def = null; - try (FileInputStream fis = new FileInputStream(graphFile)) { - return importGraph(fis, opImportOverrides, opFilter); - } catch (Exception e) { - throw new ND4JIllegalStateException("Error encountered loading graph file: " + graphFile.getAbsolutePath(), e); - } - } - - @Override - public Map nameIndexForGraph(GRAPH_TYPE graph) { - List nodes = getNodeList(graph); - Map ret = new HashMap<>(); - for (NODE_TYPE node : nodes) { - ret.put(getName(node), node); - } - return ret; - } - - @Override - public Map nodesByName(GRAPH_TYPE graph) { - val nodeTypes = getNodeList(graph); - Map ret = new LinkedHashMap<>(); - for (int i = 0; i < nodeTypes.size(); i++) { - ret.put(getName(nodeTypes.get(i)), nodeTypes.get(i)); - } - return ret; - } - - /** - * This method converts given TF - * - * @param tfGraph - * @return - */ - @Override - public SameDiff importGraph(GRAPH_TYPE tfGraph) { - return importGraph(tfGraph, Collections.>emptyMap(), null); - } - - @Override - public SameDiff importGraph(GRAPH_TYPE tfGraph, Map> opImportOverrides, - OpImportFilter opFilter) { - - SameDiff diff = SameDiff.create(); - ImportState importState = new ImportState<>(); - importState.setSameDiff(diff); - importState.setGraph(tfGraph); - - Map variablesForGraph = variablesForGraph(tfGraph); - importState.setVariables(variablesForGraph); - - - //Add each of the variables first - before importing ops - Map stringNodes = new HashMap<>(); //Key: name of string variable. Value: if it's a constant - for (Map.Entry entry : variablesForGraph.entrySet()) { - if (shouldSkip((NODE_TYPE) entry.getValue())) { //TODO only works for TF - //Skip some nodes, for example reduction indices (a lot of ND4J/SameDiff ops use int[] etc, not an INDArray/Variable) - continue; - } - - //First: check if we're skipping the op entirely. If so: don't create the output variables for it. - NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF - String opType = getOpType(node); - String opName = getName(node); - if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, importState.getGraph() )){ - log.info("Skipping variables for op: {} (name: {})", opType, opName); - continue; - } - - //Similarly, if an OpImportOverride is defined, don't create the variables now, as these might be the wrong type - //For example, the OpImportOverride might replace the op with some placeholders - // If we simply created them now, we'd create the wrong type (Array not placeholder) - if(opImportOverrides != null && opImportOverrides.containsKey(opType)){ - log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, opName); - continue; - } - - - DataType dt = dataTypeForTensor(entry.getValue(), 0); - INDArray arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph); - long[] shape = hasShape((NODE_TYPE) entry.getValue()) ? getShape((NODE_TYPE) entry.getValue()) : null; //TODO only works for TF - - //Not all variables have datatypes available on import - we have to infer these at a later point - // so we'll leave datatypes as null and infer them once all variables/ops have been imported - if(dt == DataType.UNKNOWN) - dt = null; - - if (isPlaceHolder(entry.getValue())) { - diff.placeHolder(entry.getKey(), dt, shape); - } else if (isConstant(entry.getValue())) { - Preconditions.checkNotNull(arr, "Array is null for placeholder variable %s", entry.getKey()); - diff.constant(entry.getKey(), arr); - } else { - //Could be variable, or could be array type (i.e., output of op/"activations") - //TODO work out which! - - SDVariable v; - if(shape == null || ArrayUtil.contains(shape, 0)){ - //No shape, or 0 in shape -> probably not a variable... - v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null); - } else { - v = diff.var(entry.getKey(), dt, shape); - } - if (arr != null) - diff.associateArrayWithVariable(arr, v); - } - -// NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF - List controlDependencies = getControlDependencies(node); - if (controlDependencies != null) { - Variable v = diff.getVariables().get(entry.getKey()); - v.setControlDeps(controlDependencies); - } - } - - //Map ops - val tfNodesList = getNodeList(tfGraph); - for (NODE_TYPE node : tfNodesList) { - String opType = getOpType(node); - OpImportOverride importOverride = null; - if(opImportOverrides != null){ - importOverride = opImportOverrides.get(opType); - } - - if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, null)){ - String opName = getName(node); - log.info("Skipping op due to op filter: {}", opType, opName); - continue; - } - - if (!opsToIgnore().contains(opType) || isOpIgnoreException(node)) { - mapNodeType(node, importState, importOverride, opFilter); - } - } - - - /* - At this point, we have a few remaining things to do: - 1. Make sure all datatypes are set on all variables. TF doesn't have datatype info an all op outputs for some reason, so we have to infer in manually - 2. Make sure all op output variables have been created - 3. Make sure all SameDiffOp.outputsOfOp is set - 4. Make sure all Variable.outputOfOp is set - 5. Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps) - */ - - //Make sure Variable.outputOfOp is set - for(Variable v : diff.getVariables().values()){ - if(v.getVariable().isPlaceHolder() || v.getVariable().isConstant()) - continue; - - //Expect variable names of output variables to be: opName, opName:1, opName:2, etc - String n = v.getName(); - String opName = n; - if(v.getName().matches(".*:\\d+")){ - //i.e., "something:2" - int idx = n.lastIndexOf(':'); - opName = n.substring(0,idx); - } - - if(diff.getOps().containsKey(opName)) { - //Variable is the output of an op - v.setOutputOfOp(opName); - - //Also double check variable type... - if(v.getVariable().getVariableType() != VariableType.ARRAY) - v.getVariable().setVariableType(VariableType.ARRAY); - } - } - - //Initialize any missing output variables - for (SameDiffOp op : diff.getOps().values()) { - DifferentialFunction df = op.getOp(); - initOutputVariables(diff, df); - } - - //Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps) - //i.e., if control dependency x -> y exists, then: - // (a) x.controlDepsForVar should contain "y" - // (b) y.controlDeps should contain "x" - //Need to do this before output datatype calculation, as these control dep info is used in sessions - for(Map.Entry e : diff.getVariables().entrySet()){ - Variable v = e.getValue(); - if(v.getControlDeps() != null){ - for(String s : v.getControlDeps()){ - Variable v2 = diff.getVariables().get(s); - if(v2.getControlDepsForVar() == null) - v2.setControlDepsForVar(new ArrayList()); - if(!v2.getControlDepsForVar().contains(e.getKey())){ - //Control dep v2 -> v exists, so put v.name into v2.controlDepsForVar - v2.getControlDepsForVar().add(e.getKey()); - } - } - } - } - - //Same thing for op control dependencies... - for(Map.Entry e : diff.getOps().entrySet()){ - SameDiffOp op = e.getValue(); - if(op.getControlDeps() != null){ - for(String s : op.getControlDeps()){ - //Control dependency varS -> op exists - Variable v = diff.getVariables().get(s); - if(v.getControlDepsForOp() == null) - v.setControlDepsForOp(new ArrayList()); - if(!v.getControlDepsForOp().contains(e.getKey())) - v.getControlDepsForOp().add(e.getKey()); - } - } - } - - - //Infer variable datatypes to ensure all variables have datatypes... - boolean anyUnknown = false; - for(SDVariable v : diff.variables()){ - if(v.dataType() == null) - anyUnknown = true; - } - if(anyUnknown){ - Map dataTypes = diff.calculateOutputDataTypes(); - for(SDVariable v : diff.variables()){ - if(v.dataType() == null){ - v.setDataType(dataTypes.get(v.getVarName())); - } - } - } - - //Validate the graph structure - validateGraphStructure(diff); - - return diff; - } - - protected void initOutputVariables(SameDiff sd, DifferentialFunction df) { - String[] outNames = sd.getOutputsForOp(df); - SDVariable[] outVars; - if (outNames == null) { - outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true); - outNames = new String[outVars.length]; - for (int i = 0; i < outVars.length; i++) { - outNames[i] = outVars[i].getVarName(); - } - sd.getOps().get(df.getOwnName()).setOutputsOfOp(Arrays.asList(outNames)); - } - - for (String s : outNames) { - sd.getVariables().get(s).setOutputOfOp(df.getOwnName()); - } - } - - - @Override - public boolean validTensorDataType(TENSOR_TYPE tensorType) { - return dataTypeForTensor(tensorType, 0) != DataType.UNKNOWN; - } - - public void validateGraphStructure(SameDiff sameDiff) { - //First: Check placeholders. When SDVariables are added with null shapes, these can be interpreted as a placeholder - // but null shapes might simply mean shape isn't available during import right when the variable is added - //Idea here: if a "placeholder" is the output of any function, it's not really a placeholder - for (SDVariable v : sameDiff.variables()) { - String name = v.getVarName(); - if (sameDiff.isPlaceHolder(name)) { - String varOutputOf = sameDiff.getVariables().get(name).getOutputOfOp(); - } - } - - //Second: check that all op inputs actually exist in the graph - for (SameDiffOp op : sameDiff.getOps().values()) { - List inputs = op.getInputsToOp(); - if (inputs == null) - continue; - - for (String s : inputs) { - if (sameDiff.getVariable(s) == null) { - throw new IllegalStateException("Import validation failed: op \"" + op.getName() + "\" of type " + op.getOp().getClass().getSimpleName() - + " has input \"" + s + "\" that does not have a corresponding variable in the graph"); - } - } - } - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java deleted file mode 100644 index 2d89a2b07..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java +++ /dev/null @@ -1,429 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper; - -import org.nd4j.shade.protobuf.Message; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * Map graph proto types to - * - * {@link SameDiff} instances - * @param the proto type for the graph - * @param the proto type for the node - * @param the proto type for the attribute - * @param the proto type for the tensor - *@author Adam Gibson - */ -public interface GraphMapper { - - /** - * Import a graph as SameDiff from the given file - * @param graphFile Input stream pointing to graph file to import - * @return Imported graph - */ - SameDiff importGraph(InputStream graphFile); - - SameDiff importGraph(InputStream graphFile, Map> opImportOverrides, - OpImportFilter opFilter); - - /** - * Import a graph as SameDiff from the given file - * @param graphFile Graph file to import - * @return Imported graph - * @see #importGraph(File, Map) - */ - SameDiff importGraph(File graphFile); - - /** - * Import a graph as SameDiff from the given file, with optional op import overrides.
- * The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops - * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. - * - * @param graphFile Graph file to import - * @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the - * operation to import, value is the object used to import it - * @return Imported graph - */ - SameDiff importGraph(File graphFile, Map> opImportOverrides, - OpImportFilter opFilter); - - /** - * This method converts given graph type (in its native format) to SameDiff - * @param graph Graph to import - * @return Imported graph - */ - SameDiff importGraph(GRAPH_TYPE graph); - - /** - * This method converts given graph type (in its native format) to SameDiff
- * The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops - * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. - * @param graph Graph to import - * @return Imported graph - */ - SameDiff importGraph(GRAPH_TYPE graph, Map> opImportOverrides, - OpImportFilter opFilter); - - - /** - * Returns true if this node is a special case - * (maybe because of name or other scenarios) - * that should override {@link #opsToIgnore()} - * in certain circumstances - * @param node the node to check - * @return true if this node is an exception false otherwise - */ - boolean isOpIgnoreException(NODE_TYPE node); - - /** - * Get the nodes sorted by n ame - * from a given graph - * @param graph the graph to get the nodes for - * @return the map of the nodes by name - * for a given graph - */ - Map nodesByName(GRAPH_TYPE graph); - - /** - * Get the target mapping key (usually based on the node name) - * for the given function - * @param function the function - * @param node the node to derive the target mapping from - * @return - */ - String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node); - - - /** - * - * @param on - * @param node - * @param graph - * @param sameDiff - * @param propertyMappings - */ - void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings); - - - /** - * - * @param name - * @param on - * @param node - * @param graph - * @param sameDiff - * @param propertyMappingsForFunction - */ - void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappingsForFunction); - - /** - * Get the node from the graph - * @param graph the graph to get the node from - * @param name the name of the node to get from the graph - * @return - */ - NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name); - - /** - * Returns true if the given node is a place holder - * @param node the node to check - * @return true if the node is a place holder or not - */ - boolean isPlaceHolderNode(TENSOR_TYPE node); - - /** - * Get the list of control dependencies for the current node (or null if none exist) - * - * @param node Node to get the control dependencies (if any) for - * @return - */ - List getControlDependencies(NODE_TYPE node); - - /** - * Dump a binary proto file representation as a - * plain string in to the target text file - * @param inputFile - * @param outputFile - */ - void dumpBinaryProtoAsText(File inputFile,File outputFile); - - - /** - * Dump a binary proto file representation as a - * plain string in to the target text file - * @param inputFile - * @param outputFile - */ - void dumpBinaryProtoAsText(InputStream inputFile,File outputFile); - - - /** - * Get the mapped op name - * for a given op - * relative to the type of node being mapped. - * The input name should be based on a tensorflow - * type or onnx type, not the nd4j name - * @param name the tensorflow or onnx name - * @return the function based on the values in - * {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder} - */ - DifferentialFunction getMappedOp(String name); - - - /** - * Get the variables for the given graph - * @param graphType the graph to get the variables for - * @return a map of variable name to tensor - */ - Map variablesForGraph(GRAPH_TYPE graphType); - - /** - * - * @param name - * @param node - * @return - */ - String translateToSameDiffName(String name, NODE_TYPE node); - - - /** - * - * @param graph - * @return - */ - Map nameIndexForGraph(GRAPH_TYPE graph); - - /** - * Returns an op type for the given input node - * @param nodeType the node to use - * @return the optype for the given node - */ - Op.Type opTypeForNode(NODE_TYPE nodeType); - - /** - * Returns a graph builder for initial definition and parsing. - * @return - */ - Message.Builder getNewGraphBuilder(); - - /** - * Parse a graph from an input stream - * @param inputStream the input stream to load from - * @return - */ - GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException; - - /** - * Parse a graph from an input stream - * @param inputStream the input stream to load from - * @return - */ - GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException; - - - /** - * Map a node in to the import state covering the {@link SameDiff} instance - * @param tfNode the node to map - * @param importState the current import state - * @param opFilter Optional filter for skipping operations - */ - void mapNodeType(NODE_TYPE tfNode, ImportState importState, - OpImportOverride opImportOverride, - OpImportFilter opFilter); - - - /** - * - * @param tensorType - * @param outputNum - * @return - */ - DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum); - - boolean isStringType(TENSOR_TYPE tensor); - - /** - * - * @param nodeType - * @param key - * @return - */ - String getAttrValueFromNode(NODE_TYPE nodeType,String key); - - - /** - * - * @param attrType - * @return - */ - long[] getShapeFromAttribute(ATTR_TYPE attrType); - - /** - * Returns true if the given node is a place holder type - * (think a yet to be determined shape)_ - * @param nodeType - * @return - */ - boolean isPlaceHolder(TENSOR_TYPE nodeType); - - - /** - * Returns true if the given node is a constant - * @param nodeType - * @return - */ - boolean isConstant(TENSOR_TYPE nodeType); - - /** - * - * - * @param tensorName - * @param tensorType - * @param graph - * @return - */ - INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph); - - - /** - * Get the shape for the given tensor type - * @param tensorType - * @return - */ - long[] getShapeFromTensor(TENSOR_TYPE tensorType); - - - /** - * Ops to ignore for mapping - * @return - */ - Set opsToIgnore(); - - /** - * Get the input node for the given node - * @param node the node - * @param index hte index - * @return - */ - String getInputFromNode(NODE_TYPE node, int index); - - /** - * Get the number of inputs for a node. - * @param nodeType the node to get the number of inputs for - * @return - */ - int numInputsFor(NODE_TYPE nodeType); - - /** - * Whether the data type for the tensor is valid - * for creating an {@link INDArray} - * @param tensorType the tensor proto to test - * @return - */ - boolean validTensorDataType(TENSOR_TYPE tensorType); - - - /** - * Get the shape of the attribute value - * @param attr the attribute value - * @return the shape of the attribute if any or null - */ - long[] getShapeFromAttr(ATTR_TYPE attr); - - /** - * Get the attribute - * map for given node - * @param nodeType the node - * @return the attribute map for the attribute - */ - Map getAttrMap(NODE_TYPE nodeType); - - /** - * Get the name of the node - * @param nodeType the node - * to get the name for - * @return - */ - String getName(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - boolean alreadySeen(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - boolean isVariableNode(NODE_TYPE nodeType); - - /** - * - * - * @param opType - * @return - */ - boolean shouldSkip(NODE_TYPE opType); - - /** - * - * @param nodeType - * @return - */ - boolean hasShape(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @return - */ - long[] getShape(NODE_TYPE nodeType); - - /** - * - * @param nodeType - * @param graph - * @return - */ - INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph); - - - String getOpType(NODE_TYPE nodeType); - - /** - * - * @param graphType - * @return - */ - List getNodeList(GRAPH_TYPE graphType); -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java deleted file mode 100644 index 7a651fb88..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java +++ /dev/null @@ -1,652 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.graphmapper.onnx; - -import org.nd4j.shade.protobuf.ByteString; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.guava.primitives.Floats; -import org.nd4j.shade.guava.primitives.Ints; -import org.nd4j.shade.guava.primitives.Longs; -import lombok.val; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.descriptors.properties.AttributeAdapter; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.BaseGraphMapper; -import org.nd4j.imports.graphmapper.ImportState; -import org.nd4j.imports.graphmapper.OpImportFilter; -import org.nd4j.imports.graphmapper.OpImportOverride; -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; - -import java.io.*; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.*; - -/** - * A mapper for onnx graphs to - * {@link org.nd4j.autodiff.samediff.SameDiff} instances. - * - * @author Adam Gibson - */ -public class OnnxGraphMapper extends BaseGraphMapper { - private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper(); - - - public static OnnxGraphMapper getInstance() { - return INSTANCE; - } - - - @Override - public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { - try { - Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) { - bufferedWriter.write(node.toString() + "\n"); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - - - /** - * Init a function's attributes - * @param mappedTfName the onnx name to pick (sometimes ops have multiple names - * @param on the function to map - * @param attributesForNode the attributes for the node - * @param node - * @param graph - */ - public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) { - val properties = on.mappingsForFunction(); - val tfProperties = properties.get(mappedTfName); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - val attributeAdapters = on.attributeAdaptersForFunction(); - for(val entry : tfProperties.entrySet()) { - val tfAttrName = entry.getValue().getTfAttrName(); - val currentField = fields.get(entry.getKey()); - - AttributeAdapter adapter = null; - if(tfAttrName != null) { - if(currentField == null) { - continue; - } - if(attributeAdapters != null && !attributeAdapters.isEmpty()) { - val mappers = attributeAdapters.get(on.tensorflowName()); - val adapterFor = mappers.get(entry.getKey()); - adapter = adapterFor; - } - - - if(attributesForNode.containsKey(tfAttrName)) { - val attr = attributesForNode.get(tfAttrName); - switch (attr.getType()) { - case STRING: - val setString = attr.getS().toStringUtf8(); - if(adapter != null) { - adapter.mapAttributeFor(setString,currentField,on); - } - else - on.setValueFor(currentField,setString); - break; - case INT: - val setInt = (int) attr.getI(); - if(adapter != null) { - adapter.mapAttributeFor(setInt,currentField,on); - } - else - on.setValueFor(currentField,setInt); - break; - case INTS: - val setList = attr.getIntsList(); - if(!setList.isEmpty()) { - val intList = Ints.toArray(setList); - if(adapter != null) { - adapter.mapAttributeFor(intList,currentField,on); - } - else - on.setValueFor(currentField,intList); - } - break; - case FLOATS: - val floatsList = attr.getFloatsList(); - if(!floatsList.isEmpty()) { - val floats = Floats.toArray(floatsList); - if(adapter != null) { - adapter.mapAttributeFor(floats,currentField,on); - } - - else - on.setValueFor(currentField,floats); - break; - } - break; - case TENSOR: - val tensorToGet = mapTensorProto(attr.getT()); - if(adapter != null) { - adapter.mapAttributeFor(tensorToGet,currentField,on); - } - else - on.setValueFor(currentField,tensorToGet); - break; - - } - } - } - - - } - } - - @Override - public boolean isOpIgnoreException(Onnx.NodeProto node) { - return false; - } - - @Override - public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) { - return function.opName(); - } - - - @Override - public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map> propertyMappingsForFunction) { - val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node)); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - /** - * Map ints and the like. Need to figure out how attribute mapping should work. - * - * - */ - - val propsForFunction = on.propertiesForFunction(); - - if(mapping.getTfAttrName() == null) { - int tfMappingIdx = mapping.getTfInputPosition(); - if(tfMappingIdx < 0) - tfMappingIdx += node.getInputCount(); - - val input = node.getInput(tfMappingIdx); - val inputNode = getInstance().getNodeWithNameFromGraph(graph,input); - INDArray arr = sameDiff.getArrForVarName(input); - val field = fields.get(mapping.getPropertyNames()[0]); - val type = field.getType(); - if(type.equals(int[].class)) { - try { - field.set(arr.data().asInt(),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) { - try { - field.set(arr.getInt(0),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) { - try { - field.set(arr.getDouble(0),on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - } - - - - /** - * Figure out whether it's an int array - * or a double array, or maybe a scalar. - */ - - } - else { - val tfMappingAttrName = mapping.getOnnxAttrName(); - val attr = getAttrMap(node).get(tfMappingAttrName); - val type = attr.getType(); - val field = fields.get(mapping.getPropertyNames()[0]); - - Object valueToSet = null; - switch(type) { - case INT: - valueToSet = attr.getI(); - break; - case FLOAT: - valueToSet = attr.getF(); - break; - case STRING: - valueToSet = attr.getF(); - break; - - } - - try { - field.set(valueToSet,on); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } - - } - } - - - @Override - public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) { - for(int i = 0; i < graph.getNodeCount(); i++) { - val node = graph.getNode(i); - if(node.getName().equals(name)) - return node; - } - - return null; - } - - @Override - public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) { - return false; - } - - @Override - public List getControlDependencies(Onnx.NodeProto node) { - throw new UnsupportedOperationException("Not yet implemented"); - } - - @Override - public void dumpBinaryProtoAsText(File inputFile, File outputFile) { - try { - Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - - - - /** - * - * @param name the tensorflow or onnx name - * @return - */ - @Override - public DifferentialFunction getMappedOp(String name) { - return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name); - } - - - - @Override - public Map variablesForGraph(Onnx.GraphProto graphProto) { - /** - * Need to figure out why - * gpu_0/conv1_1 isn't present in VGG - */ - Map ret = new HashMap<>(); - for(int i = 0; i < graphProto.getInputCount(); i++) { - ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType()); - } - - for(int i = 0; i < graphProto.getOutputCount(); i++) { - ret.put(graphProto.getOutput(i).getName(),graphProto.getOutput(i).getType().getTensorType()); - } - - for(int i = 0; i < graphProto.getNodeCount(); i++) { - val node = graphProto.getNode(i); - val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName(); - //add -1 as place holder value representing the shape needs to be filled in - if(!ret.containsKey(name)) { - addDummyTensor(name,ret); - } - - for(int j = 0; j < node.getInputCount(); j++) { - if(!ret.containsKey(node.getInput(j))) { - addDummyTensor(node.getInput(j),ret); - } - } - - - for(int j = 0; j < node.getOutputCount(); j++) { - if(!ret.containsKey(node.getOutput(j))) { - addDummyTensor(node.getOutput(j),ret); - } - } - } - - return ret; - } - - @Override - public String translateToSameDiffName(String name, Onnx.NodeProto node) { - return null; - } - - - protected void addDummyTensor(String name, Map to) { - Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension. - newBuilder() - .setDimValue(-1) - .build(); - Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder() - .setShape( - Onnx.TensorShapeProto.newBuilder() - .addDim(dim) - .addDim(dim).build()) - .build(); - to.put(name,typeProto); - } - - @Override - public Message.Builder getNewGraphBuilder() { - return Onnx.GraphProto.newBuilder(); - } - - @Override - public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException { - return Onnx.ModelProto.parseFrom(inputStream).getGraph(); - } - - @Override - public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException { - return Onnx.ModelProto.parseFrom(inputStream).getGraph(); - } - - @Override - public void mapNodeType(Onnx.NodeProto tfNode, ImportState importState, - OpImportOverride opImportOverride, - OpImportFilter opFilter) { - val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType()); - if(differentialFunction == null) { - throw new NoOpNameFoundException("No op name found " + tfNode.getOpType()); - } - - val diff = importState.getSameDiff(); - val idx = importState.getGraph().getNodeList().indexOf(tfNode); - val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx); - try { - val newInstance = differentialFunction.getClass().newInstance(); - val args = new SDVariable[tfNode.getInputCount()]; - - newInstance.setSameDiff(importState.getSameDiff()); - - newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph()); - importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance); - //ensure we can track node name to function instance later. - diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance); - //diff.addVarNameForImport(tfNode.getName()); - } - catch (Exception e) { - e.printStackTrace(); - } - - - - } - - - - @Override - public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) { - return nd4jTypeFromOnnxType(tensorProto.getElemType()); - } - - @Override - public boolean isStringType(Onnx.TypeProto.Tensor tensor) { - return tensor.getElemType() == Onnx.TensorProto.DataType.STRING; - } - - - /** - * Convert an onnx type to the proper nd4j type - * @param dataType the data type to convert - * @return the nd4j type for the onnx type - */ - public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case DOUBLE: return DataType.DOUBLE; - case FLOAT: return DataType.FLOAT; - case FLOAT16: return DataType.HALF; - case INT32: - case INT64: return DataType.INT; - default: return DataType.UNKNOWN; - } - } - - @Override - public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) { - for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) { - if(attributeProto.getName().equals(key)) { - return attributeProto.getS().toString(); - } - } - - throw new ND4JIllegalStateException("No key found for " + key); - } - - @Override - public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) { - return Longs.toArray(attributeProto.getT().getDimsList()); - } - - @Override - public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) { - return false; - } - - @Override - public boolean isConstant(Onnx.TypeProto.Tensor nodeType) { - return false; - } - - - @Override - public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) { - DataType type = dataTypeForTensor(tensorProto, 0); - if(!tensorProto.isInitialized()) { - throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized"); - } - - Onnx.TensorProto tensor = null; - for(int i = 0; i < graph.getInitializerCount(); i++) { - val initializer = graph.getInitializer(i); - if(initializer.getName().equals(tensorName)) { - tensor = initializer; - break; - } - } - - if(tensor == null) - return null; - - ByteString bytes = tensor.getRawData(); - ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); - ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); - directAlloc.put(byteBuffer); - directAlloc.rewind(); - long[] shape = getShapeFromTensor(tensorProto); - DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); - INDArray arr = Nd4j.create(buffer).reshape(shape); - return arr; - } - - public INDArray mapTensorProto(Onnx.TensorProto tensor) { - if(tensor == null) - return null; - - - DataType type = nd4jTypeFromOnnxType(tensor.getDataType()); - - ByteString bytes = tensor.getRawData(); - ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()); - ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder()); - directAlloc.put(byteBuffer); - directAlloc.rewind(); - long[] shape = getShapeFromTensor(tensor); - DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape)); - INDArray arr = Nd4j.create(buffer).reshape(shape); - return arr; - } - - @Override - public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) { - val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())]; - int dimCount = tensorProto.getShape().getDimCount(); - if(dimCount >= 2) - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) tensorProto.getShape().getDim(i).getDimValue(); - } - else { - ret[0] = 1; - for(int i = 1; i < ret.length; i++) { - ret[i] = (int) tensorProto.getShape().getDim(i - 1).getDimValue(); - } - } - - - return ret; - } - - - /** - * Get the shape from a tensor proto. - * Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)} - * @param tensorProto the tensor to get the shape from - * @return - */ - public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) { - val ret = new long[Math.max(2,tensorProto.getDimsCount())]; - int dimCount = tensorProto.getDimsCount(); - if(dimCount >= 2) - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) tensorProto.getDims(i); - } - else { - ret[0] = 1; - for(int i = 1; i < ret.length; i++) { - ret[i] = (int) tensorProto.getDims(i - 1); - } - } - - - return ret; - } - - @Override - public Set opsToIgnore() { - return Collections.emptySet(); - } - - - @Override - public String getInputFromNode(Onnx.NodeProto node, int index) { - return node.getInput(index); - } - - @Override - public int numInputsFor(Onnx.NodeProto nodeProto) { - return nodeProto.getInputCount(); - } - - - @Override - public long[] getShapeFromAttr(Onnx.AttributeProto attr) { - return Longs.toArray(attr.getT().getDimsList()); - } - - @Override - public Map getAttrMap(Onnx.NodeProto nodeProto) { - Map proto = new HashMap<>(); - for(int i = 0; i < nodeProto.getAttributeCount(); i++) { - Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i); - proto.put(attributeProto.getName(),attributeProto); - } - return proto; - } - - @Override - public String getName(Onnx.NodeProto nodeProto) { - return nodeProto.getName(); - } - - @Override - public boolean alreadySeen(Onnx.NodeProto nodeProto) { - return false; - } - - @Override - public boolean isVariableNode(Onnx.NodeProto nodeProto) { - return nodeProto.getOpType().contains("Var"); - } - - @Override - public boolean shouldSkip(Onnx.NodeProto opType) { - return false; - } - - @Override - public boolean hasShape(Onnx.NodeProto nodeProto) { - return false; - } - - @Override - public long[] getShape(Onnx.NodeProto nodeProto) { - return null; - } - - @Override - public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) { - - return null; - } - - @Override - public String getOpType(Onnx.NodeProto nodeProto) { - return nodeProto.getOpType(); - } - - @Override - public List getNodeList(Onnx.GraphProto graphProto) { - return graphProto.getNodeList(); - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 3ad3267c2..233609b19 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -16,11 +16,10 @@ package org.nd4j.imports.graphmapper.tf; -import org.nd4j.shade.protobuf.Message; -import org.nd4j.shade.guava.primitives.Floats; -import org.nd4j.shade.guava.primitives.Ints; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.commons.io.IOUtils; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,661 +30,620 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; -import org.nd4j.imports.graphmapper.BaseGraphMapper; -import org.nd4j.imports.graphmapper.ImportState; -import org.nd4j.imports.graphmapper.OpImportFilter; -import org.nd4j.imports.graphmapper.OpImportOverride; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers; +import org.nd4j.imports.tensorflow.TFImportOverride; +import org.nd4j.imports.tensorflow.TFOpImportFilter; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState; -import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.protobuf.InvalidProtocolBufferException; +import org.nd4j.shade.protobuf.Message; +import org.nd4j.shade.protobuf.TextFormat; import org.tensorflow.framework.*; -import org.tensorflow.framework.DataType; import java.io.*; -import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.*; /** - * Map tensorflow graph protos - * to the intermediate representation - * for samediff. + * Import a TensorFlow frozen graph in ProtoBuf (.pb) format, to SameDiff * - * @author Adam Gibson + * @author Alex Black */ @Slf4j -public class TFGraphMapper extends BaseGraphMapper { - private Set seenNodes = new LinkedHashSet<>(); - public final static String VALUE_ATTR_KEY = "value"; - public final static String SHAPE_KEY = "shape"; - private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper(); - private Set graphMapper = new HashSet(){{ - //While and If - //While -> Enter - /** - * Need to work on coping with variables - * that are marked as "shouldSkip" - * - * Possibly consider replacing should skip - * with a special handler interface. Something like - * - * public interface ImportOpHandler - */ - add("LoopCond"); - /** - * We should skip this for the sake of while..but not if. - * Need to be a bit more flexible here. - */ - add("Merge"); - add("Exit"); - add("NextIteration"); - add("NoOp"); - add("Switch"); - }}; - //singleton - private TFGraphMapper() {} +public class TFGraphMapper { /** - * Singleton. Get the needed instance. - * @return + * Import a frozen TensorFlow protobuf (.pb) file from the specified file + * + * @param f Frozen TensorFlow model pb file to import + * @return Imported graph */ - public static TFGraphMapper getInstance() { - return MAPPER_INSTANCE; + public static SameDiff importGraph(@NonNull File f) { + return importGraph(f, null, null); } - @Override - public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { - try { - GraphDef graphDef = GraphDef.parseFrom(inputFile); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(NodeDef node : graphDef.getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - + /** + * Import a frozen TensorFlow protobuf (.pb) file from the specified file, with optional overrides + * + * @param f Frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull File f, Map importOverride, TFOpImportFilter opFilter) { + Preconditions.checkState(f.exists(), "File does not exist: %s", f); + try (InputStream is = new BufferedInputStream(new FileInputStream(f))) { + return importGraph(is, importOverride, opFilter); } catch (IOException e) { - e.printStackTrace(); + throw new RuntimeException(e); } } - @Override - public boolean isOpIgnoreException(NodeDef node) { - //if statements should not be ignored -/* - if(node.getOp().equals("Merge")) { - boolean ret = false; - for(int i = 0; i < node.getInputCount(); i++) { - //while loop - ret = ret || !node.getInput(i).endsWith("/Enter") || !node.getInput(i).endsWith("/NextIteration"); + /** + * Import a frozen TensorFlow protobuf (.pb) file, via an input stream + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull InputStream is) { + return importGraph(is, null, null); + } + /** + * Import a frozen TensorFlow protobuf file in text format (.pb.txt) file via an input stream, with optional overrides + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraphTxt(@NonNull InputStream is, Map importOverride, TFOpImportFilter opFilter) { + GraphDef tfGraph; + try { + Message.Builder builder = GraphDef.newBuilder(); + String content = IOUtils.toString(is, StandardCharsets.UTF_8); + TextFormat.getParser().merge(content, builder); + tfGraph = (GraphDef) builder.build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return importGraph(tfGraph, importOverride, opFilter); + } + + /** + * Import a frozen TensorFlow protobuf (.pb) file via an input stream, with optional overrides + * + * @param is Stream for a frozen TensorFlow model pb file to import + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported graph + */ + public static SameDiff importGraph(@NonNull InputStream is, Map importOverride, TFOpImportFilter opFilter) { + GraphDef tfGraph; + try { + tfGraph = GraphDef.parseFrom(is); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return importGraph(tfGraph, importOverride, opFilter); + } + + /** + * Import a TensorFlow model from a GraphDef + * + * @param tfGraph TensorFlow model GraphDef + * @return Imported model + */ + public static SameDiff importGraph(@NonNull GraphDef tfGraph) { + return importGraph(tfGraph, null, null); + } + + /** + * Import a TensorFlow model from a GraphDef, with optional import overrides + * + * @param tfGraph TensorFlow model GraphDef + * @param importOverride Optional import override for specific ops, keyed by op name + * @param opFilter Optional filter - ops to exclude/ignore + * @return Imported model + */ + public static SameDiff importGraph(@NonNull GraphDef tfGraph, Map importOverride, TFOpImportFilter opFilter) { + + /* + First, build an in-memory representation of the graph that allows us to build the graph incrementally + If we can build the graph incrementally, we can make sure that the added variables are set up with the correct + datatype and (once implemented) greedy shape inference + */ + Set availableToAddSet = new HashSet<>(); //TODO maybe unnecessary? + Queue availableToAdd = new LinkedList<>(); + + Map remainingNodes = new HashMap<>(); //All other nodes, not in availableToAdd + + Map> nodeInputTo = new HashMap<>(); // For op x -> y, x is key, y is value. Note that these are OP names not VARIABLE names + + int nNodes = tfGraph.getNodeCount(); + + //First, add any constants, placeholders, and zero-input ops + SameDiff sd = SameDiff.create(); + for (int i = 0; i < nNodes; i++) { + NodeDef nd = tfGraph.getNode(i); + String op = nd.getOp(); + String name = nd.getName(); + + int nInputs = nd.getInputCount(); + + if ("Const".equals(op) || "Placeholder".equals(op) || nInputs == 0) { + availableToAdd.add(nd); + availableToAddSet.add(name); + } else { + remainingNodes.put(name, nd); + for (int in = 0; in < nInputs; in++) { + String inOpName = stripControl(nd.getInput(in)); + inOpName = stripVarSuffix(inOpName); + + if (!nodeInputTo.containsKey(inOpName)) { + nodeInputTo.put(inOpName, new HashSet()); + } + nodeInputTo.get(inOpName).add(name); + } + } + } + + Map mergeOpsPostProcess = new HashMap<>(); + + //Go through ops in order, and add to the graph + Map> constControlDeps = new HashMap<>(); //Key: constant name. Value: control dependencies + while (!availableToAdd.isEmpty()) { + NodeDef nd = availableToAdd.remove(); + String name = nd.getName(); + String opName = nd.getOp(); + int nIn = nd.getInputCount(); + + availableToAddSet.remove(name); + + log.trace("Adding operation to graph: {} (name={})", opName, name); + + boolean skipCase = false; + if(opFilter != null && opFilter.skipOp(nd, sd, nd.getAttrMap(), tfGraph)){ + log.debug("Skipping op {} of type {} due to op filter", name, opName); + //Don't continue at this point - we still need to process what this feeds into... + skipCase = true; + } else { + if (importOverride == null || !importOverride.containsKey(name)) { + //Standard case + if ("Const".equals(opName)) { + //Get array, create a constant + TensorProto tfTensor = nd.getAttrOrThrow("value").getTensor(); + TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); + INDArray arr = m.toNDArray(); + sd.constant(name, arr); + int inputCount = nd.getInputCount(); + if (inputCount > 0) { + //Very likely control dependency. i.e., "we must execute op X before the constant is really available to be used" + List l = new ArrayList<>(inputCount); + for (int i = 0; i < inputCount; i++) { + String n = nd.getInput(i); + if (!isControlDep(n)) { + throw new IllegalStateException("Found non-control dependency input \"" + n + "\" for constant \"" + name + "\""); + } + String n2 = stripControl(n); + l.add(n2); + } + constControlDeps.put(name, l); + } + } else if ("Placeholder".equals(opName) || "PlaceholderWithDefault".equals(opName)) { + //TODO support the "WithDefault" array + + Map attrMap = nd.getAttrMap(); + boolean shapeAvailable = attrMap.containsKey("shape"); + long[] shape; + if (shapeAvailable) { + TensorShapeProto shapeProto = attrMap.get("shape").getShape(); + shape = shapeFromShapeProto(shapeProto); + } else { + //Some placeholders don't have any shape restrictions - i.e., accept anything... + shape = null; + } + + + org.tensorflow.framework.DataType tfDtype = attrMap.get("dtype").getType(); + org.nd4j.linalg.api.buffer.DataType dt = convertType(tfDtype); + sd.placeHolder(name, dt, shape); + } else { + /* + Normal ops. Process in the following order: + 1. Create the op instance + 2. Add op to graph + 3. Import from TF (to set attributes) + 4. Calculate output dtypes + 5. Create and add output variables to graph + + Note: one constraint on this order is that some ops import modify the graph structure. + Notable example: concat op - it removes the axis op and converts the value to an iArg + https://github.com/eclipse/deeplearning4j/issues/8285 + */ + DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName); + Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: {}", opName); + + DifferentialFunction df; + try { + df = dfInstance.getClass().newInstance(); + } catch (Throwable t) { + //Should never happen because function was already created via no-arg constructor earlier + throw new RuntimeException(t); + } + df.setSameDiff(sd); + df.setOwnName(name); + + //Process inputs + List inNames = new ArrayList<>(nIn); + List controlDeps = null; + for (int i = 0; i < nIn; i++) { + String origInName = nd.getInput(i); + String inName = stripControl(origInName); + boolean isControlDep = isControlDep(origInName); + if (isControlDep) { + if (controlDeps == null) + controlDeps = new ArrayList<>(); + controlDeps.add(inName); + } + + if (!isControlDep) { + inNames.add(inName); + } + + //Update Variable.inputsForOp for all variables that feed into this op + // Such variables must have already been created, given we process in order + Variable v = sd.getVariables().get(inName); + + if (v == null && df instanceof Merge) { + //Edge case for import - we allow merge ops to be added before both inputs are available + //This is to break the cycles in loops, otherwise we can't process anything in order + mergeOpsPostProcess.put(df.getOwnName(), inName); + continue; + } + + if (!isControlDep && (v.getInputsForOp() == null || !v.getInputsForOp().contains(name))) { + //May already be present - for example, add(x,x) + if (v.getInputsForOp() == null) + v.setInputsForOp(new ArrayList()); + v.getInputsForOp().add(name); + } else if (isControlDep) { + if (v.getControlDepsForOp() == null) + v.setControlDepsForOp(new ArrayList()); + if (!v.getControlDepsForOp().contains(name)) { + v.getControlDepsForOp().add(name); + } + } + } + + //Create SameDiffOp instance and add to graph + SameDiffOp op = SameDiffOp.builder() + .name(name) + .op(df) + .inputsToOp(inNames) + //.outputsOfOp(outNames) //We'll set this later + .controlDeps(controlDeps) + .build(); + sd.getOps().put(name, op); + + + Map attrMap = nd.getAttrMap(); + df.initFromTensorFlow(nd, sd, attrMap, tfGraph); //TODO REMOVE TFGRAPH ENTIRELY FROM THIS CALL - it encourages hacky and really brittle stuff like input array to attribute conversion + + //DType calculate for output variables (set/correct if necessary) + List newInNames = sd.getOps().get(name).getInputsToOp(); //Just in case import has modified this, like for concat case + List newInDtypes = new ArrayList<>(newInNames.size()); + if (df instanceof Merge) { + //Merge op: as noted elsewhere, we allow merge to be processed when only one of the inputs is available + // to break cycles for loops + //We know that Merge op has the restriction of the same datatype for both inputs, so we'll + SDVariable v1 = sd.getVariable(newInNames.get(0)); + SDVariable v2 = sd.getVariable(newInNames.get(1)); + org.nd4j.linalg.api.buffer.DataType dt1 = (v1 == null ? v2.dataType() : v1.dataType()); + org.nd4j.linalg.api.buffer.DataType dt2 = (v2 == null ? v1.dataType() : v2.dataType()); + newInDtypes.add(dt1); + newInDtypes.add(dt2); + } else { + for (String s : newInNames) { + SDVariable v = sd.getVariable(s); + newInDtypes.add(v.dataType()); + } + } + + List outDTypes = df.calculateOutputDataTypes(newInDtypes); + SDVariable[] outSDVars = new SDVariable[outDTypes.size()]; + Variable[] outVars = new Variable[outDTypes.size()]; + List outNames = new ArrayList<>(outDTypes.size()); + + //Create output variables and add to graph + for (int i = 0; i < outDTypes.size(); i++) { + org.nd4j.linalg.api.buffer.DataType dt = outDTypes.get(i); + String varName = name + (i == 0 ? "" : ":" + i); + outSDVars[i] = sd.var(varName, VariableType.ARRAY, null, dt, (long[]) null); + outNames.add(varName); + + outVars[i] = Variable.builder() + .name(varName) + .variable(outSDVars[i]) + .inputsForOp(null) //This is updated incrementally as other ops are added + .controlDepsForOp(null) //Control deps are handled later + .controlDepsForVar(null) + .outputOfOp(name) + .build(); + + sd.getVariables().put(varName, outVars[i]); + log.trace("Added variable to graph: {} (output of op {})", varName, name); + } + sd.getOps().get(name).setOutputsOfOp(outNames); + + log.trace("Imported op: {} (name={})", opName, name); + } + } else { + //Import override case + TFImportOverride o = importOverride.get(name); + + log.debug("Importing op {} using override {}", opName, importOverride); + + //First, get inputs: + List inputs = new ArrayList<>(nIn); + List controlDeps = null; + for (int i = 0; i < nIn; i++) { + String inName = nd.getInput(i); + boolean controlDep = isControlDep(inName); + + SDVariable v = sd.getVariable(name); + + if (controlDep) { + if (controlDeps == null) + controlDeps = new ArrayList<>(); + controlDeps.add(v); + } else { + inputs.add(v); + } + + o.initFromTensorFlow(inputs, controlDeps, nd, sd, nd.getAttrMap(), tfGraph); + } + } } - return ret; - } - else if(node.getOp().equals("Switch")) { - boolean ret = false; - for(int i = 0; i < node.getInputCount(); i++) { - //while loop - ret = ret || !node.getInput(i).endsWith("/Merge") || !node.getInput(i).endsWith("/LoopCond"); + //Now that we have just added an op (or variable) - check what this feeds into, and see what we can now process + // as a result + if (nodeInputTo.containsKey(name)) { + Set set = nodeInputTo.get(name); + for (String nextOp : set) { + NodeDef nextOpDef = remainingNodes.get(nextOp); + if (nextOpDef == null) { + if (sd.getOps().containsKey(nextOp)) { + //Already processed this. + //Almost certainly the close of a loop - like NextIteration -> Merge case + continue; + } + //Should never happen + throw new IllegalStateException("Could not find op definition for op to import: " + nextOp); + } + int nInNext = nextOpDef.getInputCount(); + boolean allAlreadyInGraph = true; + int nonControlSeenCount = 0; + for (int i = 0; i < nInNext; i++) { + String s = nextOpDef.getInput(i); + String inName = stripControl(nextOpDef.getInput(i)); + +// log.info("Input: {}, {}", s, inName); + + if (!sd.hasVariable(inName) && !skipCase) { +// log.info("Not found: {} for op {}", inName, nextOpDef.getName()); + allAlreadyInGraph = false; + break; + } else if (!isControlDep(s)) { + nonControlSeenCount++; + } + } + + //Merge ops are an edge case. We'll allow these to be executed with just ONE input, to break + // the cycle in loops. In loops, generally we have (Enter, NextIteration) -> Merge, which + // of course can't be done if we strictly require all inputs to be available + boolean mergeCase = (nonControlSeenCount > 0 && "Merge".equals(nextOpDef.getOp())); + + if (allAlreadyInGraph || mergeCase) { + //Can process this op, add it to the queue for processing + if (!availableToAddSet.contains(nextOp)) { + //Avoid processing same op multiple times, for repeated inputs to one op, etc + availableToAdd.add(nextOpDef); + availableToAddSet.add(nextOp); + log.trace("Added to processing queue: {} (name={})", nextOpDef.getOp(), nextOp); + } + } + } } + //Finally, remove the just processed op from remainingNodes map: + remainingNodes.remove(name); + } + + //Post process the control dependencies, if any (done after because dependencies may not exist when imported) + for (Map.Entry> e : constControlDeps.entrySet()) { + String varName = e.getKey(); + List cdOpNames = e.getValue(); + sd.getVariables().get(varName).setControlDeps(cdOpNames); + + for (String s : cdOpNames) { + SameDiffOp sdo = sd.getOps().get(s); + if (sdo.getControlDepFor() == null) + sdo.setControlDepFor(new ArrayList()); + List l = sdo.getControlDepFor(); + if (!l.contains(s)) + l.add(varName); + } + } + + //Post process the merge ops - all we are missing is a Variable.getInputsForOp().add(mergeOpName); + for (Map.Entry e : mergeOpsPostProcess.entrySet()) { + Variable v = sd.getVariables().get(e.getValue()); + if (v.getInputsForOp() == null) + v.setInputsForOp(new ArrayList()); + v.getInputsForOp().add(e.getKey()); + } + + Preconditions.checkState(remainingNodes.isEmpty(), "%s Unprocessed nodes: %s", remainingNodes.size(), remainingNodes.keySet()); + + return sd; + } + + + /** + * Get the shape from a TensorShapeProto + * + * @param tensorShapeProto Shape + * @return Shape as long[] + */ + private static long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) { + long[] shape = new long[tensorShapeProto.getDimList().size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = tensorShapeProto.getDim(i).getSize(); + } + + return shape; + } + + /** + * Convert from TF proto datatype to ND4J datatype + * + * @param tfType TF datatype + * @return ND4J datatype + */ + public static org.nd4j.linalg.api.buffer.DataType convertType(org.tensorflow.framework.DataType tfType) { + switch (tfType) { + case DT_DOUBLE: + return org.nd4j.linalg.api.buffer.DataType.DOUBLE; + case DT_FLOAT: + return org.nd4j.linalg.api.buffer.DataType.FLOAT; + case DT_HALF: + return org.nd4j.linalg.api.buffer.DataType.HALF; + case DT_BFLOAT16: + return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; + case DT_INT8: + return org.nd4j.linalg.api.buffer.DataType.BYTE; + case DT_INT16: + return org.nd4j.linalg.api.buffer.DataType.SHORT; + case DT_INT32: + return org.nd4j.linalg.api.buffer.DataType.INT; + case DT_INT64: + return org.nd4j.linalg.api.buffer.DataType.LONG; + case DT_UINT8: + return org.nd4j.linalg.api.buffer.DataType.UBYTE; + case DT_STRING: + return org.nd4j.linalg.api.buffer.DataType.UTF8; + case DT_BOOL: + return org.nd4j.linalg.api.buffer.DataType.BOOL; + + default: + return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; + } + } + + /** + * @return True if the specified name represents a control dependency (starts with "^") + */ + protected static boolean isControlDep(String name) { + return name.startsWith("^"); + } + + /** + * @return The specified name without the leading "^" character (if any) that appears for control dependencies + */ + protected static String stripControl(String name) { + if (name.startsWith("^")) { + return name.substring(1); + } + return name; + } + + /** + * Remove the ":1" etc suffix for a variable name to get the op name + * + * @param varName Variable name + * @return Variable name without any number suffix + */ + protected static String stripVarSuffix(String varName) { + if (varName.matches(".*:\\d+")) { + int idx = varName.lastIndexOf(':'); + String ret = varName.substring(0, idx); return ret; } -*/ - return true; + return varName; } - @Override - public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) { - return function.opName(); + /** + * Convert the tensor to an NDArray (if possible and if array is available) + * + * @param node Node to get NDArray from + * @return NDArray + */ + public static INDArray getNDArrayFromTensor(NodeDef node) { + //placeholder of some kind + if (!node.getAttrMap().containsKey("value")) { + return null; + } + + val tfTensor = node.getAttrOrThrow("value").getTensor(); + INDArray out = mapTensorProto(tfTensor); + return out; } - @Override - public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { - for(int i = 0; i < graph.getNodeCount(); i++) { + /** + * Convert a TensorProto to an INDArray + * + * @param tfTensor Tensor proto + * @return INDArray + */ + public static INDArray mapTensorProto(TensorProto tfTensor) { + TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); + if (m == null) { + throw new RuntimeException("Not implemented datatype: " + tfTensor.getDtype()); + } + INDArray out = m.toNDArray(); + return out; + } + + @Deprecated //To be removed + public static NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { + for (int i = 0; i < graph.getNodeCount(); i++) { val node = graph.getNode(i); - if(node.getName().equals(name)) + if (node.getName().equals(name)) return node; } return null; } - @Override - public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map> propertyMappingsForFunction) { - if(node == null) { - throw new ND4JIllegalStateException("No node found for name " + name); - } - - - val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); - - - if(mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) { - int tfMappingIdx = mapping.getTfInputPosition(); - if(tfMappingIdx < 0) - tfMappingIdx += node.getInputCount(); - - val input = node.getInput(tfMappingIdx); - val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,input); - INDArray arr = getArrayFrom(inputNode,graph); - if(arr == null && sameDiff.hasVariable(input)) { - arr = sameDiff.getArrForVarName(input); - } - - if(arr == null && inputNode != null) { - sameDiff.addPropertyToResolve(on,name); - sameDiff.addVariableMappingForField(on,name,getNodeName(inputNode.getName())); - return; - } else if(inputNode == null) { - //TODO need to do anything here given new design? - //sameDiff.addAsPlaceHolder(input); - return; - } - - val field = fields.get(name); - val type = field.getType(); - if(type.equals(int[].class)) { - on.setValueFor(field,arr.data().asInt()); - } - else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) { - if(mapping.getShapePosition() != null) { - on.setValueFor(field,arr.size(mapping.getShapePosition())); - } - else - on.setValueFor(field,arr.getInt(0)); - - } - else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) { - on.setValueFor(field,arr.getDouble(0)); - } - - - } - else { - val tfMappingAttrName = mapping.getTfAttrName(); - if(tfMappingAttrName == null) { - return; - } - - if(!node.containsAttr(tfMappingAttrName)) { - return; - } - - - val attr = node.getAttrOrThrow(tfMappingAttrName); - val type = attr.getType(); - if(fields == null) { - throw new ND4JIllegalStateException("No fields found for op [" + mapping + "]"); - } - - if(mapping.getPropertyNames() == null) { - throw new ND4JIllegalStateException("no property found for [" + name + "] in op [" + on.opName()+"]"); - } - - val field = fields.get(mapping.getPropertyNames()[0]); - - Object valueToSet = null; - switch(type) { - case DT_BOOL: - valueToSet = attr.getB(); - break; - case DT_INT8: - valueToSet = attr.getI(); - break; - case DT_INT16: - valueToSet = attr.getI(); - break; - case DT_INT32: - valueToSet = attr.getI(); - break; - case DT_FLOAT: - valueToSet = attr.getF(); - break; - case DT_DOUBLE: - valueToSet = attr.getF(); - break; - case DT_STRING: - valueToSet = attr.getS(); - break; - case DT_INT64: - valueToSet = attr.getI(); - break; - - - } - - if(field != null && valueToSet != null) - on.setValueFor(field,valueToSet); - } - } - - - /** - * {@inheritDoc} - */ - @Override - public boolean isPlaceHolderNode(NodeDef node) { - return node.getOp().startsWith("Placeholder"); - } - - - /** - * {@inheritDoc} - */ - @Override - public void dumpBinaryProtoAsText(File inputFile, File outputFile) { - try { - GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(NodeDef node : graphDef.getNodeList()) { - bufferedWriter.write(node.toString()); - } - - bufferedWriter.flush(); - bufferedWriter.close(); - - } catch (IOException e) { - e.printStackTrace(); - } - } - - @Override - public long[] getShapeFromAttr(AttrValue attr) { - return shapeFromShapeProto(attr.getShape()); - } - - @Override - public Map getAttrMap(NodeDef nodeDef) { - return nodeDef.getAttrMap(); - } - - @Override - public String getName(NodeDef nodeDef) { - return nodeDef.getName(); - } - - @Override - public boolean alreadySeen(NodeDef nodeDef) { - return seenNodes.contains(nodeDef.getName()); - } - - @Override - public boolean isVariableNode(NodeDef nodeDef) { - boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const"); - return isVar; - } - - @Override - public boolean shouldSkip(NodeDef opType) { - if(opType == null) - return true; - - boolean endsWithRead = opType.getName().endsWith("/read"); - return endsWithRead; - } - - @Override - public boolean hasShape(NodeDef nodeDef) { - return nodeDef.containsAttr(SHAPE_KEY); - } - - @Override - public long[] getShape(NodeDef nodeDef) { - return getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY)); - } - - @Override - public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) { - if(nodeDef == null) { + @Deprecated //To be removed + public static INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) { + if (nodeDef == null) { return null; } - return getNDArrayFromTensor(nodeDef.getName(),nodeDef, graph); - } - - @Override - public String getOpType(NodeDef nodeDef) { - return nodeDef.getOp(); - } - - /** - * - * @param graphDef - * @return - */ - @Override - public List getNodeList(GraphDef graphDef) { - return graphDef.getNodeList(); - } - - /** - * - * @param name the tensorflow or onnx name - * @return - */ - @Override - public DifferentialFunction getMappedOp(String name) { - return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name); - } - - - /** - * Map a tensorflow node name - * to the samediff equivalent - * for import - * @param name the name to change - * @return the input tensorflow name - */ - public String getNodeName(String name) { - //tensorflow adds colons to the end of variables representing input index, this strips those off - String ret = name; - if(ret.startsWith("^")) - ret = ret.substring(1); - if(ret.endsWith("/read")) { - ret = ret.replace("/read",""); - } - if(ret.endsWith(":0")){ - ret = ret.substring(0, ret.length()-2); - } - return ret; - } - - public boolean isControlDependency(String name){ - return name.startsWith("^"); - } - - - - @Override - public Map variablesForGraph(GraphDef graphDef) { - Map ret = new LinkedHashMap<>(); - List nodeList = graphDef.getNodeList(); - for(NodeDef nodeDef : nodeList) { - if(nodeDef.getName().endsWith("/read")) { - continue; - } - - - val name = translateToSameDiffName(nodeDef.getName(), nodeDef); - ret.put(name,nodeDef); - } - - return ret; - } - - @Override - public String translateToSameDiffName(String name, NodeDef node) { - if(isVariableNode(node) || isPlaceHolder(node)) { - return name; - } - - StringBuilder stringBuilder = new StringBuilder(); - //strip arg number - if(name.contains(":")) { - name = name.substring(0,name.lastIndexOf(':')); - stringBuilder.append(name); - } - else { - stringBuilder.append(name); - } - - - return stringBuilder.toString(); - } - - //Strip the variable suffix to give the node name: "Unique:1" -> "Unique" - public String varNameToOpName(String varName){ - int idx = varName.lastIndexOf(':'); - if(idx < 0) - return varName; - return varName.substring(0, idx); - } - - public static int varNameToOpOutputNumber(String varName){ - int idx = varName.lastIndexOf(':'); - if(idx < 0) - return 0; - String n = varName.substring(idx+1); - return Integer.parseInt(n); - } - - - @Override - public Message.Builder getNewGraphBuilder() { - return GraphDef.newBuilder(); - } - - @Override - public GraphDef parseGraphFrom(byte[] inputStream) throws IOException { - return GraphDef.parseFrom(inputStream); - } - - @Override - public GraphDef parseGraphFrom(InputStream inputStream) throws IOException { - return GraphDef.parseFrom(inputStream); - } - - protected void importCondition(String conditionName, NodeDef tfNode, ImportState importState) { - /** - * Cond structure: - * - */ - } - - @Override - public void mapNodeType(NodeDef tfNode, ImportState importState, - OpImportOverride importOverride, - OpImportFilter opFilter) { - if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) { - return; - } - - - SameDiff diff = importState.getSameDiff(); - if (isVariableNode(tfNode)) { - List dimensions = new ArrayList<>(); - Map attributes = getAttrMap(tfNode); - if (attributes.containsKey(VALUE_ATTR_KEY)) { - diff.var(getName(tfNode),getArrayFrom(tfNode,importState.getGraph())); - } - else if (attributes.containsKey(SHAPE_KEY)) { - AttrValue shape = attributes.get(SHAPE_KEY); - long[] shapeArr = getShapeFromAttr(shape); - int dims = shapeArr.length; - if (dims > 0) { - // even vector is 2d in nd4j - if (dims == 1) - dimensions.add(1L); - - for (int e = 0; e < dims; e++) { - // TODO: eventually we want long shapes :( - dimensions.add(getShapeFromAttr(shape)[e]); - } - } - } - } - - else if(isPlaceHolder(tfNode)) { - SDVariable var = diff.getVariable(getName(tfNode)); - Preconditions.checkState(var.isPlaceHolder(), "Variable should be marked as placeholder at this point: %s", var); - } else { - val opName = tfNode.getOp(); - - if(importOverride != null){ - //First, get inputs: - int numInputs = tfNode.getInputCount(); - List inputs = new ArrayList<>(numInputs); - List controlDeps = null; - for( int i=0; i this) - if (v == null) { - //Check 'op skip' edge case - boolean shouldSkip = false; - if(opFilter != null){ - //Get the input node - List l = importState.getGraph().getNodeList(); - NodeDef inputNodeDef = null; - for(NodeDef nd : l){ - if(inName.equals(nd.getName())){ - inputNodeDef = nd; - break; - } - } - Preconditions.checkState(inputNodeDef != null, "Could not find node with name \"%s\"", inName); - shouldSkip = true; - } - - if(!shouldSkip) { - //First: try to work out the datatype of this input node - //Given we haven't already imported it at this point, it must be the 2nd or later output of an op - - String inputOpName = varNameToOpName(inName); - NodeDef inputOp = importState.getVariables().get(inputOpName); - int outputIdx = varNameToOpOutputNumber(name); - org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx); - if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN) - dt = null; //Infer it later - - - v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null); - } - } - - if(controlDep){ - if(controlDeps == null) - controlDeps = new ArrayList<>(); - controlDeps.add(v); - } else { - inputs.add(v); - } - } - - log.info("Importing op {} using override {}", opName, importOverride); - importOverride.initFromTensorFlow(inputs, controlDeps, tfNode, diff, getAttrMap(tfNode), importState.getGraph()); - } else { - - val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName); - if (differentialFunction == null) { - throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?"); - } - try { - DifferentialFunction newInstance = differentialFunction.getClass().newInstance(); - List args = new ArrayList<>(); - List controlDeps = null; - newInstance.setOwnName(tfNode.getName()); - - int x = 0; - for (int i = 0; i < tfNode.getInputCount(); i++) { - String inName = tfNode.getInput(i); - String inputOpName = varNameToOpName(inName); - NodeDef inputNode = importState.getVariables().get(inputOpName); - - if (shouldSkip(inputNode) && !inName.endsWith("/read")) - continue; - - boolean controlDep = isControlDependency(inName); - String name = getNodeName(inName); - - SDVariable v = diff.getVariable(name); - - //At this point, all placeholders, variables and constants should have been imported - //This: this should be an array type variable (i.e., activations) - if (v == null) { - //First: try to work out the datatype of this input node - //Given we haven't already imported it at this point, it must be the 2nd or later output of an op - - NodeDef inputOp = importState.getVariables().get(inputOpName); - int outputIdx = varNameToOpOutputNumber(name); - org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx); - if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN) - dt = null; //Infer it later - - - v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null); - } - - if (controlDep) { - //Is only a control dependency input to op, not a real data input - if (controlDeps == null) - controlDeps = new ArrayList<>(); - if (!controlDeps.contains(name)) - controlDeps.add(name); - } else { - //Is a standard/"real" op input - args.add(v); - } - } - - - diff.addArgsFor(args.toArray(new SDVariable[args.size()]), newInstance); - newInstance.setSameDiff(importState.getSameDiff()); - - if (controlDeps != null) { - SameDiffOp op = diff.getOps().get(newInstance.getOwnName()); - op.setControlDeps(controlDeps); - - //Also record this on the variables: - for (String s : controlDeps) { - Variable v = diff.getVariables().get(s); - if (v.getControlDepsForOp() == null) - v.setControlDeps(new ArrayList()); - List l = v.getControlDepsForOp(); - if (!l.contains(op.getName())) - l.add(op.getName()); - } - } - - newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph()); - mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction()); - importState.getSameDiff().putOpForId(newInstance.getOwnName(), newInstance); - //ensure we can track node name to function instance later. - diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance); - } catch (Exception e) { - log.error("Failed to import op [{}]", opName); - throw new RuntimeException(e); - } - } - } - } - - - /** - * Calls {@link #initFunctionFromProperties(DifferentialFunction, Map, NodeDef, GraphDef)} - * using {@link DifferentialFunction#tensorflowName()} - * @param on the function to use init on - * @param attributesForNode the attributes for the node - * @param node - * @param graph - */ - public void initFunctionFromProperties(DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { - initFunctionFromProperties(on.tensorflowName(),on,attributesForNode,node,graph); + return getNDArrayFromTensor(nodeDef); } /** * Init a function's attributes - * @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names - * @param on the function to map + * + * @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names + * @param on the function to map * @param attributesForNode the attributes for the node * @param node * @param graph + * @deprecated To be removed */ - public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { + @Deprecated + public static void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) { val properties = on.mappingsForFunction(); val tfProperties = properties.get(mappedTfName); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); @@ -699,8 +657,8 @@ public class TFGraphMapper extends BaseGraphMapper need to map data format before mapping strides //Solution: map nodes without adapters before nodes with adapters. This doesn't guarantee we'll always be // mapping in the right order (for example, we might have adapter(x) depends on adapter(y)) but it should catch most cases - Map map; - if(attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) { + Map map; + if (attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) { map = tfProperties; } else { map = new LinkedHashMap<>(); @@ -718,24 +676,24 @@ public class TFGraphMapper extends BaseGraphMapper entry : map.entrySet()){ + for (Map.Entry entry : map.entrySet()) { val tfAttrName = entry.getValue().getTfAttrName(); val currentField = fields.get(entry.getKey()); AttributeAdapter adapter = null; - if(attributeAdapters != null && !attributeAdapters.isEmpty()) { + if (attributeAdapters != null && !attributeAdapters.isEmpty()) { val mappers = attributeAdapters.get(mappedTfName); val adapterFor = mappers.get(entry.getKey()); adapter = adapterFor; } - if(tfAttrName != null) { - if(currentField == null) { + if (tfAttrName != null) { + if (currentField == null) { continue; } - if(attributesForNode.containsKey(tfAttrName)) { + if (attributesForNode.containsKey(tfAttrName)) { val attr = attributesForNode.get(tfAttrName); switch (attr.getValueCase()) { case B: @@ -743,77 +701,69 @@ public class TFGraphMapper extends BaseGraphMapper 0){ - for(int i=0; i 0){ //Looks like a few OpDef instances have outputs but don't actually list them... example: NoOp - Preconditions.checkState(outNum < actualOutputCount, "Cannot get output argument %s from op %s with %s output variables - variable %s", outNum, actualOutputCount, tensorProto.getName(), tensorProto.getName()); - - int argIdx = outNum; - if(outputArgCount != actualOutputCount){ - //Map backwards accunting for fact that each output arg might correspond to multiple variables: for output variable x, which argument is this? - int idx = 0; - int soFar = 0; - while(soFar + outVarsPerOutputArg[idx] <= outNum){ - soFar += outVarsPerOutputArg[idx++]; - } - argIdx = idx; - } - - OpDef.ArgDef argDef = opDef.getOutputArg(argIdx); - String typeAttr = argDef.getTypeAttr(); - if(typeAttr != null && tensorProto.containsAttr(typeAttr)){ - tfType = tensorProto.getAttrOrThrow(typeAttr).getType(); - } else { - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } - - } else { - if(tensorProto.getOp().equals("NoOp")){ - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } else if(tensorProto.getOp().equals("Assert")){ - return org.nd4j.linalg.api.buffer.DataType.BOOL; - } - //Not in ops.proto - log.debug("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp()); - - //No descriptor... try to fall back on common type attribute names - if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T")) - return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - - tfType = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType() - : tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto - .getAttrOrThrow("Tidx").getType(); - } - - return convertType(tfType); - } - - public static org.nd4j.linalg.api.buffer.DataType convertType(org.tensorflow.framework.DataType tfType){ - switch(tfType) { - case DT_DOUBLE: return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - case DT_FLOAT: return org.nd4j.linalg.api.buffer.DataType.FLOAT; - case DT_HALF: return org.nd4j.linalg.api.buffer.DataType.HALF; - case DT_BFLOAT16: return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; - case DT_INT8: return org.nd4j.linalg.api.buffer.DataType.BYTE; - case DT_INT16: return org.nd4j.linalg.api.buffer.DataType.SHORT; - case DT_INT32: return org.nd4j.linalg.api.buffer.DataType.INT; - case DT_INT64: return org.nd4j.linalg.api.buffer.DataType.LONG; - case DT_UINT8: return org.nd4j.linalg.api.buffer.DataType.UBYTE; - case DT_STRING: return org.nd4j.linalg.api.buffer.DataType.UTF8; - case DT_BOOL: return org.nd4j.linalg.api.buffer.DataType.BOOL; - - default: return org.nd4j.linalg.api.buffer.DataType.UNKNOWN; - } - } - - @Override - public boolean isStringType(NodeDef tensorProto){ - DataType dt = null; - if(tensorProto.containsAttr("dtype")){ - dt = tensorProto.getAttrOrThrow("dtype").getType(); - } else if(tensorProto.containsAttr("T")){ - dt = tensorProto.getAttrOrThrow("T").getType(); - } else if(tensorProto.containsAttr("Tidx")){ - dt = tensorProto.getAttrOrThrow("Tidx").getType(); - } - - return dt == DataType.DT_STRING || dt == DataType.DT_STRING_REF; - } - - - @Override - public String getAttrValueFromNode(NodeDef nodeDef, String key) { - return nodeDef.getAttrOrThrow(key).getS().toStringUtf8(); - } - - @Override - public long[] getShapeFromAttribute(AttrValue attrValue) { - TensorShapeProto shape = attrValue.getShape(); - long[] ret = new long[shape.getDimCount()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = (int) shape.getDim(i).getSize(); + if (ret.endsWith(":0")) { + ret = ret.substring(0, ret.length() - 2); } return ret; } - @Override - public boolean isPlaceHolder(NodeDef nodeDef) { - return nodeDef.getOp().startsWith("Placeholder"); + /** + * Determine if the node represents a variable node (based on op name) + * + * @param nodeDef Node to check if a variable + * @return True if a variable node + */ + public static boolean isVariableNode(NodeDef nodeDef) { + boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const"); + return isVar; } - @Override - public boolean isConstant(NodeDef nodeDef) { - return nodeDef.getOp().startsWith("Const"); - } - - @Override - public List getControlDependencies(NodeDef node){ - int numInputs = node.getInputCount(); - if(numInputs == 0) - return null; - - List out = null; - for( int i=0; i(); - out.add(getNodeName(in)); //Remove "^" prefix - } - } - return out; - } - - @Override - public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) { - //placeholder of some kind - if(!node.getAttrMap().containsKey("value")) { - return null; - } - - val tfTensor = node.getAttrOrThrow("value").getTensor(); - INDArray out = mapTensorProto(tfTensor); - return out; - } - - - - public INDArray mapTensorProto(TensorProto tfTensor) { - - TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); - if(m == null){ - throw new RuntimeException("Not implemented datatype: " + tfTensor.getDtype()); - } - INDArray out = m.toNDArray(); - return out; - } - - protected static void setFloat16ValueFromInt(INDArray arr, int idx, int bytesAsPaddedInt){ - ByteBuffer bb = arr.data().pointer().asByteBuffer(); - bb.put(2*idx, (byte)((bytesAsPaddedInt >> 8) & 0xff)); - bb.put(2*idx+1, (byte)(bytesAsPaddedInt & 0xff)); - } - - @Override - public long[] getShapeFromTensor(NodeDef tensorProto) { - if(tensorProto.containsAttr("shape")) { - return shapeFromShapeProto(tensorProto.getAttrOrThrow("shape").getShape()); - - } - //yet to be determined shape, or tied to an op where output shape is dynamic - else if(!tensorProto.containsAttr("value")) { - return null; - - } - else - return shapeFromShapeProto(tensorProto.getAttrOrThrow("value").getTensor().getTensorShape()); - } - - @Override - public Set opsToIgnore() { - return graphMapper; - } - - - @Override - public String getInputFromNode(NodeDef node, int index) { - return node.getInput(index); - } - - @Override - public int numInputsFor(NodeDef nodeDef) { - return nodeDef.getInputCount(); - } - - private long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) { - long[] shape = new long[tensorShapeProto.getDimList().size()]; - for(int i = 0; i < shape.length; i++) { - shape[i] = tensorShapeProto.getDim(i).getSize(); - } - - return shape; - } - - /** - * Returns the node for an if statement - * @param from the starting node (a merge node that represents a conditional) - * @param graph the graph to search - * @return an import state representing the nodes for each scope + * Determine if the node is a placeholder + * + * @param nodeDef Node to check + * @return True if the node is a placeholder */ - public IfImportState nodesForIf(NodeDef from, GraphDef graph) { - //Assume we start with a switch statement - int currNodeIndex = graph.getNodeList().indexOf(from); - val trueDefName = from.getInput(1); - val falseDefName = from.getInput(0); - val scopeId = UUID.randomUUID().toString(); - val scopeName = scopeId + "-" + trueDefName.substring(0,trueDefName.indexOf("/")); - val trueDefScopeName = scopeName + "-true-scope"; - val falseDefScopeName = scopeName + "-false-scope"; - - - boolean onFalseDefinition = true; - //start with the true - boolean onTrueDefinition = false; - - List falseBodyNodes = new ArrayList<>(); - List trueBodyNodes = new ArrayList<>(); - List conditionNodes = new ArrayList<>(); - Set seenNames = new LinkedHashSet<>(); - /** - * Accumulate a list backwards to get proper ordering. - * - */ - for(int i = currNodeIndex; i >= 0; i--) { - //switch to false names - if(graph.getNode(i).getName().equals(trueDefName)) { - onFalseDefinition = false; - onTrueDefinition = true; - } - - //on predicate now - if(graph.getNode(i).getName().contains("pred_id")) { - onTrueDefinition = false; - } - //don't readd the same node, this causes a stackoverflow - if(onTrueDefinition && !graph.getNode(i).equals(from)) { - trueBodyNodes.add(graph.getNode(i)); - } - else if(onFalseDefinition && !graph.getNode(i).equals(from)) { - falseBodyNodes.add(graph.getNode(i)); - } - //condition scope now - else { - val currNode = graph.getNode(i); - if(currNode.equals(from)) - continue; - - //break only after bootstrapping the first node (the predicate id node) - if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { - break; - } - - /** - * Continuously add inputs seen for each node in the sub graph that occurs. - * Starting from the predicate id, any node that has inputs in the condition scope - * are by definition within the scope. Any node not encountered after that is considered out of scope. - * This means we break. - */ - for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) { - seenNames.add(currNode.getInput(inputIdx)); - } - - - - //ensure the "current node" is added as well - seenNames.add(graph.getNode(i).getName()); - conditionNodes.add(graph.getNode(i)); - } - } - - /** - * Since we are going over the graph backwards, - * we need to reverse the nodes to ensure proper ordering. - */ - Collections.reverse(falseBodyNodes); - Collections.reverse(trueBodyNodes); - Collections.reverse(conditionNodes); - - - return IfImportState.builder() - .condNodes(conditionNodes) - .falseNodes(falseBodyNodes) - .trueNodes(trueBodyNodes) - .conditionBodyScopeName(falseDefScopeName) - .falseBodyScopeName(falseDefScopeName) - .trueBodyScopeName(trueDefScopeName) - .conditionBodyScopeName(scopeName) - .build(); + public static boolean isPlaceHolder(NodeDef nodeDef) { + return nodeDef.getOp().startsWith("Placeholder"); } - - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java index 8f59a7ef7..39d8e1577 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java @@ -226,22 +226,24 @@ public class TensorFlowImportValidator { } public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException { - TFGraphMapper m = TFGraphMapper.getInstance(); try { int opCount = 0; Set opNames = new HashSet<>(); try(InputStream bis = new BufferedInputStream(is)) { - GraphDef graphDef = m.parseGraphFrom(bis); - List nodes = m.getNodeList(graphDef); + GraphDef graphDef = GraphDef.parseFrom(bis); + List nodes = new ArrayList<>(graphDef.getNodeCount()); + for( int i=0; i Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -88,7 +87,6 @@ public abstract class BaseLapack implements Lapack { @Override public void potrf(INDArray A, boolean lower) { - // FIXME: int cast if (A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -134,7 +132,6 @@ public abstract class BaseLapack implements Lapack { @Override public void geqrf(INDArray A, INDArray R) { - // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -188,7 +185,6 @@ public abstract class BaseLapack implements Lapack { throw new Error("syev: V must be the length of the matrix dimension."); } - // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -222,7 +218,6 @@ public abstract class BaseLapack implements Lapack { @Override public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) { - // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -279,7 +274,6 @@ public abstract class BaseLapack implements Lapack { */ @Override public INDArray getLFactor(INDArray A) { - // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -304,7 +298,6 @@ public abstract class BaseLapack implements Lapack { @Override public INDArray getUFactor(INDArray A) { - // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java index 6114c52d5..34b89824e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.blas.impl; import lombok.val; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.blas.Level2; import org.nd4j.linalg.api.blas.params.GemvParameters; import org.nd4j.linalg.api.buffer.DataBuffer; @@ -25,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; @@ -113,10 +115,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME: int cast - if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); dgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, Y.stride(-1)); } else { @@ -142,10 +144,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME: int cast - if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y); @@ -173,12 +175,13 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME: int cast - + if (X.length() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, - (int) Y.stride(-1)); + Y.stride(-1)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y); ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, @@ -202,7 +205,9 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, Ap, X, Y); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (Ap.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Y); @@ -231,7 +236,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { OpProfiler.getInstance().processBlasCall(false, Ap, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X); @@ -260,7 +266,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME int cast + if (X.length() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); @@ -291,7 +298,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); @@ -321,7 +329,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); @@ -347,7 +356,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X, Y); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); @@ -376,7 +386,9 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); @@ -402,7 +414,9 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE ) { + throw new ND4JArraySizeException(); + } if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); @@ -429,7 +443,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, Ap, X); - // FIXME: int cast + if (Ap.length() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X); @@ -457,7 +472,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, Ap, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Ap); @@ -485,7 +501,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X); - // FIXME: int cast + if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); @@ -513,7 +530,8 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, X); - // FIXME: int cast + if (A.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java index 3c015c5dc..e08c4c0a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.NDArrayFactory; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; @@ -129,7 +130,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, B, C); - // FIXME: int cast + if (C.rows() > Integer.MAX_VALUE || C.columns() > Integer.MAX_VALUE || + A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE || C.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); @@ -163,7 +167,11 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, C); - // FIXME: int cast + if (C.rows() > Integer.MAX_VALUE || + A.size(0) > Integer.MAX_VALUE || + C.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, C); @@ -198,7 +206,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, B, C); - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || + A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE || C.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); @@ -234,7 +245,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, B, C); - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || + A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); @@ -269,7 +283,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) OpProfiler.getInstance().processBlasCall(false, A, B); - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || + A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 7d1acb609..16931e434 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -86,6 +86,7 @@ import java.io.*; import java.nio.IntBuffer; import java.nio.LongBuffer; import java.util.*; +import java.util.concurrent.atomic.AtomicLong; import static org.nd4j.linalg.factory.Nd4j.*; @@ -124,6 +125,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { protected transient JvmShapeInfo jvmShapeInfo; + private static final AtomicLong arrayCounter = new AtomicLong(0); + protected transient final long arrayId = arrayCounter.getAndIncrement(); + //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over private static final int[][] tadFinalPermuteDimensions; @@ -139,7 +143,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } public BaseNDArray() { - } @Override @@ -306,11 +309,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { * @param ordering the ordering of the ndarray */ public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); } public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); } /** @@ -323,19 +326,19 @@ public abstract class BaseNDArray implements INDArray, Iterable { * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. */ public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); } public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); } public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(type, ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); } public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) { - this(Nd4j.createBuffer(type, ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); } @@ -1068,7 +1071,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { long offset = index * tensorLength / NDArrayMath.lengthPerSlice(ret2); if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { - // FIXME: LONG + if (offset > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); ret2 = ret2.slice((int) offset); if (dimension.length == 1 && ret2.isRowVectorOrScalar()) return ret2; @@ -1078,7 +1082,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { else if (length == NDArrayMath.lengthPerSlice(ret2)) { offset -= ret2.slices() * (offset / ret2.slices()); - // FIXME: LONG + if (offset > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); ret2 = ret2.slice((int) offset); if (dimension.length == 1 && ret2.isRowVectorOrScalar()) return ret2; @@ -3522,16 +3527,21 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - public INDArray repmat(int[] shape) { + public INDArray repmat(long[] shape) { Nd4j.getCompressor().autoDecompress(this); - - long rows = rows() * shape[0]; long cols = columns() * shape[1]; INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]); return ret.reshape(rows, cols); } + @Deprecated + @Override + public INDArray repmat(int[] shape) { + long[] longShape = ArrayUtil.toLongArray(shape); + return repmat(longShape); + } + @Override public INDArray repeat(int dimension, long... repeats) { Nd4j.getCompressor().autoDecompress(this); @@ -3669,9 +3679,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.create(data, shape, strides, 0, ordering()); } + @Deprecated @Override public INDArray reshape(char order, int... newShape) { - // FIXME: int cast return reshape(order, ArrayUtil.toLongArray(newShape)); } @@ -3973,7 +3983,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public int columns() { - // FIXME: int cast if (isMatrix()) return (int) size(1); else if (Shape.isColumnVectorShape(shape())) { @@ -3988,7 +3997,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public int rows() { - // FIXME: if (isMatrix()) return (int) size(0); else if (Shape.isRowVectorShape(shape())) { @@ -4570,7 +4578,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } else { - // FIXME: int cast int[] repeat = new int[shape.length]; for(int i = 0; i < shape.length; i++) { if(i < rank()) { @@ -4600,9 +4607,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); } + @Deprecated @Override public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { - // FIXME: int cast return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); } @@ -4916,6 +4923,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public String toString(@NonNull NDArrayStrings options){ + if(wasClosed()) + return ""; if (!isCompressed() && !preventUnpack) return options.format(this); else if (isCompressed() && compressDebug) @@ -5600,4 +5609,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return false; } + + @Override + public long getId(){ + return arrayId; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index a2860b582..de80e9413 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -871,6 +871,9 @@ public interface INDArray extends Serializable, AutoCloseable { * @param shape the new shape of this ndarray * @return the shape to fill out to */ + INDArray repmat(long... shape); + + @Deprecated INDArray repmat(int... shape); /** @@ -2814,4 +2817,10 @@ public interface INDArray extends Serializable, AutoCloseable { * @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int) */ String toStringFull(); + + /** + * A unique ID for the INDArray object instance. Does not account for views. + * @return INDArray unique ID + */ + long getId(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index a41dc8790..f6148ed7c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -63,19 +63,11 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); } else { throw new IllegalArgumentException("Input not null variables."); } - } public BaseBroadcastBoolOp(SameDiff sameDiff) { @@ -108,16 +100,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp this(sameDiff, i_v, i_v.getShape(), inPlace, dimension, null); } - public BaseBroadcastBoolOp(SameDiff sameDiff, - SDVariable i_v, - int[] shape, - boolean inPlace, - int[] dimension, - Object[] extraArgs) { - // FIXME: int cast - this(sameDiff, i_v, ArrayUtil.toLongArray(shape), inPlace, dimension, extraArgs); - } - public BaseBroadcastBoolOp(SameDiff sameDiff, SDVariable i_v, long[] shape, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index 7f0d7e40c..9de4d3fbd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -64,19 +64,10 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - } else { throw new IllegalArgumentException("Input not null variables."); } - } public BaseBroadcastOp(SameDiff sameDiff) { @@ -109,16 +100,6 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { this(sameDiff, i_v, i_v.getShape(), inPlace, dimension, null); } - public BaseBroadcastOp(SameDiff sameDiff, - SDVariable i_v, - int[] shape, - boolean inPlace, - int[] dimension, - Object[] extraArgs) { - // FIXME: int cast - this(sameDiff, i_v, ArrayUtil.toLongArray(shape), inPlace, dimension, extraArgs); - } - public BaseBroadcastOp(SameDiff sameDiff, SDVariable i_v, long[] shape, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 6e5682962..445a0d585 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -53,11 +53,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); sameDiff.addArgsFor(new SDVariable[]{i_v},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); } else { throw new IllegalArgumentException("Input not null variable."); } @@ -75,17 +72,9 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v2); - this.xVertexId = i_v.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); - - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } } else { throw new IllegalArgumentException("Input not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 8c9cdf4e0..65bdb5c1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -24,6 +24,7 @@ import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -200,48 +201,17 @@ public abstract class BaseOp extends DifferentialFunction implements Op { @Override public void setX(INDArray x) { - if (x == null) { - if (args() != null && args().length >= 1) { - SDVariable firstArg = args()[0]; - if (firstArg.getArr() != null) - this.x = firstArg.getArr(); - } else - throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments"); - } else - this.x = x; + this.x = x; } @Override public void setZ(INDArray z) { - if (z == null) { - SDVariable getResult = sameDiff.getVariable(zVertexId); - if (getResult != null) { - if (getResult.getArr() != null) - this.z = getResult.getArr(); - else if(sameDiff.getShapeForVarName(getResult.getVarName()) != null) { - val shape = sameDiff.getShapeForVarName(getResult.getVarName()); - sameDiff.setArrayForVariable(getResult.getVarName(),getResult.getWeightInitScheme().create(getResult.dataType(), shape)); - } - else - throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments"); - - } else - throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments"); - } else - this.z = z; + this.z = z; } @Override public void setY(INDArray y) { - if (y == null) { - if (args() != null && args().length > 1) { - SDVariable firstArg = args()[1]; - if (firstArg.getArr() != null) - this.y = firstArg.getArr(); - } else - throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments"); - } else - this.y = y; + this.y = y; } @Override @@ -265,13 +235,19 @@ public abstract class BaseOp extends DifferentialFunction implements Op { return z; } + @Override + public INDArray getInputArgument(int index){ + Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index); + return index == 0 ? x : y; + } + @Override public SDVariable[] outputVariables(String baseName) { if(zVertexId == null) { val outputNames = sameDiff.getOutputsForOp(this); //no need to dynamically create if already exists if(outputNames != null) { - zVertexId = sameDiff.getVariable(outputNames[0]).getVarName(); + zVertexId = sameDiff.getVariable(outputNames[0]).name(); return new SDVariable[]{sameDiff.getVariable(outputNames[0])}; @@ -285,7 +261,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { return newVars; } - sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr); + sameDiff.setArrayForVariable(newVars[0].name(),inputArr); z = inputArr; if(sameDiff.getOutputsForOp(this) == null) sameDiff.addOutgoingFor(newVars,this); @@ -403,4 +379,11 @@ public abstract class BaseOp extends DifferentialFunction implements Op { //Always 1 for legacy/base ops return 1; } + + @Override + public void clearArrays(){ + x = null; + y = null; + z = null; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 10c26d29e..9e5b8f67b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.api.ops; -import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; @@ -24,21 +23,14 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -69,12 +61,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); this.keepDims = keepDims; - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - } else { throw new IllegalArgumentException("Input not null variable."); } @@ -93,8 +81,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.dimensions = dimensions; - this.xVertexId = i_v.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v.name(); + this.yVertexId = i_v2.name(); f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v2); this.keepDims = keepDims; @@ -219,14 +207,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - if (!attributesForNode.containsKey("axes")) { - this.dimensions = new int[] { Integer.MAX_VALUE }; - } - else { - val map = OnnxGraphMapper.getInstance().getAttrMap(node); - val dims = Ints.toArray(map.get("axes").getIntsList()); - this.dimensions = dims; - } + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 54b53f489..3abca3db2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -74,11 +74,8 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { super(sameDiff,inPlace,extraArgs); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); if (i_v != null) { - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } f().validateDifferentialFunctionsameDiff(i_v); } else { throw new IllegalArgumentException("Input not null variable."); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index 0048c9402..ce74a7cd1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -93,11 +93,8 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { Object[] extraArgs) { super(sameDiff,inPlace,extraArgs); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } f().validateDifferentialFunctionsameDiff(i_v); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java index 046e296e5..8efba0fdf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java @@ -49,10 +49,6 @@ public abstract class BaseTransformAnyOp extends BaseTransformOp implements Tran super(sameDiff, i_v1, i_v2, extraArgs); } - public BaseTransformAnyOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public BaseTransformAnyOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java index 68a06f61c..df0e04d5c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java @@ -40,10 +40,6 @@ public abstract class BaseTransformBoolOp extends BaseTransformOp implements Tra super(sameDiff, i_v1, i_v2, inPlace); } - public BaseTransformBoolOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public BaseTransformBoolOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java index fc778af90..ee97a4bba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java @@ -33,10 +33,6 @@ import java.util.List; public abstract class BaseTransformFloatOp extends BaseTransformOp implements TransformFloatOp { - public BaseTransformFloatOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public BaseTransformFloatOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java index 8afc68a52..4e498edeb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java @@ -56,16 +56,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; - this.xVertexId = i_v1.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v1.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } } else { throw new IllegalArgumentException("Input not null variables."); } @@ -87,18 +80,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { f().validateDifferentialFunctionsameDiff(i_v1); f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; - this.xVertexId = i_v1.getVarName(); - this.yVertexId = i_v2.getVarName(); + this.xVertexId = i_v1.name(); + this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); - - if(Shape.isPlaceholderShape(i_v1.getShape())) { - sameDiff.addPropertyToResolve(this,i_v1.getVarName()); - } - - if(Shape.isPlaceholderShape(i_v2.getShape())) { - sameDiff.addPropertyToResolve(this,i_v2.getVarName()); - } - } else { throw new IllegalArgumentException("Input not null variables."); } @@ -112,15 +96,6 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { this(sameDiff,i_v,i_v.getShape(),inPlace,null); } - public BaseTransformOp(SameDiff sameDiff, - SDVariable i_v, - int[] shape, - boolean inPlace, - Object[] extraArgs) { - // FIXME: int cast ! - this(sameDiff, i_v, ArrayUtil.toLongArray(shape), inPlace, extraArgs); - } - public BaseTransformOp(SameDiff sameDiff, SDVariable i_v, long[] shape, @@ -130,14 +105,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { if (i_v != null) { f().validateDifferentialFunctionsameDiff(i_v); - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new SDVariable[]{i_v},this); - - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } - - } else { throw new IllegalArgumentException("Input must not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java index 7fc34d0e5..b04c24c8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java @@ -44,10 +44,6 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra super(sameDiff, i_v1, i_v2, extraArgs); } - public BaseTransformSameOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public BaseTransformSameOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java index ff40ebae4..ff89e49ba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java @@ -41,10 +41,6 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T super(sameDiff, i_v1, i_v2, inPlace); } - public BaseTransformStrictOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public BaseTransformStrictOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index 9deb230df..c228c9b8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -119,4 +119,9 @@ public interface CustomOp { * otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException} */ void assertValidForExecution(); + + /** + * Clear the input and output INDArrays, if any are set + */ + void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index d2190098c..99e930176 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -223,7 +223,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { if (args().length >= 1) { val arr = args()[0].getArr(); if (arr != null) { - sameDiff.setArrayForVariable(newVars[0].getVarName(), arr); + sameDiff.setArrayForVariable(newVars[0].name(), arr); addOutputArgument(arr); } } @@ -263,7 +263,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public INDArray[] outputArguments() { if (!outputArguments.isEmpty()) { - return outputArguments.toArray(new INDArray[outputArguments.size()]); + return outputArguments.toArray(new INDArray[0]); } return new INDArray[0]; } @@ -271,7 +271,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public INDArray[] inputArguments() { if (!inputArguments.isEmpty()) - return inputArguments.toArray(new INDArray[inputArguments.size()]); + return inputArguments.toArray(new INDArray[0]); return new INDArray[0]; } @@ -389,6 +389,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } public void setInputArgument(int index, INDArray input) { + if(index >= inputArguments.size() ){ + List oldArgs = inputArguments; + inputArguments = new ArrayList<>(index+1); + inputArguments.addAll(oldArgs); + while(inputArguments.size() <= index) + inputArguments.add(null); + } inputArguments.set(index, input); } @@ -400,12 +407,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } public void setOutputArgument(int index, INDArray output) { - if(index == outputArguments.size()){ - //For example, setOutputArgument(0,arr) on empty list - outputArguments.add(output); - } else { - outputArguments.set(index, output); + while(index >= outputArguments.size()){ + //Resize list, in case we want to specify arrays not in order they are defined + //For example, index 1 on empty list, then index 0 + outputArguments.add(null); } + outputArguments.set(index, output); } @Override @@ -608,6 +615,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } + @Override + public void clearArrays(){ + inputArguments.clear(); + outputArguments.clear(); + } + protected static INDArray[] wrapOrNull(INDArray in){ return in == null ? null : new INDArray[]{in}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java index 3e5644439..ca0a816c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java @@ -167,4 +167,9 @@ public interface Op { * @return the equivalent {@link CustomOp} */ CustomOp toCustomOp(); + + /** + * Clear the input and output INDArrays, if any are set + */ + void clearArrays(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateCBOW.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateCBOW.java deleted file mode 100644 index a9d327a35..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateCBOW.java +++ /dev/null @@ -1,172 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.aggregates.impl; - -import lombok.NonNull; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.BaseAggregate; -import org.nd4j.linalg.factory.Nd4j; - -/** - * @author raver119@gmail.com - */ -@Deprecated -public class AggregateCBOW extends BaseAggregate { - private int vectorLength; - - /** - * Optional constructor for ParagraphVectors PV-DM implementation - * - * @param syn0 - * @param syn1 - * @param syn1Neg - * @param expTable - * @param negTable - * @param wordIdx - * @param idxSyn0 - * @param idxSyn1 - * @param codes - * @param negativeRounds - * @param ngStarter - * @param vectorLength - * @param alpha - * @param nextRandom - * @param vocabSize - * @param numLabels - * @param trainWords - */ - public AggregateCBOW(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, - INDArray negTable, int wordIdx, int[] idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, - int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize, int numLabels, - boolean trainWords, INDArray inferenceVector) { - this(syn0, syn1, syn1Neg, expTable, negTable, wordIdx, idxSyn0, idxSyn1, codes, negativeRounds, ngStarter, - vectorLength, alpha, nextRandom, vocabSize); - - indexingArguments.set(9, numLabels); - indexingArguments.set(10, trainWords ? 1 : 0); - indexingArguments.set(11, inferenceVector == null ? 0 : 1); // set inference to true - - arguments.set(5, inferenceVector); - } - - /** - * Default constructor for CBOW implementation wrapper - * @param syn0 - * @param syn1 - * @param syn1Neg - * @param expTable - * @param negTable - * @param wordIdx - * @param idxSyn0 - * @param idxSyn1 - * @param codes - * @param negativeRounds - * @param ngStarter - * @param vectorLength - * @param alpha - * @param nextRandom - * @param vocabSize - */ - public AggregateCBOW(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, - INDArray negTable, int wordIdx, int[] idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, - int ngStarter, int vectorLength, double alpha, long nextRandom, int vocabSize) { - indexingArguments.add(vectorLength); - indexingArguments.add(idxSyn1.length); - indexingArguments.add(negativeRounds); - - // FIXME: int cast - indexingArguments.add((int) expTable.length()); - indexingArguments.add(vocabSize); - indexingArguments.add(ngStarter); - indexingArguments.add(negTable == null ? 0 : (int) negTable.length()); - indexingArguments.add(idxSyn0.length); - indexingArguments.add(wordIdx); - indexingArguments.add(0); // number of labels. 0 by default - indexingArguments.add(1); // trainWords? true by default - indexingArguments.add(0); // is inference? false by default - - - arguments.add(syn0); - arguments.add(syn1); - arguments.add(expTable); - arguments.add(syn1Neg); - arguments.add(negTable); - arguments.add(null); - - intArrayArguments.add(idxSyn0); - intArrayArguments.add(idxSyn1); - intArrayArguments.add(codes); - - realArguments.add(alpha); - realArguments.add((double) nextRandom); - - this.vectorLength = vectorLength; - } - - @Override - public String name() { - return "aggregate_cbow"; - } - - @Override - public int opNum() { - return 4; - } - - @Override - public int maxArguments() { - return 6; - } - - @Override - public int maxShapes() { - return 0; - } - - @Override - public int maxIntArrays() { - return 3; - } - - @Override - public int maxIntArraySize() { - return 40; - } - - @Override - public int maxIndexArguments() { - return 12; - } - - @Override - public int maxRealArguments() { - return 2; - } - - @Override - public int getSharedMemorySize() { - return (vectorLength * Nd4j.sizeOfDataType() * 2) + 512; - } - - @Override - public int getThreadsPerInstance() { - if (vectorLength > 768) - return 768; - - return vectorLength; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateDot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateDot.java deleted file mode 100644 index a5ef4a4da..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateDot.java +++ /dev/null @@ -1,107 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.aggregates.impl; - -import lombok.NonNull; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.BaseAggregate; -import org.nd4j.linalg.factory.Nd4j; - -/** - * This op describes Dot call that'll happen soon(TM) in batch mode - * - * @author raver119@gmail.com - */ -@Deprecated -public class AggregateDot extends BaseAggregate { - private int vectorLength; - - public AggregateDot(@NonNull INDArray x, @NonNull INDArray y) { - this.arguments.add(x); - this.arguments.add(y); - - // FIXME: int cast - - this.indexingArguments.add((int) x.length()); - this.vectorLength = (int) x.length(); - } - - /** - * This method returns amount of shared memory required for this specific Aggregate. - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getSharedMemorySize() { - return (getThreadsPerInstance() * Nd4j.sizeOfDataType()) + 512; - } - - /** - * This method returns desired number of threads per Aggregate instance - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getThreadsPerInstance() { - if (vectorLength > 768) - return 768; - - return vectorLength; - } - - @Override - public String name() { - return "aggregate_dot"; - } - - @Override - public int opNum() { - return 1; - } - - @Override - public int maxArguments() { - return 2; - } - - @Override - public int maxShapes() { - return 0; - } - - @Override - public int maxIntArrays() { - return 0; - } - - @Override - public int maxIntArraySize() { - return 0; - } - - @Override - public int maxIndexArguments() { - return 1; - } - - @Override - public int maxRealArguments() { - return 0; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateSkipGram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateSkipGram.java deleted file mode 100644 index 7fa52ece2..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateSkipGram.java +++ /dev/null @@ -1,165 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.aggregates.impl; - -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.BaseAggregate; -import org.nd4j.linalg.factory.Nd4j; - -/** - * This aggregate encapsulates AggregateSkipGram training round for a given word and context - * - * @author raver119@gmail.com - */ -@Slf4j -@Deprecated -public class AggregateSkipGram extends BaseAggregate { - private int vectorLength; - - public AggregateSkipGram(INDArray syn0, INDArray syn1, INDArray syn1Neg, INDArray expTable, INDArray negTable, - int idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, int vectorLength, - double alpha, long nextRandom, int vocabSize, INDArray inferenceVector) { - this(syn0, syn1, syn1Neg, expTable, negTable, idxSyn0, idxSyn1, codes, negativeRounds, ngStarter, vectorLength, - alpha, nextRandom, vocabSize); - - arguments.set(5, inferenceVector); - - indexingArguments.set(8, inferenceVector == null ? 0 : 1); // set isInference to true - } - - public AggregateSkipGram(@NonNull INDArray syn0, INDArray syn1, INDArray syn1Neg, @NonNull INDArray expTable, - INDArray negTable, int idxSyn0, int[] idxSyn1, int[] codes, int negativeRounds, int ngStarter, - int vectorLength, double alpha, long nextRandom, int vocabSize) { - indexingArguments.add(idxSyn0); - indexingArguments.add(vectorLength); - indexingArguments.add(idxSyn1.length); - indexingArguments.add(negativeRounds); - - // FIXME: int cast - indexingArguments.add((int) expTable.length()); - indexingArguments.add(vocabSize); - indexingArguments.add(ngStarter); - - indexingArguments.add(negTable == null ? 0 : (int) negTable.length()); - indexingArguments.add(0); - - arguments.add(syn0); - arguments.add(syn1); - arguments.add(expTable); - arguments.add(syn1Neg); - arguments.add(negTable); - arguments.add(null); - - intArrayArguments.add(idxSyn1); - intArrayArguments.add(codes); - - realArguments.add(alpha); - realArguments.add((double) nextRandom); - - this.vectorLength = vectorLength; - } - - /** - * This is special signature suitable for use with VoidParameterServer, never ever use it outside of spark-nlp - * - * @param w1 - * @param w2 - * @param lr - * @param vectorLength - */ - // TODO: probably this signature should be removed? - public AggregateSkipGram(int w1, int w2, int[] codes, int[] points, int negSamples, double lr, int vectorLength) { - indexingArguments.add(w1); - indexingArguments.add(w2); - indexingArguments.add(vectorLength); - - intArrayArguments.add(codes); - intArrayArguments.add(points); - - realArguments.add(lr); - } - - - /** - * This method returns amount of shared memory required for this specific Aggregate. - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getSharedMemorySize() { - return (vectorLength * Nd4j.sizeOfDataType()) + 512; - } - - /** - * This method returns desired number of threads per Aggregate instance - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getThreadsPerInstance() { - if (vectorLength > 768) - return 768; - - return vectorLength; - } - - @Override - public String name() { - return "aggregate_skipgram"; - } - - @Override - public int opNum() { - return 3; - } - - @Override - public int maxArguments() { - return 6; - } - - @Override - public int maxShapes() { - return 0; - } - - @Override - public int maxIntArrays() { - return 2; - } - - @Override - public int maxIntArraySize() { - // we hardcode 40 here, due to w2v codeLength mechanics - // TODO: make sure this limitation doesn't bother with spark environment - return 40; - } - - @Override - public int maxIndexArguments() { - return 10; - } - - @Override - public int maxRealArguments() { - return 2; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/HierarchicSoftmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/HierarchicSoftmax.java deleted file mode 100644 index de494dbff..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/HierarchicSoftmax.java +++ /dev/null @@ -1,114 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.aggregates.impl; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.BaseAggregate; -import org.nd4j.linalg.factory.Nd4j; - -/** - * This Op describes HS round for AggregateSkipGram/CBOW Hierarchic Softmax - * - * @author raver119@gmail.com - */ -@Deprecated -public class HierarchicSoftmax extends BaseAggregate { - private int vectorLength; - - public HierarchicSoftmax(INDArray syn0, INDArray syn1, INDArray expTable, INDArray neu1e, int code, double lr) { - arguments.add(syn0); - arguments.add(syn1); - arguments.add(expTable); - arguments.add(neu1e); - - // FIXME: int cast - - indexingArguments.add((int) neu1e.length()); - indexingArguments.add((int) expTable.length()); - indexingArguments.add(code); - indexingArguments.add(0); // set isInference to false - - realArguments.add(lr); - - this.vectorLength = (int) neu1e.length(); - } - - /** - * This method returns amount of shared memory required for this specific Aggregate. - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getSharedMemorySize() { - return (getThreadsPerInstance() * Nd4j.sizeOfDataType()) + 512; - } - - /** - * This method returns desired number of threads per Aggregate instance - * PLEASE NOTE: this method is especially important for CUDA backend. On CPU backend it might be ignored, depending on Aggregate. - * - * @return - */ - @Override - public int getThreadsPerInstance() { - if (vectorLength > 768) - return 768; - - return vectorLength; - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String name() { - return "aggregate_hs"; - } - - @Override - public int maxArguments() { - return 4; - } - - @Override - public int maxShapes() { - return 0; - } - - @Override - public int maxIntArrays() { - return 0; - } - - @Override - public int maxIntArraySize() { - return 0; - } - - @Override - public int maxIndexArguments() { - return 5; - } - - @Override - public int maxRealArguments() { - return 1; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java new file mode 100644 index 000000000..80f29e577 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -0,0 +1,30 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class AdjustContrast extends BaseAdjustContrast { + + public AdjustContrast() {super();} + + public AdjustContrast(INDArray in, double factor, INDArray out) { + super(in, factor, out); + } + + public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super(sameDiff,new SDVariable[]{in,factor}); + } + + @Override + public String opName() { + return "adjust_contrast"; + } + + @Override + public String tensorflowName() { + return "AdjustContrast"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java new file mode 100644 index 000000000..5adfcbafd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -0,0 +1,30 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class AdjustContrastV2 extends BaseAdjustContrast { + + public AdjustContrastV2() {super();} + + public AdjustContrastV2(INDArray in, double factor, INDArray out) { + super(in, factor, out); + } + + public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super( sameDiff,new SDVariable[]{in,factor}); + } + + @Override + public String opName() { + return "adjust_contrast_v2"; + } + + @Override + public String tensorflowName() { + return "AdjustContrastV2"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java new file mode 100644 index 000000000..cadef80e6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -0,0 +1,25 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public abstract class BaseAdjustContrast extends DynamicCustomOp { + public BaseAdjustContrast() { + } + + public BaseAdjustContrast(INDArray in, double factor, INDArray out) { + Preconditions.checkArgument(in.rank() >= 3, + String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); + inputArguments.add(in); + outputArguments.add(out); + + addTArgument(factor); + } + + public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { + super("", sameDiff, vars); + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java new file mode 100644 index 000000000..ee0adfb94 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -0,0 +1,32 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +public class BitCast extends DynamicCustomOp { + public BitCast() {} + + public BitCast(INDArray in, int dataType, INDArray out) { + inputArguments.add(in); + outputArguments.add(out); + iArguments.add(Long.valueOf(dataType)); + } + + public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { + super("", sameDiff, new SDVariable[]{in, dataType}); + } + + @Override + public String opName() { + return "bitcast"; + } + + @Override + public String tensorflowName() { + return "Bitcast"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java new file mode 100644 index 000000000..d69c73da4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -0,0 +1,31 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +public class CompareAndBitpack extends DynamicCustomOp { + public CompareAndBitpack() {} + + public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + inputArguments.add(in); + inputArguments.add(Nd4j.scalar(threshold)); + outputArguments.add(out); + } + + public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) { + super("", sameDiff, new SDVariable[]{threshold}); + } + + @Override + public String opName() { + return "compare_and_bitpack"; + } + + @Override + public String tensorflowName() { + return "CompareAndBitpack"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java new file mode 100644 index 000000000..400830ec3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -0,0 +1,32 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.apache.commons.math3.analysis.function.Divide; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class DivideNoNan extends DynamicCustomOp { + public DivideNoNan() { + } + + public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + inputArguments.add(in1); + inputArguments.add(in2); + outputArguments.add(out); + } + + public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) { + super("", sameDiff, new SDVariable[]{in1, in2}); + } + + @Override + public String opName() { + return "divide_no_nan"; + } + + @Override + public String tensorflowName() { + return "DivNoNan"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java new file mode 100644 index 000000000..4c672a66c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -0,0 +1,32 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class DrawBoundingBoxes extends DynamicCustomOp { + public DrawBoundingBoxes() {} + + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, + INDArray output) { + inputArguments.add(images); + inputArguments.add(boxes); + inputArguments.add(colors); + outputArguments.add(output); + } + + public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) { + super("", sameDiff, new SDVariable[]{boxes, colors}); + } + + @Override + public String opName() { + return "draw_bounding_boxes"; + } + + @Override + public String tensorflowName() { + return "DrawBoundingBoxes"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java new file mode 100644 index 000000000..303ac8458 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -0,0 +1,36 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { + public FakeQuantWithMinMaxVarsPerChannel() {} + + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, + INDArray output) { + Preconditions.checkArgument(min.isVector() && max.isVector() && + min.length() == max.length(), + "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); + inputArguments.add(x); + inputArguments.add(min); + inputArguments.add(max); + outputArguments.add(output); + } + + public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { + super("", sameDiff, new SDVariable[]{x, min, max}); + } + + @Override + public String opName() { + return "fake_quant_with_min_max_vars_per_channel"; + } + + @Override + public String tensorflowName() { + return "FakeQuantWithMinMaxVarsPerChannel"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java new file mode 100644 index 000000000..16656766f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/KnnMinDistance.java @@ -0,0 +1,23 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class KnnMinDistance extends DynamicCustomOp { + + public KnnMinDistance() { + } + + public KnnMinDistance(INDArray point, INDArray lowest, INDArray highest, INDArray distance) { + inputArguments.add(point); + inputArguments.add(lowest); + inputArguments.add(highest); + + outputArguments.add(distance); + } + + @Override + public String opName() { + return "knn_mindistance"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 0741e512e..cb805a775 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -245,4 +245,9 @@ public class ScatterUpdate implements CustomOp { public void assertValidForExecution() { } + + @Override + public void clearArrays() { + op.clearArrays(); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index aed50c987..d8bf3f695 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -39,13 +39,18 @@ import java.util.*; @NoArgsConstructor public class BiasAdd extends DynamicCustomOp { + protected boolean nchw = true; - public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias) { + public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) { super(null, sameDiff, new SDVariable[] {input, bias}, false); + bArguments.clear(); + bArguments.add(nchw); } - public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){ + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){ super(new INDArray[]{input, bias}, wrapOrNull(output)); + bArguments.clear(); + bArguments.add(nchw); } @Override @@ -56,7 +61,11 @@ public class BiasAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); - + if(attributesForNode.containsKey("data_format")){ + nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8()); + } + bArguments.clear(); + bArguments.add(nchw); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAMax.java index 726d82b0a..c1fceccb6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastAMax.java @@ -50,7 +50,7 @@ public class BroadcastAMax extends BaseBroadcastOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastAMax(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastAMax(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastCopyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastCopyOp.java index 639ee24ec..00700b3c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastCopyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastCopyOp.java @@ -48,7 +48,7 @@ public class BroadcastCopyOp extends BaseBroadcastOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastCopyOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastCopyOp(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java index fe47e2bc3..8a7234532 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java @@ -58,7 +58,7 @@ public class BroadcastMin extends BaseBroadcastOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastMin(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastMin(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastRSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastRSubOp.java index 0ddf777c2..1a4ec9887 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastRSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastRSubOp.java @@ -46,7 +46,7 @@ public class BroadcastRSubOp extends BaseBroadcastOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastRSubOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastRSubOp(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java index 035e6d882..e060db4b6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java @@ -52,7 +52,7 @@ public class BroadcastSubOp extends BaseBroadcastOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastSubOp(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastSubOp(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastGreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastGreaterThan.java index 733744fcd..63f1e2c45 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastGreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastGreaterThan.java @@ -53,7 +53,7 @@ public class BroadcastGreaterThan extends BaseBroadcastBoolOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastGreaterThan(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastGreaterThan(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java index 0f715a56a..9fab3350f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java @@ -54,7 +54,7 @@ public class BroadcastLessThan extends BaseBroadcastBoolOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastLessThan(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastLessThan(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java index 903919d4b..e9ee1db2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java @@ -54,7 +54,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp { super(sameDiff, i_v, dimension, inPlace); } - public BroadcastLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { + public BroadcastLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, int[] dimension, Object[] extraArgs) { super(sameDiff, i_v, shape, inPlace, dimension, extraArgs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java deleted file mode 100644 index 03dc26313..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java +++ /dev/null @@ -1,402 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.util.HashUtil; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.*; - -/** - * Equivalent to tensorflow's conditional op. - * Runs one of 2 {@link SameDiff.SameDiffFunctionDefinition} - * depending on a predicate {@link org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional} - * - * - * @author Adam Gibson - */ -@NoArgsConstructor -@Slf4j -public class If extends DifferentialFunction implements CustomOp { - - @Getter - protected SameDiff loopBodyExecution,predicateExecution,falseBodyExecution; - - - @Getter - protected SameDiffConditional predicate; - @Getter - protected SameDiffFunctionDefinition trueBody,falseBody; - - @Getter - protected String blockName,trueBodyName,falseBodyName; - - @Getter - protected SDVariable[] inputVars; - - @Getter - protected Boolean trueBodyExecuted = null; - - @Getter - protected SDVariable targetBoolean; - - protected SDVariable dummyResult; - - @Getter - @Setter - protected SDVariable[] outputVars; - - public If(If ifStatement) { - this.sameDiff = ifStatement.sameDiff; - this.outputVars = ifStatement.outputVars; - this.falseBodyExecution = ifStatement.falseBodyExecution; - this.trueBodyExecuted = ifStatement.trueBodyExecuted; - this.falseBody = ifStatement.falseBody; - this.trueBodyExecuted = ifStatement.trueBodyExecuted; - this.dummyResult = ifStatement.dummyResult; - this.inputVars = ifStatement.inputVars; - this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme(), DataType.FLOAT, 1); - if(sameDiff.getShapeForVarName(dummyResult.getVarName()) == null) - sameDiff.putShapeForVarName(dummyResult.getVarName(),new long[]{1,1}); - - - - - } - - @Builder - public If(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffFunctionDefinition conditionBody, - SameDiffConditional predicate, - SameDiffFunctionDefinition trueBody, - SameDiffFunctionDefinition falseBody) { - - this.sameDiff = parent; - parent.putOpForId(getOwnName(),this); - this.inputVars = inputVars; - this.predicate = predicate; - - parent.addArgsFor(inputVars,this); - this.trueBody = trueBody; - this.falseBody = falseBody; - this.blockName = blockName; - //need to add the op to the list of ops to be executed when running backwards - this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - parent.addOutgoingFor(new SDVariable[]{dummyResult},this); - - //create a samediff sub graph for running just the execution - //return a reference to the loop for referencing during actual execution - SameDiff sameDiff = SameDiff.create(); - //store the reference to the result array and the same diff execution instance - this.targetBoolean = predicate.eval(sameDiff,conditionBody, inputVars); - this.predicateExecution = sameDiff; - //store references to the loop body - String trueBodyName = "true-body-" + UUID.randomUUID().toString(); - this.trueBodyName = trueBodyName; - - String falseBodyName = "false-body-" + UUID.randomUUID().toString(); - this.falseBodyName = trueBodyName; - - //running define function will setup a proper same diff instance - this.loopBodyExecution = parent.defineFunction(trueBodyName,trueBody,inputVars); - this.falseBodyExecution = parent.defineFunction(falseBodyName,falseBody,inputVars); - parent.defineFunction(blockName,conditionBody,inputVars); - parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(),sameDiff); - //get a reference to the actual loop body - this.loopBodyExecution = parent.getFunction(trueBodyName); - } - - - /** - * Toggle whether the true body was executed - * or the false body - * @param trueBodyExecuted - */ - public void exectedTrueOrFalse(boolean trueBodyExecuted) { - if(trueBodyExecuted) - this.trueBodyExecuted = true; - else - this.trueBodyExecuted = false; - } - - - - @Override - public SDVariable[] outputVariables(String baseName) { - return new SDVariable[]{dummyResult}; - } - - @Override - public List doDiff(List f1) { - List ret = new ArrayList<>(); - ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables())); - return ret; - } - - @Override - public String toString() { - return opName(); - } - - @Override - public String opName() { - return "if"; - } - - @Override - public long opHash() { - return HashUtil.getLongHash(opName()); - } - - @Override - public boolean isInplaceCall() { - return false; - } - - @Override - public INDArray[] outputArguments() { - return new INDArray[0]; - } - - @Override - public INDArray[] inputArguments() { - return new INDArray[0]; - } - - @Override - public long[] iArgs() { - return new long[0]; - } - - @Override - public double[] tArgs() { - return new double[0]; - } - - @Override - public boolean[] bArgs() { - return new boolean[0]; - } - - @Override - public void addIArgument(int... arg) { - - } - - @Override - public void addIArgument(long... arg) { - - } - - @Override - public void addBArgument(boolean... arg) { - - } - - @Override - public void removeIArgument(Integer arg) { - - } - - @Override - public Boolean getBArgument(int index) { - return null; - } - - @Override - public Long getIArgument(int index) { - return null; - } - - @Override - public int numIArguments() { - return 0; - } - - @Override - public void addTArgument(double... arg) { - - } - - @Override - public void removeTArgument(Double arg) { - - } - - @Override - public Double getTArgument(int index) { - return null; - } - - @Override - public int numTArguments() { - return 0; - } - - @Override - public int numBArguments() { - return 0; - } - - @Override - public void addInputArgument(INDArray... arg) { - - } - - @Override - public void removeInputArgument(INDArray arg) { - - } - - @Override - public INDArray getInputArgument(int index) { - return null; - } - - @Override - public int numInputArguments() { - return 0; - } - - @Override - public void addOutputArgument(INDArray... arg) { - - } - - @Override - public void removeOutputArgument(INDArray arg) { - - } - - @Override - public INDArray getOutputArgument(int index) { - return null; - } - - @Override - public int numOutputArguments() { - return 0; - } - - @Override - public Op.Type opType() { - return Op.Type.CONDITIONAL; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - //cond is only part of while loops - if(nodeDef.getName().contains("/cond/")) - return; - //usually should be a merge node for a conditional - val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph); - - - val trueScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getTrueNodes()) { - trueScopeGraphDefBuilder.addNode(node); - } - - - val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build()); - - - val falseScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getFalseNodes()) { - falseScopeGraphDefBuilder.addNode(node); - - } - - val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build()); - - - val condScopeGraphDefBuilder = GraphDef.newBuilder(); - for(val node : ifNodes.getCondNodes()) { - condScopeGraphDefBuilder.addNode(node); - - } - - - val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build()); - - - - initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope); - initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope); - initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope); - - this.loopBodyExecution = trueScope; - this.falseBodyExecution = falseScope; - this.predicateExecution = condScope; - } - - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - - - @Override - public List calculateOutputShape() { - return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL)); - } - - @Override - public CustomOpDescriptor getDescriptor() { - return null; - } - - @Override - public void assertValidForExecution() { - - } - - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("This operation has no TF counterpart"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java deleted file mode 100644 index 77b2eafa1..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/IfDerivative.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; - -import java.util.List; - -@NoArgsConstructor -public class IfDerivative extends If { - - private If ifDelegate; - - public IfDerivative(If ifBlock) { - super(ifBlock); - this.ifDelegate = ifBlock; - } - - @Override - public Boolean getTrueBodyExecuted() { - return ifDelegate.trueBodyExecuted; - } - - - @Override - public SameDiffFunctionDefinition getFalseBody() { - return ifDelegate.falseBody; - } - - @Override - public SameDiff getFalseBodyExecution() { - return ifDelegate.falseBodyExecution; - } - - @Override - public String getBlockName() { - return ifDelegate.blockName; - } - - @Override - public String getFalseBodyName() { - return ifDelegate.falseBodyName; - } - - @Override - public SameDiff getLoopBodyExecution() { - return ifDelegate.loopBodyExecution; - } - - @Override - public SameDiffConditional getPredicate() { - return ifDelegate.getPredicate(); - } - - @Override - public SameDiff getPredicateExecution() { - return ifDelegate.predicateExecution; - } - - @Override - public List calculateOutputShape() { - return super.calculateOutputShape(); - } - - @Override - public String opName() { - return "if_bp"; - } - - @Override - public List diff(List i_v1) { - throw new UnsupportedOperationException("Unable to take the derivative of the derivative for if"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java index 2a3403ae8..5fdebd03d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/Select.java @@ -55,7 +55,7 @@ public class Select extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java deleted file mode 100644 index e26b0ea5f..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java +++ /dev/null @@ -1,660 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.*; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * Equivalent to tensorflow's while loop - * Takes in: - * loopVars - * loop body - * condition - * - * runs loop till condition is false. - * @author Adam Gibson - */ -@NoArgsConstructor -@Slf4j -public class While extends DifferentialFunction implements CustomOp { - private AtomicInteger startPosition; - - - - @Getter - protected SameDiff loopBodyExecution,predicateExecution; - - - @Getter - protected SameDiffConditional predicate; - @Getter - protected SameDiffFunctionDefinition trueBody; - - @Getter - protected String blockName,trueBodyName; - - @Getter - protected SDVariable[] inputVars; - - - @Getter - protected SDVariable targetBoolean; - - protected SDVariable dummyResult; - - @Getter - @Setter - protected SDVariable[] outputVars; - - @Getter - protected int numLooped = 0; - - /** - * Mainly meant for tensorflow import. - * This allows {@link #initFromTensorFlow(NodeDef, SameDiff, Map, GraphDef)} - * to continue from a parent while loop - * using the same graph - * @param startPosition the start position for the import scan - */ - public While(AtomicInteger startPosition) { - this.startPosition = startPosition; - } - - public While(While whileStatement) { - this.sameDiff = whileStatement.sameDiff; - this.outputVars = whileStatement.outputVars; - this.loopBodyExecution = whileStatement.loopBodyExecution; - this.numLooped = whileStatement.numLooped; - this.dummyResult = whileStatement.dummyResult; - this.predicate = whileStatement.predicate; - this.predicateExecution = whileStatement.predicateExecution; - this.inputVars = whileStatement.inputVars; - this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - } - - - - @Builder - public While(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffConditional predicate, - SameDiffFunctionDefinition condition, - SameDiffFunctionDefinition trueBody) { - init(blockName,parent,inputVars,predicate,condition,trueBody); - } - - - private void init(String blockName, - SameDiff parent, - SDVariable[] inputVars, - SameDiffConditional predicate, - SameDiffFunctionDefinition condition, - SameDiffFunctionDefinition trueBody) { - this.sameDiff = parent; - this.inputVars = inputVars; - this.predicate = predicate; - this.trueBody = trueBody; - this.blockName = blockName; - this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); - parent.putOpForId(getOwnName(),this); - - parent.addArgsFor(inputVars,this); - parent.addOutgoingFor(new SDVariable[]{dummyResult},this); - - - //create a samediff sub graph for running just the execution - //return a reference to the loop for referencing during actual execution - SameDiff sameDiff = SameDiff.create(); - //store the reference to the result array and the same diff execution instance - this.targetBoolean = predicate.eval(sameDiff,condition, inputVars); - this.predicateExecution = sameDiff; - //store references to the loop body - String trueBodyName = "true-body-" + UUID.randomUUID().toString(); - this.trueBodyName = trueBodyName; - //running define function will setup a proper same diff instance - parent.defineFunction(trueBodyName,trueBody,inputVars); - parent.defineFunction(blockName,condition,inputVars); - parent.putSubFunction("predicate-eval-body",sameDiff); - //get a reference to the actual loop body - this.loopBodyExecution = parent.getFunction(trueBodyName); - - } - - - @Override - public SDVariable[] outputVariables(String baseName) { - return new SDVariable[]{dummyResult}; - } - - @Override - public List doDiff(List f1) { - List ret = new ArrayList<>(); - ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables())); - return ret; - } - - - - /** - * Increments the loop counter. - * This should be called when the loop - * actually executes. - */ - public void incrementLoopCounter() { - numLooped++; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - doImport(nodeDef,initWith,attributesForNode,graph,new LinkedHashSet(),new AtomicInteger(0)); - } - - - private void doImport(NodeDef nodeDef,SameDiff initWith,Map attributesForNode,GraphDef graph,Set skipSet,AtomicInteger currIndex) { - val uniqueId = java.util.UUID.randomUUID().toString(); - skipSet.add(nodeDef.getName()); - val scopeCondition = SameDiff.create(); - val scopeLoop = SameDiff.create(); - initWith.putSubFunction("condition-" + uniqueId,scopeCondition); - initWith.putSubFunction("loopbody-" + uniqueId,scopeLoop); - this.loopBodyExecution = scopeLoop; - this.predicateExecution = scopeCondition; - this.startPosition = currIndex; - - log.info("Adding 2 new scopes for WHILE {}"); - - - val nodes = graph.getNodeList(); - - /** - * Plan is simple: - * 1) we read all declarations of variables used within loop - * 2) we set up conditional scope - * 3) we set up body scope - * 4) ??? - * 5) PROFIT! - */ - - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("enter")) { - //skipSet.add(tfNode.getName()); - break; - } - -// if (skipSet.contains(tfNode.getName())) -// continue; - - skipSet.add(tfNode.getName()); - - val vars = new SDVariable[tfNode.getInputCount()]; - for (int e = 0; e < tfNode.getInputCount(); e++) { - val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e)); - vars[e] = initWith.getVariable(input) == null ? initWith.var(input, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(input); - scopeCondition.var(vars[e]); - scopeLoop.var(vars[e]); - } - - this.inputVars = vars; - } - - - // now we're skipping Merge step, since we've already captured variables at Enter step - int mergedCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("merge")) { - scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor) null,new ZeroInitScheme()); - break; - } - - skipSet.add(tfNode.getName()); - val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor)null,new ZeroInitScheme()); - scopeCondition.var(var); - initWith.var(var); - mergedCnt++; - } - - - // now, we're adding conditional scope - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - // we're parsing up to condition - if (tfNode.getOp().equalsIgnoreCase("LoopCond")) { - skipSet.add(tfNode.getName()); - currIndex.incrementAndGet(); - break; - } - - boolean isConst = tfNode.getOp().equalsIgnoreCase("const"); - boolean isVar = tfNode.getOp().startsWith("VariableV"); - boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder"); - - - if (isConst || isVar || isPlaceholder) { - val var = scopeCondition.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme()); - scopeLoop.var(var); - initWith.var(var); - log.info("Adding condition var [{}]", var.getVarName()); - - } - else if(!skipSet.contains(tfNode.getName())) { - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - func.initFromTensorFlow(tfNode,scopeCondition,nodeDef.getAttrMap(),graph); - func.setSameDiff(scopeLoop); - - } - - skipSet.add(tfNode.getName()); - } - - - - // time to skip some Switch calls - int switchCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - // we're parsing up to condition - if (!tfNode.getOp().equalsIgnoreCase("Switch")) - break; - - switchCnt++; - skipSet.add(tfNode.getName()); - } - - // now we're parsing Identity step - int identityCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - - if (!tfNode.getOp().equalsIgnoreCase("Identity")) { - break; - } - - - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph); - func.setSameDiff(scopeLoop); - - - val variables = new SDVariable[tfNode.getInputCount()]; - for(int i = 0; i < tfNode.getInputCount(); i++) { - val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i))); - if(testVar == null) { - variables[i] = initWith.var(tfNode.getInput(i), (LongShapeDescriptor) null,new ZeroInitScheme()); - scopeCondition.var(variables[i]); - scopeLoop.var(variables[i]); - continue; - } - else { - - variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i))); - scopeCondition.var(variables[i]); - scopeLoop.var(variables[i]); - } - - } - - scopeLoop.addArgsFor(variables,func); - skipSet.add(tfNode.getName()); - } - - - // parsing body scope - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (skipSet.contains(tfNode.getName())) { - log.info("Skipping: {}", tfNode.getName()); - continue; - } - - if (tfNode.getOp().equalsIgnoreCase("NextIteration")) { -// skipSet.add(tfNode.getName()); - break; - } - - if (skipSet.contains(tfNode.getName())) { - log.info("Skipping: {}", tfNode.getName()); - continue; - } - - - - boolean isConst = tfNode.getOp().equalsIgnoreCase("const"); - boolean isVar = tfNode.getOp().startsWith("VariableV"); - boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder"); - - - if (isConst || isVar || isPlaceholder) { - val var = scopeLoop.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme()); - log.info("Adding body var [{}]",var.getVarName()); - - } else { - log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp()); - - if (tfNode.getOp().equalsIgnoreCase("enter")) { - log.info("NEW LOOP ----------------------------------------"); - val func = new While(currIndex); - func.doImport(nodeDef,initWith,attributesForNode,graph,skipSet,currIndex); - func.setSameDiff(initWith); - log.info("END LOOP ----------------------------------------"); - } else { - val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName()); - - func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph); - - - func.setSameDiff(scopeCondition); - - val variables = new SDVariable[tfNode.getInputCount()]; - for(int i = 0; i < tfNode.getInputCount(); i++) { - val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)); - variables[i] = scopeCondition.getVariable(name); - if(variables[i] == null) { - if(scopeLoop.getVariable(name) == null) - variables[i] = scopeCondition.var(initWith.getVariable(name)); - else if(scopeLoop.getVariable(name) != null) - variables[i] = scopeLoop.getVariable(name); - else - variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0)); - } - } - - scopeLoop.addArgsFor(variables,func); - - - } - } - - skipSet.add(tfNode.getName()); - } - - - val returnInputs = new ArrayList(); - val returnOutputs = new ArrayList(); - - // mapping NextIterations, to Return op - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("NextIteration")) - break; - - skipSet.add(tfNode.getName()); - - val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName()); - val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ; - returnInputs.add(input); - } - - - this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]); - this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]); - initWith.addArgsFor(inputVars,this); - initWith.addOutgoingFor(outputVars,this); - - // we should also map While/Exit to libnd4j while - int exitCnt = 0; - for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) { - val tfNode = nodes.get(currIndex.get()); - - if (!tfNode.getOp().equalsIgnoreCase("Exit")) { - //skipSet.add(tfNode.getName()); - break; - } - - skipSet.add(tfNode.getName()); - val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName()); - val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ; - } - - - //the output of the condition should always be a singular scalar - //this is a safe assumption - val conditionVars = scopeCondition.ops(); - if(conditionVars.length < 1) { - throw new ND4JIllegalArgumentException("No functions found!"); - } - this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0]; - - log.info("-------------------------------------------"); - - } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - - @Override - public String toString() { - return opName(); - } - - @Override - public String opName() { - return "while"; - } - - @Override - public long opHash() { - return opName().hashCode(); - } - - @Override - public boolean isInplaceCall() { - return false; - } - - @Override - public INDArray[] outputArguments() { - return new INDArray[0]; - } - - @Override - public INDArray[] inputArguments() { - return new INDArray[0]; - } - - @Override - public long[] iArgs() { - return new long[0]; - } - - @Override - public double[] tArgs() { - return new double[0]; - } - - @Override - public void addIArgument(int... arg) { - - } - - @Override - public void addIArgument(long... arg) { - - } - - @Override - public void removeIArgument(Integer arg) { - - } - - @Override - public Long getIArgument(int index) { - return null; - } - - @Override - public int numIArguments() { - return 0; - } - - @Override - public void addTArgument(double... arg) { - - } - - @Override - public void removeTArgument(Double arg) { - - } - - @Override - public Double getTArgument(int index) { - return null; - } - - @Override - public int numTArguments() { - return 0; - } - - @Override - public int numBArguments() { - return 0; - } - - @Override - public void addInputArgument(INDArray... arg) { - - } - - @Override - public void removeInputArgument(INDArray arg) { - - } - - @Override - public boolean[] bArgs() { - return new boolean[0]; - } - - @Override - public void addBArgument(boolean... arg) { - - } - - @Override - public Boolean getBArgument(int index) { - return null; - } - - @Override - public INDArray getInputArgument(int index) { - return null; - } - - @Override - public int numInputArguments() { - return 0; - } - - @Override - public void addOutputArgument(INDArray... arg) { - - } - - @Override - public void removeOutputArgument(INDArray arg) { - - } - - @Override - public INDArray getOutputArgument(int index) { - return null; - } - - @Override - public int numOutputArguments() { - return 0; - } - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - for(SDVariable var : args()) { - ret.add(sameDiff.getShapeDescriptorForVarName(var.getVarName())); - } - return ret; - } - - @Override - public CustomOpDescriptor getDescriptor() { - return CustomOpDescriptor.builder().build(); - } - - @Override - public void assertValidForExecution() { - - } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName()); - } - - @Override - public String[] tensorflowNames() { - throw new NoOpNameFoundException("This operation has no TF counterpart"); - } - - - @Override - public Op.Type opType() { - return Op.Type.LOOP; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java deleted file mode 100644 index d9aaf2af0..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/WhileDerivative.java +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.controlflow; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffConditional; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ops.Op; - -/** - * While loop derivative - * @author Adam Gibson - */ -@NoArgsConstructor -public class WhileDerivative extends While { - private While delegate; - - public WhileDerivative(While delegate) { - super(delegate); - this.delegate = delegate; - } - - - - @Override - public SameDiffFunctionDefinition getTrueBody() { - return delegate.trueBody; - } - - @Override - public String getTrueBodyName() { - return delegate.getTrueBodyName(); - } - - @Override - public SameDiffConditional getPredicate() { - return delegate.getPredicate(); - } - - @Override - public SameDiff getPredicateExecution() { - return delegate.getPredicateExecution(); - } - - @Override - public SDVariable[] getInputVars() { - return delegate.getInputVars(); - } - - @Override - public String getBlockName() { - return delegate.getBlockName(); - } - - @Override - public SameDiff getLoopBodyExecution() { - return delegate.getLoopBodyExecution(); - } - - @Override - public int getNumLooped() { - return delegate.getNumLooped(); - } - - @Override - public String opName() { - return "while_bp"; - } - - @Override - public Op.Type opType() { - return Op.Type.CONDITIONAL; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow name for while backprop"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 3f56096a2..1bb451bf1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -55,7 +55,7 @@ public abstract class BaseCompatOp extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java index 4f5d11b38..769f7c509 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java @@ -32,9 +32,11 @@ import java.util.List; import java.util.Map; public class LoopCond extends BaseCompatOp { + public static final String OP_NAME = "loop_cond"; + @Override public String opName() { - return "loop_cond"; + return OP_NAME; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 4ede302dd..0fee6c238 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -74,8 +74,6 @@ public class CropAndResize extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - String method = attributesForNode.get("method").getS().toStringUtf8(); if(method.equalsIgnoreCase("nearest")){ this.method = Method.NEAREST; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 62194c044..8922df9e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -120,4 +120,10 @@ public class ExtractImagePatches extends DynamicCustomOp { //TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3) return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)}; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java index 5ae8f85ea..be6eb3730 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java @@ -74,7 +74,7 @@ public class ResizeBilinear extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); this.alignCorners = attributesForNode.get("align_corners").getB(); addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java index ecb48f922..ea339ae2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java @@ -50,7 +50,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java index fd2134aad..208daebf2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java @@ -57,7 +57,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { public ExternalErrorsFunction(){ } public String getGradPlaceholderName(){ - return arg().getVarName() + "-grad"; + return arg().name() + "-grad"; } @Override @@ -70,7 +70,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { out = sameDiff.getVariable(name); } else { out = sameDiff.zero(name, Nd4j.dataType(), 1); - sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.getVarName())); + sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.name())); sameDiff.getVariables().get(name).setOutputOfOp(getOwnName()); } } @@ -83,7 +83,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { if (gradVariables == null) { gradVariables = new HashMap<>(); for(SDVariable arg : args()){ - INDArray gradArr = gradients.get(arg.getVarName()); + INDArray gradArr = gradients.get(arg.name()); SDVariable grad; DataType dt = arg.dataType(); String n = getGradPlaceholderName(); @@ -94,7 +94,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { } else { grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt); } - gradVariables.put(arg.getVarName(), grad); + gradVariables.put(arg.name(), grad); out.add(grad); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index bad975cb5..20ff5918c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -26,8 +26,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,7 +39,6 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.lang.reflect.Field; import java.util.*; @@ -106,7 +103,7 @@ public class BatchNorm extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); //Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta] SameDiffOp op = initWith.getOps().get(this.getOwnName()); List list = op.getInputsToOp(); @@ -140,13 +137,12 @@ public class BatchNorm extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } @Override public String opName() { - return "batchnorm_new"; + return "batchnorm"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 2fc814fb3..852c865f7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -21,33 +21,20 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import lombok.val; -import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.imports.descriptors.properties.AttributeAdapter; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; -import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter; -import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; -import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.util.ArrayUtil; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.lang.reflect.Field; -import java.util.*; +import java.util.Collections; +import java.util.List; +import java.util.Map; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 5e077e3fc..3794469ae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -31,7 +31,6 @@ import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.*; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -122,7 +121,7 @@ public class Conv2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -138,8 +137,7 @@ public class Conv2D extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 8c4e40e8a..665e7dd99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -251,7 +251,7 @@ public class Conv3D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index c69292dd9..2bba1c2e3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -186,22 +186,22 @@ public class DeConv2D extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { val aStrides = nodeDef.getAttrOrThrow("strides"); val tfStrides = aStrides.getList().getIList(); - int sH = 1; - int sW = 1; - int kH = 1; - int kW = 1; + long sH = 1; + long sW = 1; + long kH = 1; + long kW = 1; val aPadding = nodeDef.getAttrOrDefault("padding", null); val paddingMode = aPadding.getS().toStringUtf8(); val args = args(); - INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); + INDArray arr = sameDiff.getVariable(args[1].name()).getArr(); if (arr == null) { - arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); + arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); // TODO: arguable. it might be easier to permute weights once //arr = (arr.permute(3, 2, 0, 1).dup('c')); - val varForOp = initWith.getVariable(args[1].getVarName()); + val varForOp = initWith.getVariable(args[1].name()); if (arr != null) initWith.associateArrayWithVariable(arr, varForOp); @@ -214,21 +214,18 @@ public class DeConv2D extends DynamicCustomOp { dataFormat = attr.getS().toStringUtf8().toLowerCase(); } - // FIXME: int cast - - if (dataFormat.equalsIgnoreCase(DeConv2DConfig.NCHW)) { - sH = tfStrides.get(2).intValue(); - sW = tfStrides.get(3).intValue(); + sH = tfStrides.get(2).longValue(); + sW = tfStrides.get(3).longValue(); - kH = (int) arr.size(2); - kW = (int) arr.size(3); + kH = arr.size(2); + kW = arr.size(3); } else { - sH = tfStrides.get(1).intValue(); - sW = tfStrides.get(2).intValue(); + sH = tfStrides.get(1).longValue(); + sW = tfStrides.get(2).longValue(); - kH = (int) arr.size(0); - kW = (int) arr.size(1); + kH = arr.size(0); + kW = arr.size(1); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index bc4f996b1..dfabc89dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -214,7 +214,7 @@ public class DeConv2DTF extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -240,9 +240,9 @@ public class DeConv2DTF extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes){ //inShape, weights, input int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Collections.singletonList(inputDataTypes.get(2)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 077f6a64b..0a153422d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -158,10 +158,10 @@ public class DeConv3D extends DynamicCustomOp { val paddingMode = aPadding.getS().toStringUtf8(); val args = args(); - INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); + INDArray arr = sameDiff.getVariable(args[1].name()).getArr(); if (arr == null) { - arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); - val varForOp = initWith.getVariable(args[1].getVarName()); + arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); + val varForOp = initWith.getVariable(args[1].name()); if (arr != null) initWith.associateArrayWithVariable(arr, varForOp); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index 6715f742a..704f8bdd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -77,7 +77,7 @@ public class DepthToSpace extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); boolean isNHWC = dataFormat.equals("NHWC"); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index ec2bb1d3f..4b10909a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -29,14 +29,15 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.properties.adapters.*; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; +import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; +import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter; +import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; +import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -136,7 +137,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); /* @@ -162,8 +163,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java index 90bdcdb45..e591c9f1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java @@ -75,7 +75,7 @@ public class SpaceToDepth extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 3b12187e3..21756d99b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -64,7 +64,7 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 9c385c425..15f556c64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -55,7 +55,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); //Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits] SameDiffOp op = initWith.getOps().get(this.getOwnName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 7d711ca58..fa6bceef7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -193,12 +193,6 @@ public class Mmul extends DynamicCustomOp { .transposeA(isTransposeA).transposeB(isTransposeB) .build(); this.mt = mMulTranspose; - val args = args(); - for(val arg : args) { - if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) { - sameDiff.addPropertyToResolve(this,arg.getVarName()); - } - } iArguments.clear(); addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB())); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index 80b767676..7afdf5166 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -64,7 +64,7 @@ public class Moments extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java index 09f4ac2f4..36b4600c9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java @@ -60,7 +60,7 @@ public class NormalizeMoments extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index ad19598eb..6c8aa5901 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -63,7 +63,7 @@ public class ScatterAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index ea7ef3da7..4e7563e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -86,7 +86,7 @@ public class ScatterDiv extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 33f8db980..65162aad3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -60,7 +60,7 @@ public class ScatterMax extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 00322b259..8d8fe4e33 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -60,7 +60,7 @@ public class ScatterMin extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 1db426364..2790667cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -62,7 +62,7 @@ public class ScatterMul extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java index a589fa1ae..a72801760 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java @@ -67,7 +67,7 @@ public class ScatterNd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { @@ -80,8 +80,8 @@ public class ScatterNd extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + public List calculateOutputDataTypes(List inputDataTypes){ //Indices, updates, shape + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(1)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java index 7dd2b9462..c79ec058d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java @@ -66,7 +66,7 @@ public class ScatterNdAdd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java index 42c539f58..8efc6717f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java @@ -66,7 +66,7 @@ public class ScatterNdSub extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java index aeb3c9872..bf95b448d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java @@ -66,7 +66,7 @@ public class ScatterNdUpdate extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 375d5bc6b..382806779 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -79,7 +79,7 @@ public class ScatterSub extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index ccfc541de..980ae7f8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -73,8 +73,6 @@ public class ScatterUpdate extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - if (nodeDef.containsAttr("use_locking")) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { bArguments.add(true); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 5c6beb945..c860152ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -130,10 +130,7 @@ public class Concat extends DynamicCustomOp { val variable = initWith.getVariable(input); // concat dimension is only possible - if (variable != null && variable.getArr() == null) { - sameDiff.addPropertyToResolve(this, input); - - } else if (variable != null) { + if (variable != null) { val arr = variable.getArr(); if (arr.length() == 1) { concatDimension = arr.getInt(0); @@ -151,6 +148,7 @@ public class Concat extends DynamicCustomOp { removeInputArgument(inputArgs[inputArguments().length - 1]); } + //TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285 sameDiff.removeArgFromOp(input,this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index 5c50b983d..a13a03184 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -69,8 +69,8 @@ public class ExpandDims extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); - val dimArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph); + val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); + val dimArr = TFGraphMapper.getNDArrayFromTensor(targetNode); if (dimArr != null) { int axis = dimArr.data().asInt()[0]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index 5613cc85f..fd6ec5240 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -22,13 +22,9 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -73,12 +69,12 @@ public class Gather extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index cfe4fe8be..b8ef51d57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -17,26 +17,13 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; -import lombok.val; -import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.util.ArrayUtil; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; /** * GatherND op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index 2d5dcc63c..841fec7b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -88,7 +88,7 @@ public class OneHot extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); if(attributesForNode.containsKey("T")) { outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 1856e6804..3a1605d8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -64,7 +64,7 @@ public class ParallelStack extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index 568b14a44..c05df441c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -50,21 +50,6 @@ public class Rank extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, inPlace); } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val name = TFGraphMapper.getInstance().getNodeName(nodeDef.getName()); - val input = initWith.getVariable(name); - val outputVertex = input.getVarName(); - if (!initWith.isPlaceHolder(input.getVarName()) && initWith.shapeAlreadyExistsForVarName(outputVertex)) { - val inputShape = initWith.getShapeForVarName(input.getVarName()); - val resultLength = Nd4j.scalar(inputShape.length); - val thisResultId = outputVertex; - initWith.setArrayForVariable(thisResultId, resultLength); - initWith.putShapeForVarName(thisResultId, new long[]{1, 1}); - } - } - @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index af8940bf4..4bf920da9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -101,7 +101,7 @@ public class Repeat extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addIArgument(jaxis); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index b30bacc22..44d9b79fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -21,20 +21,18 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.*; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * Reshape function @@ -70,32 +68,7 @@ public class Reshape extends DynamicCustomOp { if (!nodeDef.containsAttr("TShape") && nodeDef.getInputCount() == 1) { this.shape = new long[]{}; return; - } else if (nodeDef.getInputCount() > 1) { - val shapeNode = nodeDef.getInput(1); - NodeDef shapeNodeInGraph = null; - for (int i = 0; i < graph.getNodeCount(); i++) { - if (graph.getNode(i).getName().equals(shapeNode)) { - shapeNodeInGraph = graph.getNode(i); - - } - } - - val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", shapeNodeInGraph, graph); - if (arr != null && arr.isEmpty()) { - // special case: empty array - this.shape = new long[0]; - - } else if (arr != null) { - this.shape = arr.data().asLong(); - //all TF is c - if (!ArrayUtil.containsAnyNegative(this.shape)) - addIArgument(this.shape); - else { - arrName = nodeDef.getName(); - } - - } - } else { + } else if(nodeDef.getInputCount() == 1){ val shape = nodeDef.getAttrOrThrow("Tshape"); if (!shape.hasShape()) { val shapeRet = new long[2]; @@ -127,8 +100,7 @@ public class Reshape extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - val shape = new OnnxGraphMapper().getShape(node); - this.shape = shape; + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index a2f6bd208..67454e231 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -65,13 +65,13 @@ public class SequenceMask extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); - val maxlen = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph); + val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); + val maxlen = TFGraphMapper.getNDArrayFromTensor(targetNode); if (maxlen == null){ // No 2nd input this.is_static_maxlen = true; } - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (is_static_maxlen) { addIArgument(this.maxLen); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java index f11c10c1c..685623d32 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Split.java @@ -54,7 +54,7 @@ public class Split extends DynamicCustomOp { this.numSplit = numSplits; addIArgument(numSplits); - val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); + val splitDim = TFGraphMapper.getArrayFrom(TFGraphMapper.getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java index 134f0dbe3..2407bc0f1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SplitV.java @@ -49,7 +49,7 @@ public class SplitV extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); + val splitDim = TFGraphMapper.getArrayFrom(TFGraphMapper.getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 6cd09f9bd..d2bf9d71b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -88,7 +88,7 @@ public class Stack extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 782c70859..32fd91b3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -114,7 +114,7 @@ public class Transpose extends DynamicCustomOp { } - INDArray permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph); + INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode); if (permuteArrayOp != null) { this.permuteDims = permuteArrayOp.data().asInt(); } @@ -124,13 +124,7 @@ public class Transpose extends DynamicCustomOp { return; } - INDArray arr = sameDiff.getArrForVarName(arg().getVarName()); - if (arr == null) { - val arrVar = sameDiff.getVariable(arg().getVarName()); - - arr = arrVar.getWeightInitScheme().create(arrVar.dataType(), arrVar.getShape()); - sameDiff.setArrayForVariable(arg().getVarName(), arr); - } + INDArray arr = sameDiff.getArrForVarName(arg().name()); if(permuteArrayOp != null){ addInputArgument(arr, permuteArrayOp); @@ -138,16 +132,12 @@ public class Transpose extends DynamicCustomOp { addInputArgument(arr); } - - if (arr != null && permuteDims == null) { this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank())); } if (permuteDims != null && permuteDims.length < arg().getShape().length) throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified"); - - } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java index b027750fc..a9e67f9f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java @@ -47,8 +47,8 @@ public abstract class BaseTensorOp extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { val inputOne = nodeDef.getInput(1); val varFor = initWith.getVariable(inputOne); - val nodeWithIndex = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,inputOne); - val var = TFGraphMapper.getInstance().getArrayFrom(nodeWithIndex,graph); + val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne); + val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph); if(var != null) { val idx = var.getInt(0); addIArgument(idx); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java index b22434a71..4ecdf947d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.java @@ -70,7 +70,7 @@ public class TensorArray extends BaseTensorOp { } } - val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); + val arr = TFGraphMapper.getNDArrayFromTensor(iddNode); if (arr != null) { int idx = arr.getInt(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java index 5b3584d11..b32687844 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java @@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java index 185ef429a..70d200dea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java @@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java index 2b8114543..a92185e57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java @@ -73,7 +73,7 @@ public class TensorArrayRead extends BaseTensorOp { dt = importDataType; } else { SDVariable tArr = arg(0); - DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName()); + DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.name()); TensorArray t3 = (TensorArray) op; dt = t3.getTensorArrayDataType(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java index 78fdcabba..3334c9bbd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/CheckNumerics.java @@ -71,9 +71,9 @@ public class CheckNumerics extends DynamicCustomOp { SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str)); List newInputs = new ArrayList<>(2); newInputs.addAll(initWith.getOps().get(name).getInputsToOp()); - newInputs.add(msg.getVarName()); + newInputs.add(msg.name()); initWith.getOps().get(name).setInputsToOp(newInputs); - initWith.getVariables().get(msg.getVarName()).setInputsForOp(Collections.singletonList(getOwnName())); } + initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName())); } @Override public List calculateOutputDataTypes(List inputDataTypes){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java index 4d321d920..9a935aa8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java @@ -18,12 +18,15 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -46,11 +49,17 @@ public class Cholesky extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); } @Override public List doDiff(List f1) { throw new UnsupportedOperationException(); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java deleted file mode 100644 index e003f4325..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Constant.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms; - -import lombok.Data; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -@Data -public class Constant extends BaseTransformSameOp { - - - public Constant() { - } - - - protected Constant(SameDiff sameDiff, - SDVariable i_v, - long[] shape, - boolean inPlace) { - super(); - sameDiff.putOrUpdateShapeForVarName(i_v.getVarName(), shape, false); - this.xVertexId = i_v.getVarName(); - this.inPlace = inPlace; - this.sameDiff = sameDiff; - } - - public Constant(SameDiff sameDiff, SDVariable i_v, long[] shape) { - this(sameDiff, i_v, shape, false); - } - - - @Override - public List doDiff(List i_v) { - return Collections.singletonList(sameDiff.zerosLike(arg())); - } - - - @Override - public DifferentialFunction dup() { - Constant ret = new Constant(sameDiff, sameDiff.getVariable(outputVariables()[0].getVarName()) - , sameDiff.getShapeForVarName(outputVariables()[0].getVarName())); - Constant differentialFunction = ret; - return differentialFunction; - } - - - @Override - public int opNum() { - return 15; - } - - @Override - public String opName() { - return "constant"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow opName found for " + opName()); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index 052700b0a..05993cd7f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -118,7 +118,7 @@ public class MaxOut extends BaseTransformOp { if(arg() == null) throw new ND4JIllegalStateException("No arg found for op!"); - val arr = sameDiff.getArrForVarName(arg().getVarName()); + val arr = sameDiff.getArrForVarName(arg().name()); if(arr == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java index 958df5579..fcf6390cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/NthElement.java @@ -18,12 +18,15 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -47,7 +50,7 @@ public class NthElement extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); this.reverse = attributesForNode.get("reverse").getB(); addArgs(); @@ -70,4 +73,10 @@ public class NthElement extends DynamicCustomOp { public List doDiff(List f1) { throw new UnsupportedOperationException(); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ //Input and number + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index 72d2823b5..7ea1bb38b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -99,8 +99,8 @@ public class Pad extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), - "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() >= 1 && inputDataTypes.size() <= 3), + "Expected 1-3 input datatypes for %s, got %s", getClass(), inputDataTypes); //input, padding, pad value return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 9c04aeb12..3874c040b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -120,7 +120,7 @@ public class CumProd extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -143,7 +143,8 @@ public class CumProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes){ - Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), dataTypes); + Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), + "Expected 1 or 2 input datatype for %s, got %s", getClass(), dataTypes); //2nd optional input - axis return Collections.singletonList(dataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index b8c7d5c51..6720b5a75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -122,7 +122,7 @@ public class CumSum extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } @@ -144,7 +144,8 @@ public class CumSum extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes){ - Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), dataTypes); + Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), + "Expected 1 or 2 input datatype for %s, got %s", getClass(), dataTypes); //2nd optional input - axis return Collections.singletonList(dataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java index 79e793174..fc909261b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java @@ -19,12 +19,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.util.ArrayUtil; @@ -32,9 +34,7 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; +import java.util.*; /** * Dilation2D op wrapper @@ -90,7 +90,7 @@ public class Dilation2D extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); addArgs(); } @@ -185,4 +185,11 @@ public class Dilation2D extends DynamicCustomOp { public String tensorflowName() { return "Dilation2D"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ //Input and weights, optional rates/strides + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2 && inputDataTypes.size() <= 4, + "Expected 2 to 4 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 06d52f777..8581c51fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -74,7 +74,7 @@ public class DynamicPartition extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index af4097870..a5ffbced5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -54,7 +54,6 @@ public class Fill extends DynamicCustomOp { public Fill(SameDiff sameDiff, SDVariable shape, DataType outputDataType, double value) { super(null,sameDiff, new SDVariable[] {shape}, false); this.value = value; - val shp = shape.getArr(); this.outputDataType = outputDataType; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java index 7a69306ab..28fe7c305 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InTopK.java @@ -74,7 +74,7 @@ public class InTopK extends DynamicCustomOp { } Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName); - INDArray arr = TFGraphMapper.getInstance().getNDArrayFromTensor(inputName, kNode, graph); + INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode); this.k = arr.getInt(0); addIArgument(k); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java index 1e84fa3f2..bed056888 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MirrorPad.java @@ -43,7 +43,7 @@ public class MirrorPad extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); iArguments.add(isSymmetric ? 1L : 0L); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java index e54a9dc40..3d167ea9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ParallelConcat.java @@ -42,7 +42,7 @@ public class ParallelConcat extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); // We might want to import everything here? i.e. shape in advance? } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 0906451cc..078af1088 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -75,7 +75,7 @@ public class ReverseSequence extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArguments(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java index 186205df7..21ceaed83 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java @@ -47,6 +47,20 @@ public class Svd extends DynamicCustomOp { public Svd(){ } + public Svd(INDArray input, boolean full_matrices, INDArray s, INDArray u, INDArray v) { + inputArguments.add(input); + fullUV = full_matrices; + computeUv = true; + switchNum = DEFAULT_SWITCHNUM; + + + outputArguments.add(s); + outputArguments.add(u); + outputArguments.add(v); + + addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum); + } + public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){ this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java index 0779fd693..e9d40264f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/TopK.java @@ -82,7 +82,7 @@ public class TopK extends DynamicCustomOp { if (kNode != null) { Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName); - INDArray arr = TFGraphMapper.getInstance().getNDArrayFromTensor(inputName, kNode, graph); + INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode); this.k = arr.getInt(0); addIArgument(ArrayUtil.fromBoolean(sorted), k); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index 98a479542..eb8b820ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -84,7 +84,7 @@ public class Cast extends BaseDynamicTransformOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index 4d8209b8a..16afb4316 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -28,13 +28,19 @@ import java.util.Collections; import java.util.List; /** - * + * TanhDerivative: calculated dL/dIn from dL/dOut and In */ public class TanhDerivative extends DynamicCustomOp { public TanhDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } + /** + * + * @param x Input + * @param y Gradient at output (dL/dOut) + * @param z Output array, gradient at input (dL/dIn - to be calculated) + */ public TanhDerivative(INDArray x, INDArray y, INDArray z) { super(null, new INDArray[]{x, y}, new INDArray[]{z}); } @@ -42,6 +48,10 @@ public class TanhDerivative extends DynamicCustomOp { public TanhDerivative() { } + /** + * @param x Input + * @param y Gradient at output (dL/dOut) + */ public TanhDerivative(INDArray x, INDArray y) { this(x, y, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 23d176603..49ef2fb09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -37,10 +37,6 @@ public class ASinh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } - public ASinh(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - public ASinh(SameDiff sameDiff, SDVariable i_v) { super(sameDiff, i_v, false); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index f3fb41464..763a64fc4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -43,11 +43,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) { Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor"); this.sameDiff = sameDiff; - this.xVertexId = i_v.getVarName(); + this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - if(Shape.isPlaceholderShape(i_v.getShape())) { - sameDiff.addPropertyToResolve(this,i_v.getVarName()); - } } public BaseRandomOp(SameDiff sd, long[] shape){ @@ -73,11 +70,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { if(shape != null){ return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); } else { - List ret = new ArrayList<>(1); - val shape = sameDiff.getShapeForVarName(args()[0].getVarName()); - if (shape != null) - ret.add(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); - return ret; + return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistributionEx.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistributionEx.java index ff05d6c9f..3ec26a927 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistributionEx.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistributionEx.java @@ -33,7 +33,7 @@ import java.util.Map; * @author raver119@gmail.com */ public class BinomialDistributionEx extends BaseRandomOp { - private int trials; + private long trials; private double probability; public BinomialDistributionEx() { @@ -46,7 +46,7 @@ public class BinomialDistributionEx extends BaseRandomOp { * @param trials * @param probability */ - public BinomialDistributionEx(@NonNull INDArray z, int trials, double probability) { + public BinomialDistributionEx(@NonNull INDArray z, long trials, double probability) { super(z, z, z); this.trials = trials; this.probability = probability; @@ -59,7 +59,7 @@ public class BinomialDistributionEx extends BaseRandomOp { * @param trials * @param probabilities array with probability value for each trial */ - public BinomialDistributionEx(@NonNull INDArray z, int trials, @NonNull INDArray probabilities) { + public BinomialDistributionEx(@NonNull INDArray z, long trials, @NonNull INDArray probabilities) { super(z, probabilities, z); if (z.length() != probabilities.length()) throw new IllegalStateException("Length of probabilities array should match length of target array"); @@ -82,8 +82,7 @@ public class BinomialDistributionEx extends BaseRandomOp { * @param probabilities */ public BinomialDistributionEx(@NonNull INDArray z, @NonNull INDArray probabilities) { - // FIXME: int cast - this(z, (int) probabilities.length(), probabilities); + this(z, probabilities.length(), probabilities); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java index 831ee144a..e32010827 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java @@ -229,7 +229,6 @@ public abstract class BaseDistribution implements Distribution { if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); } - // FIXME: int cast double[] out = new double[(int) sampleSize]; for (int i = 0; i < sampleSize; i++) { out[i] = sample(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java index 4452bc7e4..80f9910db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java @@ -21,6 +21,7 @@ import lombok.val; import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.OutOfRangeException; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Svd; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.api.rng.distribution.BaseDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -221,7 +222,7 @@ public class OrthogonalDistribution extends BaseDistribution { @Override public INDArray sample(long[] shape){ - int numRows = 1; + long numRows = 1; for (int i = 0; i < shape.length - 1; i++) numRows *= shape[i]; long numCols = shape[shape.length - 1]; @@ -231,21 +232,20 @@ public class OrthogonalDistribution extends BaseDistribution { val flatShape = new long[]{numRows, numCols}; val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random); - long m = flatRng.rows(); - long n = flatRng.columns(); + val m = flatRng.rows(); + val n = flatRng.columns(); val s = Nd4j.create(dtype, m < n ? m : n); - val u = m < n ? Nd4j.create(dtype, m, n) : Nd4j.create(dtype, m, m); + val u = Nd4j.create(dtype, m, m); val v = Nd4j.create(dtype, new long[] {n, n}, 'f'); - Nd4j.getBlasWrapper().lapack().gesvd(flatRng, s, u, v); + Nd4j.exec(new Svd(flatRng, true, s, u, v)); - // FIXME: int cast if (gains == null) { - if (u.rows() == numRows && u.columns() == numCols) { - return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } else { + if (u.rows() >= numRows && u.columns() >= numCols) { return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); + } else { + return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); } } else { throw new UnsupportedOperationException(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index ef4331b0a..51711b3d2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -923,15 +923,13 @@ public class Shape { * @param indices Indices array to get the offset for (must be same length as array rank) * @return Buffer offset fo the specified indices */ - public static long getOffset(IntBuffer shapeInformation, int[] indices) { - // FIXME: int cast + /*public static long getOffset(IntBuffer shapeInformation, int[] indices) { return getOffset(shapeInformation, ArrayUtil.toLongArray(indices)); } public static long getOffset(LongBuffer shapeInformation, int[] indices) { - // FIXME: int cast return getOffset(shapeInformation, ArrayUtil.toLongArray(indices)); - } + }*/ public static long getOffset(LongBuffer shapeInformation, long... indices) { int rank = rank(shapeInformation); @@ -968,8 +966,8 @@ public class Shape { * @param indices Indices array to get the offset for (must be same length as array rank) * @return Buffer offset fo the specified indices */ + @Deprecated public static long getOffset(DataBuffer shapeInformation, int[] indices) { - // FIXME: int cast return getOffset(shapeInformation, ArrayUtil.toLongArray(indices)); } public static long getOffset(DataBuffer shapeInformation, long... indices) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index b31e6e036..9c0645156 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -159,9 +159,8 @@ public class Convolution { public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode) { Nd4j.getCompressor().autoDecompress(img); //Input: NCHW format - // FIXME: int cast - int outH = outputSize((int) img.size(2), kh, sy, ph, dh, isSameMode); - int outW = outputSize((int) img.size(3), kw, sx, pw, dw, isSameMode); + long outH = outputSize(img.size(2), kh, sy, ph, dh, isSameMode); + long outW = outputSize(img.size(3), kw, sx, pw, dw, isSameMode); //[miniBatch,depth,kH,kW,outH,outW] INDArray out = Nd4j.create(new long[]{img.size(0), img.size(1), kh, kw, outH, outW}, 'c'); @@ -277,9 +276,8 @@ public class Convolution { output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c'); } else { - // FIXME: int cast - int oH = ((int) img.size(2) - (kh + (kh - 1) * (1 - 1)) + 2 * ph) / sy + 1; - int oW = ((int) img.size(3) - (kw + (kw - 1) * (1 - 1)) + 2 * pw) / sx + 1; + long oH = (img.size(2) - (kh + (kh - 1) * (1 - 1)) + 2 * ph) / sy + 1; + long oW = (img.size(3) - (kw + (kw - 1) * (1 - 1)) + 2 * pw) / sx + 1; output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c'); } @@ -314,7 +312,7 @@ public class Convolution { * @return */ @Deprecated - public static int outSize(int size, int k, int s, int p, int dilation, boolean coverAll) { + public static long outSize(long size, long k, long s, long p, int dilation, boolean coverAll) { k = effectiveKernelSize(k, dilation); if (coverAll) @@ -323,7 +321,7 @@ public class Convolution { return (size + p * 2 - k) / s + 1; } - public static int outputSize(int size, int k, int s, int p, int dilation, boolean isSameMode) { + public static long outputSize(long size, long k, long s, long p, int dilation, boolean isSameMode) { k = effectiveKernelSize(k, dilation); if (isSameMode) { @@ -333,7 +331,7 @@ public class Convolution { } } - public static int effectiveKernelSize(int kernel, int dilation) { + public static long effectiveKernelSize(long kernel, int dilation) { return kernel + (kernel - 1) * (dilation - 1); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index cb354c2d1..5cfecc6fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -584,7 +584,6 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { */ @Override public int numInputs() { - // FIXME: int cast return (int) getFeatures().size(1); } @@ -1134,13 +1133,11 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { @Override public int numOutcomes() { - // FIXME: int cast return (int) getLabels().size(1); } @Override public int numExamples() { - // FIXME: int cast if (getFeatures() != null) return (int) getFeatures().size(0); else if (getLabels() != null) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java index 70c754c6e..4d7d257e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java @@ -99,13 +99,11 @@ public class KFoldIterator implements DataSetIterator { @Override public int inputColumns() { - // FIXME: int cast return (int) allData.getFeatures().size(1); } @Override public int totalOutcomes() { - // FIXME: int cast return (int) allData.getLabels().size(1); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java index cc3f6905d..f57025344 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java @@ -72,13 +72,11 @@ public class TestDataSetIterator implements DataSetIterator { @Override public int inputColumns() { - // FIXME: int cast return (int)list.get(0).getFeatures().columns(); } @Override public int totalOutcomes() { - // FIXME: int cast return (int) list.get(0).getLabels().columns(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java index 598eac4f6..5fc4491c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java @@ -20,6 +20,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -125,7 +126,9 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void transformLabel(INDArray label) { - //No op + Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use" + + " cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", label); + transform(label); } @Override @@ -161,7 +164,9 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void revertLabels(INDArray labels) { - //No op + Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use" + + " cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", labels); + revertFeatures(labels); } @Override @@ -171,9 +176,7 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void fitLabel(boolean fitLabels) { - if (fitLabels) { - log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified"); - } + //No-op } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java index 11d0bd9a6..d7839e8d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java @@ -61,7 +61,6 @@ public class RGBtoGrayscaleDataSetPreProcessor implements DataSetPreProcessor { B.muli(BLUE_RATIO); R.addi(G).addi(B); - // FIXME: int cast result.putSlice((int)n, R); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index a664d9ee5..907290ef9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.random.impl.Range; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.AtomicDouble; @@ -921,8 +922,8 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { int arrOffset = 0; - // FIXME: int cast - + if (ret.tensorsAlongDimension(dimension) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); INDArray[] retAlongDimensionArrays = new INDArray[(int) ret.tensorsAlongDimension(dimension)]; for (int i = 0; i < retAlongDimensionArrays.length; i++) retAlongDimensionArrays[i] = ret.tensorAlongDimension(i, dimension); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 98f92aaaa..25960a8a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -5212,6 +5212,8 @@ public class Nd4j { } } } + + backend.logBackendInit(); } catch (Exception e) { throw new RuntimeException(e); } @@ -5625,19 +5627,38 @@ public class Nd4j { * @return an ndarray created from the in memory * numpy pointer */ - @SuppressWarnings("WeakerAccess") public static INDArray createFromNpyPointer(Pointer pointer) { return INSTANCE.createFromNpyPointer(pointer); } /** - * Create from a given Numpy .npy file. + * Create an INDArray from a given Numpy .npy file. + * + * @param path Path to the .npy file to read + * @return the created ndarray + */ + public static INDArray readNpy(@NonNull String path){ + return readNpy(new File(path)); + } + + /** + * Create an INDArray from a given Numpy .npy file. * * @param file the file to create the ndarray from * @return the created ndarray */ - public static INDArray createFromNpyFile(File file) { + public static INDArray readNpy(@NonNull File file){ + return createFromNpyFile(file); + } + + /** + * Create an INDArray from a given Numpy .npy file. + * + * @param file the file to create the ndarray from + * @return the created ndarray + */ + public static INDArray createFromNpyFile(@NonNull File file) { if (!file.exists()) throw new IllegalArgumentException("File [" + file.getAbsolutePath() + "] doesn't exist"); @@ -5654,7 +5675,7 @@ public class Nd4j { * @return the loaded ndarray */ @SuppressWarnings("unused") - public static INDArray createNpyFromInputStream(InputStream is) throws IOException { + public static INDArray createNpyFromInputStream(@NonNull InputStream is) throws IOException { byte[] content = IOUtils.toByteArray(is); return createNpyFromByteArray(content); } @@ -5668,7 +5689,7 @@ public class Nd4j { * @param input the input byte array with the npy format * @return the equivalent {@link INDArray} */ - public static INDArray createNpyFromByteArray(byte[] input) { + public static INDArray createNpyFromByteArray(@NonNull byte[] input) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length); byteBuffer.put(input); byteBuffer.rewind(); @@ -5708,15 +5729,22 @@ public class Nd4j { for (int e = 0; e < shapeInfo.length; e++) shapeInfo[e] = array.shape(e); - if (Shape.isEmpty(shapeInfo)) - return Nd4j.empty(); + val shapeOf = Shape.shapeOf(shapeInfo); + DataType _dtype = FlatBuffersMapper.getDataTypeFromByte(dtype); + if (Shape.isEmpty(shapeInfo)) { + if(Shape.rank(shapeInfo) == 0) { + return Nd4j.empty(); + } else { + return Nd4j.create(_dtype, shapeOf); + } + } char ordering = shapeInfo[shapeInfo.length - 1] == 99 ? 'c' : 'f'; - val shapeOf = Shape.shapeOf(shapeInfo); + val stridesOf = Shape.stridesOf(shapeInfo); - val _dtype = FlatBuffersMapper.getDataTypeFromByte(dtype); + val _order = FlatBuffersMapper.getOrderFromByte(order); val prod = rank > 0 ? ArrayUtil.prod(shapeOf) : 1; @@ -5803,12 +5831,18 @@ public class Nd4j { } } case UBYTE: - UInt8Buffer b = new UInt8Buffer(ArrayUtil.prod(shapeOf)); - val sb = bb.order(_order).asReadOnlyBuffer(); - for (int e = 0; e < prod; e++) - b.put(e, sb.get(e)); + case BFLOAT16: + case UINT16: + INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf); + ByteBuffer obb = bb.order(_order); + int pos = obb.position(); + byte[] bArr = new byte[obb.limit() - pos]; - return Nd4j.create(b, shapeOf); + for (int e = 0; e < bArr.length; e++) { + bArr[e] = obb.get(e + pos); + } + arr.data().asNio().put(bArr); + return arr; default: throw new UnsupportedOperationException("Unknown datatype: [" + _dtype + "]"); } @@ -6541,7 +6575,8 @@ public class Nd4j { */ @Deprecated public static void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, int... axis) { - Preconditions.checkArgument(indices.dataType() == DataType.INT, "Indices should have INT data type"); + Preconditions.checkArgument(indices.dataType() == DataType.INT || indices.dataType() == DataType.LONG, + "Indices should have INT data type"); Preconditions.checkArgument(array.dataType() == updates.dataType(), "Array and updates should have the same data type"); getExecutioner().scatterUpdate(op, array, indices, updates, axis); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java index 6bc48778f..77f56c613 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java @@ -106,8 +106,6 @@ public class BooleanIndexing { MatchCondition op = new MatchCondition(n, condition, dimension); INDArray arr = Nd4j.getExecutioner().exec(op); - // FIXME: int cast - boolean[] result = new boolean[(int) arr.length()]; for (int i = 0; i < arr.length(); i++) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java index 3ca99c50e..c17295409 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.indexing; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import org.nd4j.linalg.api.ndarray.INDArray; @@ -59,8 +60,8 @@ public class Indices { double otherTest = ((double) index) / arr.size(-1); int test = (int) Math.floor(otherTest); - // FIXME: int cast - + if (arr.vectorsAlongDimension(-1) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int vectors = (int) arr.vectorsAlongDimension(-1); if (test >= vectors) return vectors - 1; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/legacy/AdaGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/legacy/AdaGrad.java index 0bf673a49..6793f04f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/legacy/AdaGrad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/legacy/AdaGrad.java @@ -37,7 +37,7 @@ public class AdaGrad implements Serializable { public static final double DEFAULT_ADAGRAD_EPSILON = 1e-6; public INDArray historicalGradient; - public int[] shape; + public long[] shape; protected double learningRate = 1e-1; // learning rate protected int numIterations = 0; private double epsilon = DEFAULT_ADAGRAD_EPSILON; @@ -73,7 +73,7 @@ public class AdaGrad implements Serializable { * @param learningRate */ public AdaGrad(int rows, int cols, double learningRate) { - this.shape = new int[] {rows, cols}; + this.shape = new long[] {rows, cols}; this.learningRate = learningRate; } @@ -81,7 +81,7 @@ public class AdaGrad implements Serializable { this(rows, cols, 0.1); } - public AdaGrad(int[] shape, double learningRate) { + public AdaGrad(long[] shape, double learningRate) { this.shape = shape; this.learningRate = learningRate; } @@ -124,7 +124,7 @@ public class AdaGrad implements Serializable { return ret; } - public double getGradient(double gradient, int column, int[] shape) { + public double getGradient(double gradient, int column, long[] shape) { boolean historicalInitialized = false; if (this.historicalGradient == null) { this.historicalGradient = Nd4j.ones(shape); @@ -143,7 +143,7 @@ public class AdaGrad implements Serializable { return adjustedGradient; } - public INDArray getGradient(INDArray gradient, int slice, int[] shape) { + public INDArray getGradient(INDArray gradient, int slice, long[] shape) { boolean historicalInitialized = false; INDArray sqrtHistory; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java index 302b2d102..49c17fa11 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.util; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.tools.BTools; import org.nd4j.tools.InfoLine; import org.nd4j.tools.InfoValues; @@ -178,7 +179,9 @@ public class DataSetUtils { InfoValues iv; // double j_Dbl = -1; - // FIXME: int cast + if (in_INDA.rows() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } int i_CharsCount = BTools.getIndexCharsCount( (int) in_INDA.rows() - 1 ); // oinfo = ""; @@ -219,7 +222,8 @@ public class DataSetUtils { c_I = 0; // if ( ot_INDA != null ) { - // FIXME: int cast + if (ot_INDA.columns() - 1 > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); for ( int j = (int) ot_INDA.columns() - 1; j >= 0; j-- ) { // if ( c_I > c_End_I ) break; @@ -346,7 +350,8 @@ public class DataSetUtils { InfoValues iv; // double j_Dbl = -1; - // FIXME: int cast + if (INDA.rows() - 1 > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); int i_CharsCount = BTools.getIndexCharsCount( (int) INDA.rows() - 1 ); // if ( !turned ) { //= standard @@ -366,7 +371,8 @@ public class DataSetUtils { iv.vsL.add( BTools.getSInt( i, i_CharsCount ) ); // int c_I = 0; - // FIXME: int cast + if (INDA.columns() - 1 > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); for ( int j = (int) INDA.columns() - 1; j >= 0; j-- ) { // if ( c_I > c_End_I ) break; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/NDArrayUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/NDArrayUtil.java index fd09351fb..23ae9b984 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/NDArrayUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/NDArrayUtil.java @@ -84,7 +84,6 @@ public class NDArrayUtil { n = n.reshape(-1); - // FIXME: int cast long[] ret = new long[(int) n.length()]; for (int i = 0; i < n.length(); i++) ret[i] = (long) n.getFloat(i); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java index bf0fdf9cb..60bb0378b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressionDescriptor; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -91,7 +92,8 @@ public class BinarySerde { if (type != DataType.COMPRESSED) { ByteBuffer slice = byteBuffer.slice(); //wrap the data buffer for the last bit - // FIXME: int cast + if (Shape.length(shapeBuff) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff)); //advance past the data int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java index 8253b67bb..1abd7a3de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatResult; import org.nd4j.graph.FlatVariable; @@ -115,15 +116,17 @@ public class NativeGraphExecutioner implements GraphExecutioner { for (int e = 0; e < fr.variablesLength(); e++) { FlatVariable var = fr.variables(e); + String varName = var.name(); // log.info("Var received: id: [{}:{}/<{}>];", var.id().first(), var.id().second(), var.name()); FlatArray ndarray = var.ndarray(); - INDArray val = Nd4j.createFromFlatArray(ndarray); results[e] = val; if (var.name() != null && sd.variableMap().containsKey(var.name())) { - sd.associateArrayWithVariable(val, sd.variableMap().get(var.name())); + if(sd.getVariable(varName).getVariableType() != VariableType.ARRAY){ + sd.associateArrayWithVariable(val, sd.variableMap().get(var.name())); + } } else { if (sd.variableMap().get(var.name()) != null) { sd.associateArrayWithVariable(val,sd.getVariable(var.name())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index a060232db..8f621668b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1075,7 +1075,7 @@ public interface NativeOps { Pointer dX, @Cast("Nd4jLong *") LongPointer dXShapeInfo, @Cast("Nd4jLong *") LongPointer dxOffsets, Pointer hY, @Cast("Nd4jLong *") LongPointer hYShapeInfo, @Cast("Nd4jLong *") LongPointer hyOffsets, Pointer dY, @Cast("Nd4jLong *") LongPointer dYShapeInfo, @Cast("Nd4jLong *") LongPointer dyOffsets, - IntPointer hIndices, IntPointer dIndices); + Pointer hIndices, @Cast("Nd4jLong *") LongPointer hIndicesShapeInfo, Pointer dIndices, @Cast("Nd4jLong *") LongPointer dIndicesShapeInfo); //void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer); Pointer createUtf8String(PointerPointer extraPointers, String string, int length); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index ae31ea7b8..de9edfc2e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -20,6 +20,7 @@ import java.util.Properties; import lombok.Getter; import org.bytedeco.javacpp.Loader; import org.nd4j.config.ND4JEnvironmentVars; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.context.Nd4jContext; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; @@ -101,7 +102,12 @@ public class NativeOpsHolder { } //deviceNativeOps.setOmpNumThreads(4); - log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + + if(logInit) { + log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + } } catch (Exception | Error e) { throw new RuntimeException( "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java index c790b40ab..c7eada206 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java @@ -47,7 +47,7 @@ public class MemoryTracker { val f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i)); - log.debug("Free memory on device_{}: {}", i, f); + //log.debug("Free memory on device_{}: {}", i, f); freePerDevice.add(i, f); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java index a6a5a45e4..34970dc19 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java @@ -16,14 +16,24 @@ package org.nd4j.linalg.jcublas; +import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Loader; +import org.nd4j.config.ND4JSystemProperties; +import org.nd4j.linalg.api.environment.Nd4jEnvironment; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.Resource; +import org.nd4j.nativeblas.Nd4jCuda; + +import java.util.List; +import java.util.Map; +import java.util.Properties; /** * */ +@Slf4j public class JCublasBackend extends Nd4jBackend { @@ -76,4 +86,34 @@ public class JCublasBackend extends Nd4jBackend { return JCublasNDArray.class; } + @Override + public void logBackendInit() { + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + + if(logInit) { + try { + Nd4jCuda.Environment e = Nd4jCuda.Environment.getInstance(); + int blasMajor = e.blasMajorVersion(); + int blasMinor = e.blasMinorVersion(); + int blasPatch = e.blasPatchVersion(); + log.info("ND4J CUDA build version: {}.{}.{}", blasMajor, blasMinor, blasPatch); + int nGPUs = Nd4jEnvironment.getEnvironment().getNumGpus(); + + Properties props = Nd4j.getExecutioner().getEnvironmentInformation(); + List> devicesList = (List>) props.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY); + + for (int i = 0; i < nGPUs; i++) { + Map dev = devicesList.get(i); + String name = (String) dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY); + int major = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY)).intValue(); + int minor = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY)).intValue(); + long totalMem = ((Number) dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)).longValue(); + log.info("CUDA device {}: [{}]; cc: [{}.{}]; Total memory: [{}]", i, name, major, minor, totalMem); + } + } catch (Throwable t) { + log.debug("Error logging CUDA backend versions and devices", t); + } + } + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 9931fcaa9..0bcb6e562 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -455,7 +455,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray pullRows(INDArray source, int sourceDimension, long[] indexes) { - // FIXME: int cast return pullRows(source, sourceDimension, ArrayUtil.toInts(indexes)); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java index 3eade74e9..13991f63b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java @@ -850,7 +850,9 @@ public class JcublasLapack extends BaseLapack { if (A.ordering() == 'c') a = A.dup('f'); - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE) { + throw new RuntimeException("Rows overflow"); + } int M = (int) A.rows(); if (Nd4j.getExecutioner() instanceof GridExecutioner) @@ -925,7 +927,10 @@ public class JcublasLapack extends BaseLapack { if (A.ordering() == 'c') a = A.dup('f'); - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE) { + throw new RuntimeException("Rows overflow"); + } + int M = (int) A.rows(); if (Nd4j.getExecutioner() instanceof GridExecutioner) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 6c95d3ce5..8fe744b38 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1000,8 +1000,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dataType = op.resultType(); - val ret = Nd4j.createUninitialized(dataType, retShape); - op.setZ(ret); + if( op.z() == null ){ + val ret = Nd4j.createUninitialized(dataType, retShape); + op.setZ(ret); + } else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){ + throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) + + " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape())); + } val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType()); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null; @@ -1890,14 +1895,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { @SuppressWarnings("unchecked") public void printEnvironmentInformation() { super.printEnvironmentInformation(); - - Properties env = getEnvironmentInformation(); - - List> devicesList = (List>) env.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY); - for (Map dev : devicesList) { - log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY), - dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)); - } } @Override @@ -2466,7 +2463,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { nativeOps.scatterUpdate(stuff, op.ordinal(), (int) indices.length(), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadX.getFirst()), null, AtomicAllocator.getInstance().getPointer(array, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadX.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadX.getSecond()), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()), - null, (IntPointer) AtomicAllocator.getInstance().getPointer(indices, context)); + AtomicAllocator.getInstance().getHostPointer(indices), (LongPointer) AtomicAllocator.getInstance().getHostPointer(indices.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(indices, context), (LongPointer) AtomicAllocator.getInstance().getPointer(indices.shapeInfoDataBuffer(), context)); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index cf779f537..32f1b0a10 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -52,20 +52,26 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setIArguments(long... arguments) { - super.setIArguments(arguments); - nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setIArguments(arguments); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + } } @Override public void setBArguments(boolean... arguments) { - super.setBArguments(arguments); - nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setBArguments(arguments); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + } } @Override public void setTArguments(double... arguments) { - super.setTArguments(arguments); - nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setTArguments(arguments); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + } } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 0ddcb6266..efa70d691 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.2-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -598,6 +598,10 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public native @Cast("bool") boolean isCPU(); + public native int blasMajorVersion(); + public native int blasMinorVersion(); + public native int blasPatchVersion(); + public native @StdVector Pair capabilities(); } @@ -3045,19 +3049,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, @Cast("Nd4jLong*") LongPointer dXOffsets, Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer hYOffsets, Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, @Cast("Nd4jLong*") LongPointer dYOffsets, - IntPointer hIindexes, IntPointer dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") LongPointer dIndicesShapeInfo); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer hXOffsets, Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXOffsets, Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer hYOffsets, Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYOffsets, - IntBuffer hIindexes, IntBuffer dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") LongBuffer dIndicesShapeInfo); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] hXOffsets, Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, @Cast("Nd4jLong*") long[] dXOffsets, Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] hYOffsets, Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, @Cast("Nd4jLong*") long[] dYOffsets, - int[] hIindexes, int[] dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") long[] dIndicesShapeInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); @@ -3120,6 +3124,13 @@ public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); + +public native int binaryLevel(); +public native int optimalLevel(); + +public native @Cast("bool") boolean isMinimalRequirementsMet(); +public native @Cast("bool") boolean isOptimalRequirementsMet(); + // #endif //NATIVEOPERATIONS_NATIVEOPS_H @@ -4696,6 +4707,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * k - depth * value - scalar value to assign */ + public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices @@ -4924,7 +4936,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4955,7 +4967,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4972,7 +4984,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // TODO: do we really want a view here? auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -4988,7 +5000,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5007,7 +5019,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5024,7 +5036,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5040,7 +5052,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5054,7 +5066,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // FIXME auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -5070,7 +5082,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5085,7 +5097,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -6578,9 +6590,6 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { // #include // #include -// #ifdef HAVE_MKLDNN -// #endif - // CUDA-specific includes // #ifdef __CUDACC__ // #endif @@ -6647,8 +6656,6 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native int getBranch(); public native void setBranch(int branch); -// #ifdef HAVE_MKLDNN -// #endif /** * * @return @@ -7956,9 +7963,7 @@ public static final int PREALLOC_SIZE = 33554432; * @param indices the indices to iterate over * @return the double at the specified index */ - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank); + @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @@ -7979,34 +7984,26 @@ public static final int PREALLOC_SIZE = 33554432; /** * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to following coordinates - * -> [1, 1] in case of c order - * -> [1, 2] in case of f order + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] */ - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + + /** * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4}, then: - * in case of c order and coordinates [1, 1] index 5 is returned - * in case of f order and coordinates [1, 2] index 5 is returned + * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords, byte order/*='c'*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); /** @@ -8018,36 +8015,16 @@ public static final int PREALLOC_SIZE = 33554432; */ /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - * arrLen - array length */ - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - - /** - * Compute the real linear indices for the given shape and stride - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride); - - /** - * Compute the real linear indices for the - * given shape buffer. Shape,stride and rank are derived - * from the buffer - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices( @Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices( @Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices( @Cast("Nd4jLong*") long[] shapeBuffer); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); @@ -8326,20 +8303,62 @@ public static final int PREALLOC_SIZE = 33554432; * for the given rank and shape. */ -/** - * Compute the real linear indices for the given shape and stride - */ - -/** -* Compute the real linear indices for the given shape and stride -*/ - +////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////////// +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { + +// const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; + +// if(ews > 0 && order(shapeInfo) == 'c') +// if (ews == 1) +// return index; +// else +// return ews * index; + +// Nd4jLong offset = 0; +// Nd4jLong rank = shapeInfo[0]; +// for(int i = 1; i <= shapeInfo[0]; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { + +// const uint rank = shapeInfo[0]; +// const uint ews = shapeInfo[rank + rank + 2]; + +// if(ews > 0 && shapeInfo[rank + rank + 3] == 99) +// if (ews == 1) +// return index; +// else +// return ews * index; + +// uint offset = 0; + +// for(uint i = 1; i <= rank; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + + ////////////////////////////////////////////////////////////////////// /** @@ -8708,6 +8727,10 @@ public static final int PREALLOC_SIZE = 33554432; * @return the double at the specified index */ +////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////// + @@ -9036,6 +9059,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + @@ -9302,6 +9327,81 @@ public static final int PREALLOC_SIZE = 33554432; // #endif //LIBND4J_OPDESCRIPTOR_H +// Parsed from ops/declarable/PlatformHelper.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_PLATFORMHELPER_H +// #define SD_PLATFORMHELPER_H + +// #include +// #include +// #include +// #include +// #include + /** + * This abstract class defines methods used by platform-specific helpers implementations + */ + @Namespace("nd4j::ops::platforms") @NoOffset public static class PlatformHelper extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public PlatformHelper(Pointer p) { super(p); } + + + public native @StdString BytePointer name(); + + public native @Cast("Nd4jLong") long hash(); + + /** + * This method checks, if given helper can be used with given input/output/configuration options + * + * @param context + * @return + */ + public native @Cast("bool") boolean isUsable(@ByRef Context context); + + /** + * This method invokes helper. Typically this method replaces actual op execution + * + * @param context + * @return + */ + public native @Cast("Nd4jStatus") int invokeHelper(@ByRef Context context); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getZ(@ByRef Context ctx, int inputId); + } + + + + + +// #endif //SD_PLATFORMHELPER_H + + // Parsed from ops/declarable/BroadcastableOp.h /******************************************************************************* @@ -9879,6 +9979,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include // handlers part // #include @@ -9906,7 +10007,7 @@ public static final int PREALLOC_SIZE = 33554432; public native @Cast("char*") String getAllCustomOperations(); /** - * This method registers operation + * This method registers operation in our registry, so we can use them later * * @param op */ @@ -9914,10 +10015,16 @@ public static final int PREALLOC_SIZE = 33554432; public native @Cast("bool") boolean registerOperation(@Cast("char*") BytePointer name, DeclarableOp op); public native @Cast("bool") boolean registerOperation(DeclarableOp op); + public native void registerHelper(PlatformHelper op); + + public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash); + public native DeclarableOp getOperation(@Cast("char*") String name); public native DeclarableOp getOperation(@Cast("char*") BytePointer name); public native DeclarableOp getOperation(@Cast("Nd4jLong") long hash); + public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash); + public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); public native int numberOfOperations(); @@ -10044,6 +10151,10 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #endif +// used for MKLDNN etc +// #if !defined(__STANDALONE_BUILD__) +// #include "config.h" +// #endif // #include // #include @@ -10080,6 +10191,8 @@ public static final int PREALLOC_SIZE = 33554432; public native Workspace getWorkspace(); public native void setWorkspace(Workspace theWorkspace); + public native Pointer engine(); + public native int getDeviceID(); public native void setDeviceID(int deviceID); public native ErrorReference errorReference(); @@ -10438,6 +10551,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java index 7ba2aa6d7..627105bda 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuBackend.java @@ -60,4 +60,9 @@ public class CpuBackend extends Nd4jBackend { public Class getNDArrayClass() { return NDArray.class; } + + @Override + public void logBackendInit() { + //No additional logging for CPU backend + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index a328d788e..a1746134c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -120,7 +120,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { log.warn("Warning: Initializing ND4J with " + binLevel + " binary on a CPU with " + optLevel + " support"); log.warn("Using ND4J with " + optLevel + " will improve performance. See deeplearning4j.org/cpu for more details"); log.warn("Or set environment variable " + ND4JEnvironmentVars.ND4J_IGNORE_AVX + "=true to suppress this warning"); - log.warn("************************************************************************************************"); + log.warn("*************************************************************************************************"); } blas = new CpuBlas(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java index 39731f90f..70866f6f7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java @@ -20,6 +20,7 @@ import lombok.val; import org.nd4j.linalg.api.blas.impl.BaseLapack; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -40,7 +41,9 @@ public class CpuLapack extends BaseLapack { } protected static int getLda(INDArray A) { - // FIXME: int cast + if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } return A.ordering() == 'f' ? (int) A.rows() : (int) A.columns(); } //========================= diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 8db359d01..9431a3453 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -49,20 +49,26 @@ public class CpuOpContext extends BaseOpContext implements OpContext { @Override public void setIArguments(long... arguments) { - super.setIArguments(arguments); - nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setIArguments(arguments); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); + } } @Override public void setBArguments(boolean... arguments) { - super.setBArguments(arguments); - nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setBArguments(arguments); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); + } } @Override public void setTArguments(double... arguments) { - super.setTArguments(arguments); - nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + if (arguments.length > 0) { + super.setTArguments(arguments); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); + }; } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 663eb862e..a2964b7a6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1974,7 +1974,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { loop.scatterUpdate(null, op.ordinal(), (int) indices.length(), array.data().addressPointer(), (LongPointer) tadX.getFirst().addressPointer(), (LongPointer) tadX.getSecond().addressPointer(), null, null, null, updates.data().addressPointer(), (LongPointer) tadY.getFirst().addressPointer(), (LongPointer) tadY.getSecond().addressPointer(), null, null, null, - (IntPointer) indices.data().addressPointer(), null); + indices.data().addressPointer(), (LongPointer) indices.shapeInfoDataBuffer().addressPointer(), null, null); if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8d92e09ad..f915c8152 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.2-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -598,6 +598,10 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public native @Cast("bool") boolean isCPU(); + public native int blasMajorVersion(); + public native int blasMinorVersion(); + public native int blasPatchVersion(); + public native @StdVector Pair capabilities(); } @@ -3045,19 +3049,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, @Cast("Nd4jLong*") LongPointer dXOffsets, Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer hYOffsets, Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, @Cast("Nd4jLong*") LongPointer dYOffsets, - IntPointer hIindexes, IntPointer dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") LongPointer dIndicesShapeInfo); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer hXOffsets, Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXOffsets, Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer hYOffsets, Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYOffsets, - IntBuffer hIindexes, IntBuffer dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") LongBuffer dIndicesShapeInfo); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] hXOffsets, Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, @Cast("Nd4jLong*") long[] dXOffsets, Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] hYOffsets, Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, @Cast("Nd4jLong*") long[] dYOffsets, - int[] hIindexes, int[] dIindexes); + Pointer hIindexes, @Cast("Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, @Cast("Nd4jLong*") long[] dIndicesShapeInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); @@ -4703,6 +4707,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * k - depth * value - scalar value to assign */ + public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices @@ -4931,7 +4936,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4962,7 +4967,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) { } else { Nd4jLong idx[MAX_RANK]; shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -4979,7 +4984,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // TODO: do we really want a view here? auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -4995,7 +5000,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5014,7 +5019,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5031,7 +5036,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5047,7 +5052,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5061,7 +5066,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), coords); // FIXME auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); @@ -5077,7 +5082,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -5092,7 +5097,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { if (idx[i] >= sizeAt(i)) throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); + auto xOffset = shape::getOffset(getShapeInfo(), idx); auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); @@ -7958,9 +7963,7 @@ public static final int PREALLOC_SIZE = 33554432; * @param indices the indices to iterate over * @return the double at the specified index */ - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank); + @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @@ -7981,34 +7984,26 @@ public static final int PREALLOC_SIZE = 33554432; /** * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to following coordinates - * -> [1, 1] in case of c order - * -> [1, 2] in case of f order + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] */ - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); - @Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + + /** * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4}, then: - * in case of c order and coordinates [1, 1] index 5 is returned - * in case of f order and coordinates [1, 2] index 5 is returned + * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords, byte order/*='c'*/); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); /** @@ -8020,36 +8015,16 @@ public static final int PREALLOC_SIZE = 33554432; */ /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - * arrLen - array length */ - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo, @Cast("uint") int arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned); - - /** - * Compute the real linear indices for the given shape and stride - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride); - - /** - * Compute the real linear indices for the - * given shape buffer. Shape,stride and rank are derived - * from the buffer - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices( @Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices( @Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices( @Cast("Nd4jLong*") long[] shapeBuffer); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); + @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned); + @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); @@ -8328,20 +8303,62 @@ public static final int PREALLOC_SIZE = 33554432; * for the given rank and shape. */ -/** - * Compute the real linear indices for the given shape and stride - */ - -/** -* Compute the real linear indices for the given shape and stride -*/ - +////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////////// +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { + +// const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; + +// if(ews > 0 && order(shapeInfo) == 'c') +// if (ews == 1) +// return index; +// else +// return ews * index; + +// Nd4jLong offset = 0; +// Nd4jLong rank = shapeInfo[0]; +// for(int i = 1; i <= shapeInfo[0]; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { + +// const uint rank = shapeInfo[0]; +// const uint ews = shapeInfo[rank + rank + 2]; + +// if(ews > 0 && shapeInfo[rank + rank + 3] == 99) +// if (ews == 1) +// return index; +// else +// return ews * index; + +// uint offset = 0; + +// for(uint i = 1; i <= rank; ++i) { +// arrLen /= shapeInfo[i]; +// if(arrLen > 0 && shapeInfo[i] > 1) { +// offset += (index / arrLen) * shapeInfo[i + rank]; +// index %= arrLen; +// } +// } +// return offset; +// } + ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + + ////////////////////////////////////////////////////////////////////// /** @@ -8710,6 +8727,10 @@ public static final int PREALLOC_SIZE = 33554432; * @return the double at the specified index */ +////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////// + @@ -9038,6 +9059,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + @@ -12262,6 +12285,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include // #include // #include // #include @@ -13850,6 +13874,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: + * 1) if shapes are equal that's pairwise operation, result will have the same shape. + * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. + * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. + * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +// #if NOT_EXCLUDED(OP_divide_no_nan) + @Namespace("nd4j::ops") public static class divide_no_nan extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_no_nan(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_no_nan position(long position) { + return (divide_no_nan)super.position(position); + } + + public divide_no_nan() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif /** * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: * 1) if shapes are equal that's pairwise operation, result will have the same shape. @@ -14385,6 +14434,54 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); } // #endif + + /** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igamma) + @Namespace("nd4j::ops") public static class igamma extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igamma position(long position) { + return (igamma)super.position(position); + } + + public igamma() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + /** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igammac) + @Namespace("nd4j::ops") public static class igammac extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igammac(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igammac(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igammac position(long position) { + return (igammac)super.position(position); + } + + public igammac() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif @@ -15842,6 +15939,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + ////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) + @Namespace("nd4j::ops") public static class lstmLayer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer position(long position) { + return (lstmLayer)super.position(position); + } + + public lstmLayer() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi @@ -17079,16 +17196,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array * * Input arrays: - * input: input array, considered as batch of matrices - * diagonal: array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) * * Output array: - * has the same shape as input, corresponding diagonal elements are substituted + * 0: has the same shape as input, corresponding diagonal elements are substituted */ // #if NOT_EXCLUDED(OP_matrix_set_diag) @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { @@ -17109,8 +17226,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Returns a batched matrix tensor with diagonal values given (as TF.matrix_diag). - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, + * rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] + */ @Namespace("nd4j::ops") public static class matrix_diag extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -17130,13 +17255,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op calculates regularized incomplete beta integral Ix(a, b). * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then apply Gauss-Legendre quadrature method - * - when a and b are both <= maxValue (3000.), then apply modified Lentz’s algorithm for continued fractions + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied * * Input arrays: - * a: define power t^{a-1}, must be > 0, type float. - * b: define power (1-t)^{b-1}, must be > 0, type float. - * x: define upper limit of integration, must be within (0 <= x <= 1) range, type float. + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. * * Output array: * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float @@ -18250,6 +18375,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) + * Input arrays: + * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * + * T arguments: + * 0 - contrast factor + * + */ +// #if NOT_EXCLUDED(OP_adjust_contrast) + @Namespace("nd4j::ops") public static class adjust_contrast extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast position(long position) { + return (adjust_contrast)super.position(position); + } + + public adjust_contrast() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } + @Namespace("nd4j::ops") public static class adjust_contrast_v2 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast_v2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast_v2 position(long position) { + return (adjust_contrast_v2)super.position(position); + } + + public adjust_contrast_v2() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + + /** * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation @@ -19634,6 +19803,37 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * draw_bounding_boxes op - modified input image with given colors exept given boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), + * 3 (RGB) or 4 (RGBA) + * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as + * (y_min, x_min, y_max, x_max), all values in between 0. and 1. + * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +// #if NOT_EXCLUDED(OP_draw_bounding_boxes) + @Namespace("nd4j::ops") public static class draw_bounding_boxes extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public draw_bounding_boxes(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public draw_bounding_boxes position(long position) { + return (draw_bounding_boxes)super.position(position); + } + + public draw_bounding_boxes() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) * @@ -20496,10 +20696,14 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * 1 - scales - 1D-tensor with shape (num_boxes) by float type * 2 - output_size - 0D-tensor by int type (optional) * float args: - * 0 - threshold - threshold value for overlap checks (optional, by default 0.5) + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) * int args: * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. * + * output: + * - vector with size M, where M <= output_size by int type + * * */ // #if NOT_EXCLUDED(OP_image_non_max_suppression) @Namespace("nd4j::ops") public static class non_max_suppression extends DeclarableCustomOp { @@ -20519,6 +20723,39 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) + * int args: + * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size + * */ +// #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) + @Namespace("nd4j::ops") public static class non_max_suppression_overlaps extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression_overlaps(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression_overlaps(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression_overlaps position(long position) { + return (non_max_suppression_overlaps)super.position(position); + } + + public non_max_suppression_overlaps() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /* * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). * input: @@ -20623,6 +20860,67 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +/** + * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) + @Namespace("nd4j::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars_per_channel position(long position) { + return (fake_quant_with_min_max_vars_per_channel)super.position(position); + } + + public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +// #if NOT_EXCLUDED(OP_compare_and_bitpack) + @Namespace("nd4j::ops") public static class compare_and_bitpack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public compare_and_bitpack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public compare_and_bitpack position(long position) { + return (compare_and_bitpack)super.position(position); + } + + public compare_and_bitpack() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif @@ -21142,12 +21440,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * Local response normalization implementation as TF. * input: 4D array - * + * * T args: * * 0: bias @@ -21155,8 +21453,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * 2: beta * * Int arg: depth - optional local radius - * - * output - 4D array + * + * output - 4D array */ // #if NOT_EXCLUDED(OP_lrn) @Namespace("nd4j::ops") public static class lrn extends DeclarableOp { @@ -21178,10 +21476,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * Local response normalization - backprop variant. - * input: + * input: * 0 - 4D array of data * 1 - epsilon - 4D array of approximation - * + * * T args: * * 0: bias @@ -21211,21 +21509,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Batch normalization implementation. + * Batch normalization implementation. * Reference: https://arxiv.org/abs/1502.03167v3 - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: * variance: * gamma: * beta: - * + * * Int args: * 0: apply scale * 1: apply offset - * - * + * + * * T args: * 0: epsilon */ @@ -21246,27 +21544,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif -// #if NOT_EXCLUDED(OP_batchnorm_new) - @Namespace("nd4j::ops") public static class batchnorm_new extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm_new(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm_new(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm_new position(long position) { - return (batchnorm_new)super.position(position); - } - - public batchnorm_new() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif /** * back prop in batch normalization - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: @@ -21274,11 +21555,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * gamma: optional * beta: optional * dLdOut: next epsilon - * + * * Int args: * 0: apply scale - * 1: apply offset - * + * 1: apply offset + * * T args: * 0: epsilon * @@ -21286,8 +21567,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * dL/dInput * dL/dMean * dL/dVariance - * dL/dGamma - * dL/dBeta + * dL/dGamma, optional + * dL/dBeta, optional */ // #if NOT_EXCLUDED(OP_batchnorm) @Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp { @@ -21314,7 +21595,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * x: parameters, any shape * y: gradients. same shape as x * lr: optional, learning rate - * + * * T args: * 0: optional, learning rate */ @@ -21333,25 +21614,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public apply_sgd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + } // #endif /** * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167. * Expected arguments: * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width + * bS - batch size + * iH - input height + * iW - input width * iD - input depth (or number of channels) * scale: 1D input array of scale factors, shape [iD] * offset: 1D input array of offsets (shifts), shape [iD] * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * + * * T input arguments: * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * + * * integer input arguments: * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW * 1: isTraining, may have two values: zero -> inference, unity -> training @@ -23102,6 +23383,28 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + /** + * This operation change type of input and modified shape of output to conform with given data type + * + * all as above op + * */ +// #if NOT_EXCLUDED(OP_bitcast) + @Namespace("nd4j::ops") public static class bitcast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitcast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitcast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitcast position(long position) { + return (bitcast)super.position(position); + } + + public bitcast() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index 15ed777d4..09e94acc7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -17,10 +17,13 @@ package org.nd4j.autodiff; import org.junit.Test; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.AbstractSession; import org.nd4j.autodiff.samediff.internal.InferenceSession; +import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -73,7 +76,7 @@ public class TestSessions extends BaseNd4jTest { m.put("y", y); Map outMap = is.output(Collections.singletonList("out"), m, null, - Collections.emptyList(), true, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(outExp, outMap.get("out")); @@ -111,7 +114,7 @@ public class TestSessions extends BaseNd4jTest { System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList("d"), m, null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(dExp, outMap.get("d")); @@ -143,10 +146,10 @@ public class TestSessions extends BaseNd4jTest { System.out.println("----------------------------------"); InferenceSession is = new InferenceSession(sd); -// String outName = merge.getVarName(); - String outName = outVar.getVarName(); +// String outName = merge.name(); + String outName = outVar.name(); Map outMap = is.output(Collections.singletonList(outName), m, null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); INDArray out = outMap.get(outName); @@ -178,11 +181,11 @@ public class TestSessions extends BaseNd4jTest { m.put("b", bArr); InferenceSession is = new InferenceSession(sd); - String n = merge.getVarName(); + String n = merge.name(); System.out.println("----------------------------------"); Map outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), - false, null); + null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(expTrue, outMap.get(n)); @@ -191,12 +194,12 @@ public class TestSessions extends BaseNd4jTest { //Check false case: bArr.assign(0); is = new InferenceSession(sd); - outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), false, null); + outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(1, outMap.size()); assertEquals(expFalse, outMap.get(n)); } - @Test(timeout = 60000L) + @Test(timeout = 20000L) public void testSwitchWhile() throws Exception{ /* @@ -212,18 +215,19 @@ public class TestSessions extends BaseNd4jTest { for( int numIter : new int[]{1,3}) { File f = new ClassPathResource("tf_graphs/examples/while1/iter_" + numIter + "/frozen_model.pb").getFile(); - SameDiff sd = TFGraphMapper.getInstance().importGraph(f); + SameDiff sd = TFGraphMapper.importGraph(f); System.out.println(sd.summary()); System.out.println("----------------------------------"); //This particular test/graph doesn't use placeholders InferenceSession is = new InferenceSession(sd); + is.setMmgr(new NoOpMemoryMgr()); //So arrays aren't deallocated during execution String n = "while/Exit"; String n2 = "while/Exit_1"; Map m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, - Collections.emptyList(), false, null); + Collections.emptyList(), null, At.defaultAt(Operation.TRAINING)); assertEquals(2, m.size()); INDArray exp = Nd4j.scalar((float)numIter); @@ -231,7 +235,6 @@ public class TestSessions extends BaseNd4jTest { assertEquals(exp, m.get(n)); assertEquals(exp, m.get(n2)); - Map frameParents = is.getFrameParents(); Map outputs = is.getNodeOutputs(); //Some sanity checks on the internal state: //Check 1: "while/Less" should be executed numIter+1 times... i.e., numIter times through the loop, plus once to exit diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java index 177a4b795..8fee7f888 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java @@ -118,7 +118,7 @@ public class GraphExecutionerTest extends BaseNd4jTest { SDVariable result = sdVariable.add(scalarOne); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - log.info("TOTAL: {}; Id: {}", total.getVarName(), total); + log.info("TOTAL: {}; Id: {}", total.name(), total); INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java new file mode 100644 index 000000000..f9bd75d4a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java @@ -0,0 +1,166 @@ +package org.nd4j.autodiff.internal; + +import org.junit.Test; +import org.nd4j.autodiff.samediff.internal.DependencyList; +import org.nd4j.autodiff.samediff.internal.DependencyTracker; +import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.Collections; + +import static junit.framework.TestCase.assertNotNull; +import static org.junit.Assert.*; + +public class TestDependencyTracker { + + @Test + public void testSimple(){ + + DependencyTracker dt = new DependencyTracker<>(); + + dt.addDependency("y", "x"); + assertTrue(dt.hasDependency("y")); + assertFalse(dt.hasDependency("x")); + assertFalse(dt.hasDependency("z")); + + DependencyList dl = dt.getDependencies("y"); + assertEquals("y", dl.getDependencyFor()); + assertNotNull(dl.getDependencies()); + assertNull(dl.getOrDependencies()); + assertEquals(Collections.singletonList("x"), dl.getDependencies()); + + dt.removeDependency("y", "x"); + assertFalse(dt.hasDependency("y")); + assertFalse(dt.hasDependency("x")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() == null || dl.getOrDependencies().isEmpty()); + + + //Or dep + dt.addOrDependency("y", "x1", "x2"); + assertTrue(dt.hasDependency("y")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() != null && !dl.getOrDependencies().isEmpty()); + assertEquals(Collections.singletonList(new Pair<>("x1", "x2")), dl.getOrDependencies()); + + dt.removeDependency("y", "x1"); + assertFalse(dt.hasDependency("y")); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() == null || dl.getOrDependencies().isEmpty()); + + dt.addOrDependency("y", "x1", "x2"); + dl = dt.getDependencies("y"); + assertTrue(dl.getDependencies() == null || dl.getDependencies().isEmpty()); + assertTrue(dl.getOrDependencies() != null && !dl.getOrDependencies().isEmpty()); + assertEquals(Collections.singletonList(new Pair<>("x1", "x2")), dl.getOrDependencies()); + dt.removeDependency("y", "x2"); + assertTrue(dt.isEmpty()); + } + + @Test + public void testSatisfiedBeforeAdd(){ + DependencyTracker dt = new DependencyTracker<>(); + + //Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency + // and check that y is added to satisfied list... + dt.markSatisfied("x", true); + dt.addDependency("y", "x"); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("y", dt.getNewAllSatisfied()); + + //Same as above - x satisfied, add x->y, then add z->y + //y should go from satisfied to not satisfied + dt.clear(); + assertTrue(dt.isEmpty()); + dt.markSatisfied("x", true); + dt.addDependency("y", "x"); + assertTrue(dt.hasNewAllSatisfied()); + dt.addDependency("y", "z"); + assertFalse(dt.hasNewAllSatisfied()); + + + //x satisfied, then or(x,y) -> z added + dt.markSatisfied("x", true); + dt.addOrDependency("z", "x", "y"); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("z", dt.getNewAllSatisfied()); + + + //x satisfied, then or(x,y) -> z added, then or(a,b)->z added (should be unsatisfied) + dt.clear(); + assertTrue(dt.isEmpty()); + dt.markSatisfied("x", true); + dt.addOrDependency("z", "x", "y"); + assertTrue(dt.hasNewAllSatisfied()); + dt.addOrDependency("z", "a", "b"); + assertFalse(dt.hasNewAllSatisfied()); + } + + @Test + public void testMarkUnsatisfied(){ + + DependencyTracker dt = new DependencyTracker<>(); + dt.addDependency("y", "x"); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + + dt.markSatisfied("x", false); + assertFalse(dt.hasNewAllSatisfied()); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("y", dt.getNewAllSatisfied()); + assertFalse(dt.hasNewAllSatisfied()); + + + //Same for OR dependencies + dt.clear(); + assertTrue(dt.isEmpty()); + dt.addOrDependency("z", "x", "y"); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + + dt.markSatisfied("x", false); + assertFalse(dt.hasNewAllSatisfied()); + dt.markSatisfied("x", true); + assertTrue(dt.hasNewAllSatisfied()); + assertEquals("z", dt.getNewAllSatisfied()); + assertFalse(dt.hasNewAllSatisfied()); + } + + + @Test + public void testIdentityDependencyTracker(){ + IdentityDependencyTracker dt = new IdentityDependencyTracker<>(); + assertTrue(dt.isEmpty()); + + INDArray y1 = Nd4j.scalar(0); + INDArray y2 = Nd4j.scalar(0); + String x1 = "x1"; + dt.addDependency(y1, x1); + + assertFalse(dt.hasNewAllSatisfied()); + assertTrue(dt.hasDependency(y1)); + assertFalse(dt.hasDependency(y2)); + assertFalse(dt.isSatisfied(x1)); + + DependencyList dl = dt.getDependencies(y1); + assertSame(y1, dl.getDependencyFor()); //Should be same object + assertEquals(Collections.singletonList(x1), dl.getDependencies()); + assertNull(dl.getOrDependencies()); + + + //Mark as satisfied, check if it's added to list + dt.markSatisfied(x1, true); + assertTrue(dt.isSatisfied(x1)); + assertTrue(dt.hasNewAllSatisfied()); + INDArray get = dt.getNewAllSatisfied(); + assertSame(y1, get); + assertFalse(dt.hasNewAllSatisfied()); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 539901a41..b84d7ceea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -79,7 +79,7 @@ public class LayerOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sameDiff) .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); + .expectedOutput(res.name(), exp); System.out.println(sameDiff.summary()); System.out.println("============================"); @@ -112,7 +112,7 @@ public class LayerOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sameDiff) .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); + .expectedOutput(res.name(), exp); String err = OpValidation.validate(tc); @@ -123,27 +123,24 @@ public class LayerOpValidation extends BaseOpValidation { public void testBiasAdd() { Nd4j.getRandom().setSeed(12345); - for (boolean rank1Bias : new boolean[]{false, true}) { + SameDiff sameDiff = SameDiff.create(); + INDArray input = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(new long[]{2, 4}); + INDArray b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).divi(4); - SameDiff sameDiff = SameDiff.create(); - INDArray input = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(new long[]{2, 4}); - INDArray b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(rank1Bias ? new long[]{4} : new long[]{1, 4}).divi(4); + SDVariable sdInput = sameDiff.var("input", input); + SDVariable sdBias = sameDiff.var("bias", b); - SDVariable sdInput = sameDiff.var("input", input); - SDVariable sdBias = sameDiff.var("bias", b); + SDVariable res = sameDiff.nn().biasAdd(sdInput, sdBias, true); + SDVariable loss = sameDiff.standardDeviation(res, true); - SDVariable res = sameDiff.nn().biasAdd(sdInput, sdBias); - SDVariable loss = sameDiff.standardDeviation(res, true); + INDArray exp = input.addRowVector(b); - INDArray exp = input.addRowVector(b); + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(res.name(), exp); - TestCase tc = new TestCase(sameDiff) - .gradientCheck(true) - .expectedOutput(res.getVarName(), exp); - - String err = OpValidation.validate(tc); - assertNull(err); - } + String err = OpValidation.validate(tc); + assertNull(err); } @Test @@ -594,7 +591,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().sconv2d(vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); @@ -640,7 +637,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().sconv2d(vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 7, 7}, outShape); @@ -691,7 +688,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().deconv2d(vars, deconv); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 9, 9}, outShape); @@ -739,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable out = sd.cnn().conv2d("conv", vars, c); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, nOut, 27, 27}, outShape); @@ -773,7 +770,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); SDVariable out = sd.nn().tanh("out", outPool); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); @@ -831,7 +828,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); SDVariable out = sd.nn().tanh("out", outPool); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); @@ -999,7 +996,7 @@ public class LayerOpValidation extends BaseOpValidation { } ); - TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected); + TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.name(), expected); String err = OpValidation.validate(tc); assertNull(err); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index eb228bf1f..7f6daf78f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -423,11 +423,11 @@ public class MiscOpValidation extends BaseOpValidation { } SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical... - sd.execAndEndResult(); + TestCase tc = new TestCase(sd) .expected(scatter, exp) - .gradCheckSkipVariables(indices.getVarName()); + .gradCheckSkipVariables(indices.name()); String error = OpValidation.validate(tc); if(error != null){ @@ -493,7 +493,7 @@ public class MiscOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sd) .testName(msg) - .gradCheckSkipVariables(indices.getVarName()); + .gradCheckSkipVariables(indices.name()); if (gatherExp != null) { tc.expected(gather, gatherExp); @@ -586,18 +586,19 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable varMul = varMulPre.mul("d", sdVariable1); SDVariable sum = sameDiff.sum("ret", varMul, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); + Map m = sameDiff.outputAll(null); + Map gm = sameDiff.calculateGradients(null, m.keySet()); - SDVariable finalResult = sameDiff.grad(sum.getVarName()); + SDVariable finalResult = sameDiff.grad(sum.name()); - SDVariable cGrad = sameDiff.grad(varMulPre.getVarName()); + SDVariable cGrad = sameDiff.grad(varMulPre.name()); - SDVariable mulGradResult = sameDiff.grad(varMul.getVarName()); - SDVariable aGrad = sameDiff.grad(sdVariable.getVarName()); - SDVariable wGrad = sameDiff.grad(sdVariable1.getVarName()); - SDVariable dGrad = sameDiff.grad(varMul.getVarName()); + SDVariable mulGradResult = sameDiff.grad(varMul.name()); + SDVariable aGrad = sameDiff.grad(sdVariable.name()); + SDVariable wGrad = sameDiff.grad(sdVariable1.name()); + SDVariable dGrad = sameDiff.grad(varMul.name()); - INDArray scalarGradTest = finalResult.getArr(); + INDArray scalarGradTest = gm.get(sum.name()); assertEquals(scalar, scalarGradTest); @@ -737,11 +738,10 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable B2 = sd.var("B2", B); SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2}); - sd.exec(Collections.emptyMap(), sd.outputs()); - - INDArray resultingMatrix = batchMul[0].getArr(); - System.out.print(resultingMatrix); + Map m = sd.output(Collections.emptyMap(), sd.outputs()); + INDArray resultingMatrix = m.get(batchMul[0].name()); + //System.out.print(resultingMatrix); } @@ -769,14 +769,14 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable mmul = sd.f().mmul(f, s, mt); sd.updateVariableNameAndReference(mmul, "mmul"); - INDArray out = sd.execAndEndResult(); + INDArray out = mmul.eval(); INDArray exp = first.transpose().mmul(second); assertEquals(exp, out); SDVariable loss = sd.standardDeviation(mmul, true); String err = OpValidation.validate(new TestCase(sd) - .expected(mmul.getVarName(), exp)); + .expected(mmul.name(), exp)); assertNull(err); } @@ -1265,17 +1265,17 @@ public class MiscOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); - SDVariable merged = sd.math().mergeAvg(var); + SDVariable merged = sd.math().mergeAvg("merged", var); SDVariable sum = sd.sum(merged); - sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); + Map m = sd.output(Collections.emptyMap(), "merged"); + Map gm = sd.calculateGradients(null, "in"); - INDArray out = merged.getArr(); + INDArray out = m.get("merged"); assertEquals(1, out.rank()); - INDArray inGrad = var.getGradient().getArr(); - assertEquals(1, inGrad.rank()); //Fails here, getting rank 2 + INDArray inGrad = gm.get("in"); + assertEquals(1, inGrad.rank()); } @Test @@ -1286,7 +1286,7 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable var = sd.var("in", i); SDVariable diag = sd.math().diagPart(var); - INDArray out = sd.execAndEndResult(); + INDArray out = diag.eval(); assertEquals(1, out.rank()); } @@ -1643,10 +1643,10 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable v = new StopGradient(sd, w).outputVariable(); SDVariable loss = v.std(true); - sd.execBackwards(null); + Map gm = sd.calculateGradients(null, v.name(), w.name()); - INDArray vArr = v.getGradient().getArr(); - INDArray wArr = w.getGradient().getArr(); + INDArray vArr = gm.get(v.name()); + INDArray wArr = gm.get(w.name()); System.out.println(vArr); System.out.println(wArr); @@ -1668,18 +1668,18 @@ public class MiscOpValidation extends BaseOpValidation { INDArray expLoss = in.std(true); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput(checkNumerics.getVarName(), in) + .expectedOutput(checkNumerics.name(), in) .placeholderValue("in", in) .expectedOutput("loss", expLoss)); Preconditions.checkState(err == null, err); //Also check that it actually does what it's supposed to: - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); in.putScalar(0, Double.NaN); try { - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); fail("Expected exception"); } catch (Throwable t){ //OK @@ -1687,14 +1687,14 @@ public class MiscOpValidation extends BaseOpValidation { in.putScalar(0, Double.POSITIVE_INFINITY); try { - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); fail("Expected exception"); } catch (Throwable t){ //OK } in.putScalar(0, 0.0); - sd.execAll(Collections.singletonMap("in", in)); + sd.outputAll(Collections.singletonMap("in", in)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 24ca6540b..802ed9be9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -117,8 +117,8 @@ public class ReductionOpValidation extends BaseOpValidation { SDVariable loss = nonZero.add(zero).castTo(DataType.DOUBLE).std(true); String error = OpValidation.validate(new TestCase(sd) - .expectedOutput(nonZero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0)) - .expectedOutput(zero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0)) + .expectedOutput(nonZero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0)) + .expectedOutput(zero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0)) .gradientCheck(false) ); if (error != null) @@ -148,7 +148,7 @@ public class ReductionOpValidation extends BaseOpValidation { SDVariable zeroFraction = sd.math().zeroFraction(input); String error = OpValidation.validate(new TestCase(sd) - .expectedOutput(zeroFraction.getVarName(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f)) + .expectedOutput(zeroFraction.name(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f)) .gradientCheck(i != 0) ); if (error != null) @@ -429,7 +429,7 @@ public class ReductionOpValidation extends BaseOpValidation { tc.gradientCheck(gradientCheckable); if(exp != null){ - tc.expectedOutput(loss.getVarName(), exp); + tc.expectedOutput(loss.name(), exp); } String error = OpValidation.validate(tc); @@ -996,7 +996,7 @@ public class ReductionOpValidation extends BaseOpValidation { String msg = name + " - dims=" + Arrays.toString(reduceDims); - INDArray out = sd.execAndEndResult(); + INDArray out = reduced.eval(); log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape()) + ", outExp=" + Arrays.toString(expOut.shape())); @@ -1069,10 +1069,10 @@ public class ReductionOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = loss.eval(); assertEquals(1, result.length()); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 8ecdc4eac..988b8da69 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -76,11 +76,11 @@ public class RnnOpValidation extends BaseOpValidation { LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); //Weights and bias order: [i, f, z, o] @@ -179,11 +179,11 @@ public class RnnOpValidation extends BaseOpValidation { LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2}); //Input mod gate INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2}); //CS (pre tanh) @@ -233,11 +233,11 @@ public class RnnOpValidation extends BaseOpValidation { List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ - toExec.add(sdv.getVarName()); + toExec.add(sdv.name()); } //Test forward pass: - Map m = sd.exec(null, toExec); + Map m = sd.output(null, toExec); //Weights and bias order: [r, u], [c] diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 2cbdfd3fd..401dddc56 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -128,7 +128,7 @@ public class ShapeOpValidation extends BaseOpValidation { //Using stdev here: mean/sum would backprop the same gradient for each input... SDVariable stdev = sd.standardDeviation("out", reshape, true); - INDArray out = sd.execAndEndResult(); + INDArray out = stdev.eval(); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); String msg = "toShape=" + Arrays.toString(toShape) + ", order=" + order; @@ -244,10 +244,10 @@ public class ShapeOpValidation extends BaseOpValidation { //Using stdev here: mean/sum would backprop the same gradient for each input... SDVariable stdev = sd.standardDeviation("out", expand, true); - INDArray out = sd.execAndEndResult(); + Map m = sd.outputAll(null); INDArray expOut = in.getArr().std(true); - assertArrayEquals(expExpandShape, expand.getArr().shape()); + assertArrayEquals(expExpandShape, m.get(expand.name()).shape()); INDArray expExpand = inArr.dup('c').reshape(expExpandShape); String msg = "expandDim=" + i + ", source=" + p.getSecond(); @@ -256,7 +256,7 @@ public class ShapeOpValidation extends BaseOpValidation { TestCase tc = new TestCase(sd); tc.testName(msg) .expectedOutput("out", expOut) - .expectedOutput(expand.getVarName(), expExpand); + .expectedOutput(expand.name(), expExpand); String error = OpValidation.validate(tc); if(error != null){ @@ -304,19 +304,19 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray exp = inArr.dup('c').reshape('c', expShapePostSqueeze); - sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray squeezed = squeeze.getArr(); + INDArray squeezed = m.get(squeeze.name()); // assertArrayEquals(expShapePostSqueeze, squeezed.shape()); - INDArray out = sd.execAndEndResult(); + INDArray out = m.get(stdev.name()); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); assertEquals(expOut, out); String msg = "squeezeDim=" + i + ", source=" + p.getSecond(); TestCase tc = new TestCase(sd) .testName(msg) - .expected(squeeze.getVarName(), exp) + .expected(squeeze.name(), exp) .expectedOutput("out", expOut); @@ -546,7 +546,7 @@ public class ShapeOpValidation extends BaseOpValidation { .testName(msg); String error = OpValidation.validate(tc, true); if(error != null){ - failed.add(msg); + failed.add(msg + " - " + error); } } } @@ -618,7 +618,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable stack = sd.stack(axis, in); - INDArray out = sd.execAndEndResult(); + INDArray out = stack.eval(); assertArrayEquals(expOutShape, out.shape()); if (ArrayUtil.prodLong(shape) == 1) { @@ -712,9 +712,9 @@ public class ShapeOpValidation extends BaseOpValidation { String msg = "Unstacked shape = " + Arrays.toString(shape) + ", stacked shape = " + Arrays.toString(stackedShape) + ", axis=" + axis + ", numInputs=" + numInputs; - sd.execAndEndResult(); + Map m = sd.outputAll(null); for (SDVariable v : unstacked) { - assertArrayEquals(msg, shape, v.getArr().shape()); + assertArrayEquals(msg, shape, m.get(v.name()).shape()); } TestCase tc = new TestCase(sd).testName(msg); @@ -884,7 +884,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray exp = arr.dup('c').reshape('c', 4,3); String err = OpValidation.validate(new TestCase(sameDiff) - .expectedOutput(result1.getVarName(), exp)); + .expectedOutput(result1.name(), exp)); assertNull(err); } @@ -920,7 +920,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable result = sameDiff.transpose(x); SDVariable loss = sameDiff.standardDeviation(result, true); - String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.getVarName(), arr.transpose())); + String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.name(), arr.transpose())); assertNull(err); } @@ -1022,17 +1022,16 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new double[]{1,2,3}); SDVariable in = sd.var(ia); - SDVariable constant = sd.constant(in, 3); - SDVariable loss = constant.std(true); + SDVariable loss = in.std(true); - assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia))); + assertNull(OpValidation.validate(new TestCase(sd).expected(in, ia))); //Case 1: shape is provided + scalar sd = SameDiff.create(); ia = Nd4j.scalar(3.0); in = sd.var(ia); - constant = sd.constant(in, 3,4,5); + SDVariable constant = sd.constant(Nd4j.create(DataType.FLOAT, 3,4,5)); INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); loss = constant.std(true); @@ -1149,7 +1148,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable loss = sameDiff.standardDeviation(result, true); String err = OpValidation.validate(new TestCase(sameDiff) - .expected(result.getVarName(), expected) + .expected(result.name(), expected) .gradientCheck(false)); assertNull(err); } @@ -1172,7 +1171,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1196,7 +1195,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1227,7 +1226,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray outExp = Nd4j.scalar(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), outExp)); + .expected(md.name(), outExp)); assertNull(err); } @@ -1247,7 +1246,7 @@ public class ShapeOpValidation extends BaseOpValidation { //System.out.println(d); String err = OpValidation.validate(new TestCase(sd) - .expected(md.getVarName(), Nd4j.scalar(d))); + .expected(md.name(), Nd4j.scalar(d))); assertNull(err); } @@ -1332,7 +1331,7 @@ public class ShapeOpValidation extends BaseOpValidation { .testName(op) .expected(sm, exp) .gradientCheck(true) - .gradCheckSkipVariables(segments.getVarName()); + .gradCheckSkipVariables(segments.name()); String err = OpValidation.validate(tc); if(err != null) @@ -1383,7 +1382,7 @@ public class ShapeOpValidation extends BaseOpValidation { String err = OpValidation.validate(new TestCase(sameDiff) .expected(result1, expected) - .gradCheckSkipVariables(lengths.getVarName())); + .gradCheckSkipVariables(lengths.name())); assertNull(err); // Test with dynamic maxlen @@ -1591,8 +1590,8 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); SDVariable x = sameDiff.var("x", arr); SDVariable result = sameDiff.permute(x, 1, 0); - sameDiff.execAll(null); - assertArrayEquals(new long[]{3, 2}, result.getShape()); + Map m = sameDiff.outputAll(null); + assertArrayEquals(new long[]{3, 2}, m.get(result.name()).shape()); } @@ -1629,10 +1628,10 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice_full = sd.slice(in, new int[]{0, 0}, new int[]{3, 4}); SDVariable subPart = sd.slice(in, new int[]{1, 2}, new int[]{2, 2}); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(Collections.emptyMap()); - assertEquals(inArr, slice_full.getArr()); - assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); + assertEquals(inArr, m.get(slice_full.name())); + assertEquals(inArr.get(interval(1, 3), interval(2, 4)), m.get(subPart.name())); } @@ -1645,10 +1644,10 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice_full = sd.slice(in, new int[]{0, 0, 0}, new int[]{3, 4, 5}); SDVariable subPart = sd.slice(in, new int[]{1, 2, 3}, new int[]{2, 2, 1}); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(null); - assertEquals(inArr, slice_full.getArr()); - assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), subPart.getArr()); + assertEquals(inArr, m.get(slice_full.name())); + assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name())); } @Test @@ -1661,7 +1660,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); // SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr, slice_full.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); @@ -1678,7 +1677,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr()); @@ -1695,7 +1694,7 @@ public class ShapeOpValidation extends BaseOpValidation { //[1:3,...,1:4] -> [1:3,:,1:4] SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); - sd.execAll(Collections.emptyMap()); + sd.outputAll(Collections.emptyMap()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr()); @@ -1708,7 +1707,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); - INDArray out = sd.execAndEndResult(); + INDArray out = slice.eval(); assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape()); assertEquals(inArr, out.get(point(0), all(), all(), all())); @@ -1720,7 +1719,7 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); - INDArray out = sd.execAndEndResult(); + INDArray out = slice.eval(); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); } @@ -1735,7 +1734,7 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); - sd.execAll(null); + sd.outputAll(null); assertEquals(inArr.get(point(0), all(), all()), slice.getArr()); assertEquals(inArr.get(point(2), all(), all()), slice2.getArr()); @@ -1880,8 +1879,8 @@ public class ShapeOpValidation extends BaseOpValidation { // log.info(sd.summary()); - sd.exec(Collections.emptyMap(), Lists.newArrayList(s)); - sd.execBackwards(Collections.emptyMap()); + sd.output(Collections.emptyMap(), Lists.newArrayList(s)); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } } @@ -2405,8 +2404,8 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable gathered = sd.gather(input, indices, 1); SDVariable loss = gathered.std(true); - sd.exec(null, gathered.getVarName()); - sd.setLossVariables(gathered.getVarName()); + sd.output((Map)null, gathered.name()); + sd.setLossVariables(gathered.name()); String err = OpValidation.validate(new TestCase(sd) .gradCheckEpsilon(1e-3) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index fdd2b3160..bb3bab213 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -115,37 +115,37 @@ public class TransformOpValidation extends BaseOpValidation { switch (i){ case 0: out = in.mul(2); - tc.expectedOutput(out.getVarName(), inArr.mul(2)); + tc.expectedOutput(out.name(), inArr.mul(2)); msg = "mul - " + inOrder; break; case 1: out = in.div(2); - tc.expectedOutput(out.getVarName(), inArr.div(2)); + tc.expectedOutput(out.name(), inArr.div(2)); msg = "div - " + inOrder; break; case 2: out = in.add(2); - tc.expectedOutput(out.getVarName(), inArr.add(2)); + tc.expectedOutput(out.name(), inArr.add(2)); msg = "add - " + inOrder; break; case 3: out = in.sub(2); - tc.expectedOutput(out.getVarName(), inArr.sub(2)); + tc.expectedOutput(out.name(), inArr.sub(2)); msg = "sub - " + inOrder; break; case 4: out = in.rdiv(2); - tc.expectedOutput(out.getVarName(), inArr.rdiv(2)); + tc.expectedOutput(out.name(), inArr.rdiv(2)); msg = "rdiv - " + inOrder; break; case 5: out = in.rsub(2); - tc.expectedOutput(out.getVarName(), inArr.rsub(2)); + tc.expectedOutput(out.name(), inArr.rsub(2)); msg = "rsub - " + inOrder; break; case 6: out = sd.math().pow(in,2); - tc.expectedOutput(out.getVarName(), Transforms.pow(inArr, 2)); + tc.expectedOutput(out.name(), Transforms.pow(inArr, 2)); msg = "pow - " + inOrder; break; case 7: @@ -584,219 +584,219 @@ public class TransformOpValidation extends BaseOpValidation { switch (i) { case 0: t = in.add(5.0); - tc.expectedOutput(t.getVarName(), ia.add(5.0)); + tc.expectedOutput(t.name(), ia.add(5.0)); break; case 1: t = in.sub(5.0); - tc.expectedOutput(t.getVarName(), ia.sub(5.0)); + tc.expectedOutput(t.name(), ia.sub(5.0)); break; case 2: t = in.mul(2.5); - tc.expectedOutput(t.getVarName(), ia.mul(2.5)); + tc.expectedOutput(t.name(), ia.mul(2.5)); break; case 3: t = in.div(4.0); - tc.expectedOutput(t.getVarName(), ia.div(4.0)); + tc.expectedOutput(t.name(), ia.div(4.0)); break; case 4: t = in.rsub(5.0); - tc.expectedOutput(t.getVarName(), ia.rsub(5.0)); + tc.expectedOutput(t.name(), ia.rsub(5.0)); break; case 5: t = in.rdiv(1.0); - tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); + tc.expectedOutput(t.name(), ia.rdiv(1.0)); break; case 6: t = sd.math().pow(in, 2.5); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.5, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 2.5, true)); break; case 7: t = sd.nn().sigmoid(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.sigmoid(ia, true)); + tc.expectedOutput(t.name(), Transforms.sigmoid(ia, true)); break; case 8: t = sd.math().tanh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.tanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.tanh(ia, true)); break; case 9: ia.assign(Nd4j.rand(DataType.DOUBLE, ia.shape())); t = sd.math().tan(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.tan(ia)); + tc.expectedOutput(t.name(), Transforms.tan(ia)); break; case 10: t = sd.math().cos(in); - tc.expectedOutput(t.getVarName(), Transforms.cos(ia, true)); + tc.expectedOutput(t.name(), Transforms.cos(ia, true)); break; case 11: t = sd.math().sin(in); - tc.expectedOutput(t.getVarName(), Transforms.sin(ia, true)); + tc.expectedOutput(t.name(), Transforms.sin(ia, true)); break; case 12: t = sd.nn().softplus(in); - tc.expectedOutput(t.getVarName(), Transforms.softPlus(ia, true)); + tc.expectedOutput(t.name(), Transforms.softPlus(ia, true)); break; case 13: t = sd.math().log(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.log(ia, true)); + tc.expectedOutput(t.name(), Transforms.log(ia, true)); break; case 14: t = sd.math().neg(in); INDArray exp14 = ia.neg(); - tc.expectedOutput(t.getVarName(), exp14); + tc.expectedOutput(t.name(), exp14); break; case 15: t = sd.math().acos(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.acos(ia, true)); + tc.expectedOutput(t.name(), Transforms.acos(ia, true)); break; case 16: t = sd.math().acosh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).addi(1.01); //Only defined for x >= 1 - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ACosh(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ACosh(ia.dup()))); break; case 17: t = sd.math().asin(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.asin(ia, true)); + tc.expectedOutput(t.name(), Transforms.asin(ia, true)); break; case 18: t = sd.math().atan(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(4).subi(2); - tc.expectedOutput(t.getVarName(), Transforms.atan(ia, true)); + tc.expectedOutput(t.name(), Transforms.atan(ia, true)); break; case 19: t = sd.math().atanh(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); - tc.expectedOutput(t.getVarName(), Transforms.atanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.atanh(ia, true)); break; case 20: t = sd.math().cosh(in); - tc.expectedOutput(t.getVarName(), Transforms.cosh(ia, true)); + tc.expectedOutput(t.name(), Transforms.cosh(ia, true)); break; case 21: t = sd.math().cube(in); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 3.0, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 3.0, true)); break; case 22: t = sd.nn().elu(in); - tc.expectedOutput(t.getVarName(), Transforms.elu(ia, true)); + tc.expectedOutput(t.name(), Transforms.elu(ia, true)); break; case 23: //TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? t = sd.nn().softmax(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); break; case 24: t = sd.math().sqrt(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.sqrt(ia, true)); + tc.expectedOutput(t.name(), Transforms.sqrt(ia, true)); break; case 25: t = sd.math().square(in); - tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.0, true)); + tc.expectedOutput(t.name(), Transforms.pow(ia, 2.0, true)); break; case 26: t = sd.transpose(in); - tc.expectedOutput(t.getVarName(), ia.transpose().dup()); + tc.expectedOutput(t.name(), ia.transpose().dup()); break; case 27: t = sd.math().abs(in); - tc.expectedOutput(t.getVarName(), Transforms.abs(ia, true)); + tc.expectedOutput(t.name(), Transforms.abs(ia, true)); break; case 28: t = sd.math().sinh(in); - tc.expectedOutput(t.getVarName(), Transforms.sinh(ia, true)); + tc.expectedOutput(t.name(), Transforms.sinh(ia, true)); break; case 29: t = sd.math().asinh(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ASinh(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ASinh(ia.dup()))); break; case 30: t = sd.math().exp(in); - tc.expectedOutput(t.getVarName(), Transforms.exp(ia, true)); + tc.expectedOutput(t.name(), Transforms.exp(ia, true)); break; case 31: t = sd.math().floor(in); - tc.expectedOutput(t.getVarName(), Transforms.floor(ia, true)); + tc.expectedOutput(t.name(), Transforms.floor(ia, true)); break; case 32: t = sd.nn().relu(in, 0.0); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.relu(ia, true)); + tc.expectedOutput(t.name(), Transforms.relu(ia, true)); break; case 33: t = sd.nn().hardTanh(in); ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0); - tc.expectedOutput(t.getVarName(), Transforms.hardTanh(ia, true)); + tc.expectedOutput(t.name(), Transforms.hardTanh(ia, true)); break; case 34: t = sd.nn().logSigmoid(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup()))); break; case 35: t = sd.nn().swish(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Swish(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Swish(ia.dup()))); break; case 36: t = sd.math().sign(in); - tc.expectedOutput(t.getVarName(), Transforms.sign(ia, true)); + tc.expectedOutput(t.name(), Transforms.sign(ia, true)); break; case 37: t = sd.nn().softsign(in); - tc.expectedOutput(t.getVarName(), Transforms.softsign(ia, true)); + tc.expectedOutput(t.name(), Transforms.softsign(ia, true)); break; case 38: t = sd.nn().leakyRelu(in, 0.0); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.leakyRelu(ia, true)); + tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true)); break; case 39: if(OpValidationSuite.IGNORE_FAILING) continue; t = sd.nn().logSoftmax(in); ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5); - tc.expectedOutput(t.getVarName(), Transforms.log(Transforms.softmax(ia, true))); + tc.expectedOutput(t.name(), Transforms.log(Transforms.softmax(ia, true))); stdevLoss = true; break; case 40: t = sd.nn().selu(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SELU(ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SELU(ia.dup()))); break; case 41: t = sd.gt(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 42: t = sd.gte(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 43: t = sd.lt(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 44: t = sd.lte(in, 1.0).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 45: t = sd.eq(in, 2.0).castTo(DataType.DOUBLE); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); - tc.expectedOutput(t.getVarName(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 46: t = sd.neq(in, 2.0).castTo(DataType.DOUBLE); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); - tc.expectedOutput(t.getVarName(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); break; case 47: t = sd.math().ceil(in); - tc.expectedOutput(t.getVarName(), Transforms.ceil(ia, true)); + tc.expectedOutput(t.name(), Transforms.ceil(ia, true)); break; case 48: ia = Nd4j.randn(DataType.DOUBLE, ia.shape()).muli(2); @@ -804,7 +804,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut48 = ia.dup(); BooleanIndexing.replaceWhere(expOut48, -3, Conditions.lessThan(-3)); BooleanIndexing.replaceWhere(expOut48, 2, Conditions.greaterThan(2)); - tc.expectedOutput(t.getVarName(), expOut48); + tc.expectedOutput(t.name(), expOut48); break; case 49: //Clip by norm, dimension 0, some below threshold, some above @@ -825,7 +825,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut49.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue())); } } - tc.expectedOutput(t.getVarName(), expOut49); + tc.expectedOutput(t.name(), expOut49); //System.out.println(expOut.norm2(0)); break; //TODO clip by norm along other dimensions @@ -837,7 +837,7 @@ public class TransformOpValidation extends BaseOpValidation { .addIntegerArguments(dim) .addInputs(ia).addOutputs(expOut50).build(); Nd4j.getExecutioner().exec(reverse); - tc.expectedOutput(t.getVarName(), expOut50); + tc.expectedOutput(t.name(), expOut50); break; case 51: dim = 0; @@ -850,7 +850,7 @@ public class TransformOpValidation extends BaseOpValidation { .addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim) .addInputs(ia).addOutputs(expOut51).build(); Nd4j.getExecutioner().exec(cumsum); - tc.expectedOutput(t.getVarName(), expOut51); + tc.expectedOutput(t.name(), expOut51); break; case 52: if(OpValidationSuite.IGNORE_FAILING){ @@ -869,7 +869,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut52.putScalar(s0, s1, prod); } } - tc.expectedOutput(t.getVarName(), expOut52); + tc.expectedOutput(t.name(), expOut52); break; case 53: if(OpValidationSuite.IGNORE_FAILING){ @@ -881,90 +881,90 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut53 = Nd4j.create(DataType.DOUBLE, 2, 2); DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut53).build(); Nd4j.getExecutioner().exec(op); - tc.expectedOutput(t.getVarName(), expOut53); + tc.expectedOutput(t.name(), expOut53); break; case 54: t = sd.math().erf(in); INDArray expOut54 = Nd4j.createUninitialized(DataType.DOUBLE, ia.shape(), ia.ordering()); Nd4j.getExecutioner().exec(new Erf(ia, expOut54)); - tc.expectedOutput(t.getVarName(), expOut54); + tc.expectedOutput(t.name(), expOut54); break; case 55: t = sd.math().erfc(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering())))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering())))); break; case 56: t = sd.math().expm1(in); - tc.expectedOutput(t.getVarName(),Transforms.expm1(ia, true)); + tc.expectedOutput(t.name(),Transforms.expm1(ia, true)); break; case 57: t = sd.math().log1p(in); ia = Nd4j.rand(minibatch, nOut); - tc.expectedOutput(t.getVarName(), Transforms.log1p(ia, true)); + tc.expectedOutput(t.name(), Transforms.log1p(ia, true)); break; case 58: t = sd.math().round(in); - tc.expectedOutput(t.getVarName(), Transforms.round(ia, true)); + tc.expectedOutput(t.name(), Transforms.round(ia, true)); break; case 59: ia = Nd4j.create(new float[]{4, 2}).castTo(DataType.DOUBLE); // in = sd.var("in", new int[]{1, 2}); t = sd.math().rsqrt(in); - tc.expectedOutput(t.getVarName(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering())))); + tc.expectedOutput(t.name(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering())))); break; case 60: t = sd.nn().relu6(in, 0); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(),Transforms.relu6(ia, true)); + tc.expectedOutput(t.name(),Transforms.relu6(ia, true)); break; case 61: ia = Nd4j.create(new float[] {2, 2}).castTo(DataType.DOUBLE); sd.associateArrayWithVariable(ia, in); double value = 42; t = sd.fill(in.castTo(DataType.INT), DataType.DOUBLE, value); - tc.expectedOutput(t.getVarName(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false); opName = "fill"; break; case 62: t = sd.nn().hardSigmoid(in); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup()))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup()))); break; case 63: t = sd.scalarMax(in, 0.5); - tc.expectedOutput(t.getVarName(), Transforms.max(ia, 0.5, true)); + tc.expectedOutput(t.name(), Transforms.max(ia, 0.5, true)); break; case 64: t = sd.scalarMin(in, 0.5); - tc.expectedOutput(t.getVarName(), Transforms.min(ia, 0.5, true)); + tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true)); break; case 65: t = sd.assign(in, 0.5); - tc.expectedOutput(t.getVarName(), ia.dup().assign(0.5)); + tc.expectedOutput(t.name(), ia.dup().assign(0.5)); break; case 66: t = sd.scalarFloorMod(in, 0.5); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); break; case 67: t = sd.math().reciprocal(in); - tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); + tc.expectedOutput(t.name(), ia.rdiv(1.0)); break; case 68: t = sd.shape(in).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false); break; case 69: t = sd.rank(in).castTo(DataType.DOUBLE); - tc.expectedOutput(t.getVarName(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); + tc.expectedOutput(t.name(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); break; case 70: t = sd.onesLike(in); - tc.expectedOutput(t.getVarName(), Nd4j.ones(ia.shape())); + tc.expectedOutput(t.name(), Nd4j.ones(ia.shape())); break; case 71: ia = Nd4j.randn(DataType.DOUBLE, nOut, nOut); t = sd.math().diagPart(in); - tc.expectedOutput(t.getVarName(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE)); + tc.expectedOutput(t.name(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE)); break; case 72: t = sd.identity(in); @@ -1087,109 +1087,109 @@ public class TransformOpValidation extends BaseOpValidation { switch (i) { case 0: t = in1.add(in2); - tc.expectedOutput(t.getVarName(), ia.add(ib)); + tc.expectedOutput(t.name(), ia.add(ib)); break; case 1: t = in1.sub(in2); - tc.expectedOutput(t.getVarName(),ia.sub(ib)); + tc.expectedOutput(t.name(),ia.sub(ib)); break; case 2: t = in1.mul(in2); - tc.expectedOutput(t.getVarName(), ia.mul(ib)); + tc.expectedOutput(t.name(), ia.mul(ib)); break; case 3: t = in1.div(in2); - tc.expectedOutput(t.getVarName(), ia.div(ib)); + tc.expectedOutput(t.name(), ia.div(ib)); break; case 4: t = in1.rsub(in2); - tc.expectedOutput(t.getVarName(), ia.rsub(ib)); + tc.expectedOutput(t.name(), ia.rsub(ib)); break; case 5: ia.assign(Nd4j.rand(ia.shape())).addi(0.5); ib.assign(Nd4j.rand(ib.shape())).addi(0.5); t = in1.rdiv(in2); - tc.expectedOutput(t.getVarName(), ia.rdiv(ib)); + tc.expectedOutput(t.name(), ia.rdiv(ib)); break; case 6: t = sd.eq(in1, in2); opName = "eq"; - tc.expectedOutput(t.getVarName(), ia.eq(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.eq(ib)).gradientCheck(false); break; case 7: t = sd.neq(in1, in2); opName = "neq"; - tc.expectedOutput(t.getVarName(), ia.neq(ib)).gradientCheck(false);; + tc.expectedOutput(t.name(), ia.neq(ib)).gradientCheck(false);; break; case 8: t = sd.gt(in1, in2); opName = "gt"; - tc.expectedOutput(t.getVarName(), ia.gt(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.gt(ib)).gradientCheck(false); break; case 9: t = sd.lt(in1, in2); opName = "lt"; - tc.expectedOutput(t.getVarName(), ia.lt(ib)).gradientCheck(false); + tc.expectedOutput(t.name(), ia.lt(ib)).gradientCheck(false); break; case 10: t = sd.gte(in1, in2); opName = "gte"; INDArray expOut10 = Nd4j.create(DataType.BOOL, ia.shape()); Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut10})); - tc.expectedOutput(t.getVarName(), expOut10).gradientCheck(false); + tc.expectedOutput(t.name(), expOut10).gradientCheck(false); break; case 11: t = sd.lte(in1, in2); opName = "lte"; INDArray expOut11 = Nd4j.create(DataType.BOOL, ia.shape()); Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut11})); - tc.expectedOutput(t.getVarName(), expOut11).gradientCheck(false); + tc.expectedOutput(t.name(), expOut11).gradientCheck(false); break; case 12: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "or"; - tc.expectedOutput(t.getVarName(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 13: ib = Nd4j.randn(DataType.DOUBLE, nOut, nOut); t = sd.mmul(in1, in2); - tc.expectedOutput(t.getVarName(), ia.mmul(ib)); + tc.expectedOutput(t.name(), ia.mmul(ib)); break; case 14: t = sd.max(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); break; case 15: t = sd.min(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); + tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); break; case 16: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "and"; - tc.expectedOutput(t.getVarName(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 17: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); opName = "xor"; - tc.expectedOutput(t.getVarName(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); + tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); break; case 18: t = sd.assign(in1, in2); - tc.expectedOutput(t.getVarName(), ib); + tc.expectedOutput(t.name(), ib); break; case 19: t = sd.math().atan2(in1, in2); - tc.expectedOutput(t.getVarName(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms + tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms break; case 20: t = sd.math().mergeAdd(in1, in2, in2); - tc.expectedOutput(t.getVarName(), ia.add(ib).add(ib)); + tc.expectedOutput(t.name(), ia.add(ib).add(ib)); break; case 21: t = in1.squaredDifference(in2); @@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation { .addOutputs(expOut21) .build(); Nd4j.getExecutioner().exec(squareDiff); - tc.expectedOutput(t.getVarName(), expOut21); + tc.expectedOutput(t.name(), expOut21); break; case 22: //set diag @@ -1210,7 +1210,7 @@ public class TransformOpValidation extends BaseOpValidation { expOut22.putScalar(j,j, ib.getDouble(j)); } t = sd.math().setDiag(in1, in2); - tc.expectedOutput(t.getVarName(), expOut22); + tc.expectedOutput(t.name(), expOut22); break; default: throw new RuntimeException(); @@ -1341,7 +1341,6 @@ public class TransformOpValidation extends BaseOpValidation { } } - //TODO UPDATE TO OP VALIDATION OR DELETE @Test public void testLogGrad() { SameDiff sameDiff = SameDiff.create(); @@ -1349,7 +1348,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable log = sameDiff.math().log(input); SDVariable sum = sameDiff.sum(log, Integer.MAX_VALUE); INDArray result = null; - sameDiff.execBackwards(Collections.emptyMap()); + sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); } @@ -1362,8 +1361,8 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable sigmoid = sameDiff.nn().sigmoid(input); SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); - INDArray arr = input.gradient().getArr(); + Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + INDArray arr = m.get(input.name()); assertTrue(Nd4j.create(new double[][]{ {0.1966, 0.1050}, {0.0452, 0.0177} @@ -1384,12 +1383,12 @@ public class TransformOpValidation extends BaseOpValidation { public void testRank0EdgeCase(){ SameDiff sd = SameDiff.create(); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); - double d0 = sd.execAndEndResult().getDouble(0); + double d0 = v1.eval().getDouble(0); assertEquals(8, d0, 0); SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0); - sd.exec(Collections.emptyMap(), sd.outputs()); - double d1 = v2.getArr().getDouble(0); + Map m = sd.outputAll(Collections.emptyMap()); + double d1 = m.get(v2.name()).getDouble(0); assertEquals(4, d1, 0); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index bec6e0349..9d89aec91 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -78,82 +78,6 @@ public class FailingSameDiffTests extends BaseNd4jTest { assertArrayEquals(new long[]{3,3}, list.get(0).getShape()); } - @Test(timeout = 10000L) - public void testWhileLoop() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.exec(Collections.emptyMap()); - } - - @Test(timeout = 10000L) - public void testWhileBackwards() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff exec = sameDiff.getFunction("grad"); - } - - @Test(timeout = 10000L) - public void testWhileLoop2() { - OpValidationSuite.ignoreFailing(); - SameDiff sameDiff = SameDiff.create(); - sameDiff.whileStatement(new DefaultSameDiffConditional(), new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable eqResult = sameDiff.neq(variableInputs[0], variableInputs[1]); - return new SDVariable[]{eqResult}; - } - }, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].add(1.0); - return new SDVariable[]{variableInputs[0], ret}; - } - }, new SDVariable[]{ - sameDiff.one("one", new long[]{1, 1}), - sameDiff.var("two", new long[]{1, 1}), - - }); - - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - } - @Test public void testExecutionDifferentShapesTransform(){ OpValidationSuite.ignoreFailing(); @@ -163,12 +87,12 @@ public class FailingSameDiffTests extends BaseNd4jTest { SDVariable tanh = sd.math().tanh(in); INDArray exp = Transforms.tanh(in.getArr(), true); - INDArray out = sd.execAndEndResult(); + INDArray out = tanh.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = tanh.eval(); assertArrayEquals(new long[]{5,4}, out2.shape()); exp = Transforms.tanh(in.getArr(), true); @@ -200,12 +124,12 @@ public class FailingSameDiffTests extends BaseNd4jTest { SDVariable mmul = sd.mmul(in,w).add(b); INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); - INDArray out = sd.execAndEndResult(); + INDArray out = mmul.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = mmul.eval(); assertArrayEquals(new long[]{5,5}, out2.shape()); exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); @@ -213,11 +137,10 @@ public class FailingSameDiffTests extends BaseNd4jTest { //Generate gradient function, and exec SDVariable loss = mmul.std(true); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); in.setArray(Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); - sd.execAndEndResult(); - out2 = mmul.getArr(); + out2 = mmul.eval(); assertArrayEquals(new long[]{3,5}, out2.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index c291a5556..e704c9337 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -173,7 +173,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } if(execFirst){ - sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); + sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); } File f = testDir.newFile(); @@ -186,7 +186,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { List varsRestored = restored.variables(); assertEquals(varsOrig.size(), varsRestored.size()); for (int j = 0; j < varsOrig.size(); j++) { - assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName()); + assertEquals(varsOrig.get(j).name(), varsRestored.get(j).name()); } DifferentialFunction[] fOrig = sd.ops(); @@ -200,10 +200,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { assertEquals(sd.getLossVariables(), restored.getLossVariables()); - Map m = sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); - INDArray outOrig = m.get(x.getVarName()); - Map m2 = restored.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); - INDArray outRestored = m2.get(x.getVarName()); + Map m = sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); + INDArray outOrig = m.get(x.name()); + Map m2 = restored.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); + INDArray outRestored = m2.get(x.name()); assertEquals(String.valueOf(i), outOrig, outRestored); @@ -317,10 +317,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } for(SDVariable v : sd.variables()){ - if(v.isPlaceHolder()) + if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY) continue; - SDVariable v2 = sd2.getVariable(v.getVarName()); + SDVariable v2 = sd2.getVariable(v.name()); INDArray a1 = v.getArr(); INDArray a2 = v2.getArr(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java index 083fdbfe8..b3b82e16c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java @@ -57,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest { SDVariable sub = add.sub(add2); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.name()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.name()))); - assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName()))); - assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName()))); - assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName()))); + assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.name()))); + assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.name()))); + assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.name()))); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class)); @@ -76,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest { assertEquals(2, l.size()); SubGraph sg1 = l.get(0); - assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName())); + assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.name())); assertEquals(0, sg1.getChildNodes().size()); SubGraph sg2 = l.get(1); - assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName())); + assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.name())); assertEquals(0, sg2.getChildNodes().size()); } @@ -118,7 +118,7 @@ public class GraphTransformUtilTests extends BaseNd4jTest { }); INDArray exp2 = p1.div(p2).mul(p1.sub(p2)); - INDArray out2 = sd2.getVariable(mul.getVarName()).eval(); + INDArray out2 = sd2.getVariable(mul.name()).eval(); assertEquals(exp2, out2); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 659fc5438..9532cb0f5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -33,18 +33,18 @@ public class NameScopeTests extends BaseNd4jTest { SDVariable v = sd.var("x"); try(NameScope ns = sd.withNameScope("nameScope")){ SDVariable v2 = sd.var("x2"); - assertEquals("nameScope/x2", v2.getVarName()); + assertEquals("nameScope/x2", v2.name()); assertTrue(sd.getVariables().containsKey("nameScope/x2")); assertEquals("nameScope", sd.currentNameScope()); SDVariable v3 = sd.var("x"); - assertEquals("nameScope/x", v3.getVarName()); + assertEquals("nameScope/x", v3.name()); assertTrue(sd.getVariables().containsKey("nameScope/x")); try(NameScope ns2 = sd.withNameScope("scope2")){ assertEquals("nameScope/scope2", sd.currentNameScope()); SDVariable v4 = sd.var("x"); - assertEquals("nameScope/scope2/x", v4.getVarName()); + assertEquals("nameScope/scope2/x", v4.name()); assertTrue(sd.getVariables().containsKey("nameScope/scope2/x")); } @@ -76,19 +76,19 @@ public class NameScopeTests extends BaseNd4jTest { } SDVariable a = sd.var("a", DataType.FLOAT, 1); - assertEquals("x", x.getVarName()); - assertEquals("s1/y", y.getVarName()); - assertEquals("s1/s2/z", z.getVarName()); - assertEquals("a", a.getVarName()); + assertEquals("x", x.name()); + assertEquals("s1/y", y.name()); + assertEquals("s1/s2/z", z.name()); + assertEquals("a", a.name()); - assertTrue(add.getVarName(), add.getVarName().startsWith("s1/")); - assertEquals("s1/addxy", addWithName.getVarName()); + assertTrue(add.name(), add.name().startsWith("s1/")); + assertEquals("s1/addxy", addWithName.name()); - assertTrue(merge.getVarName(), merge.getVarName().startsWith("s1/s2/")); - assertEquals("s1/s2/mmax", mergeWithName.getVarName()); + assertTrue(merge.name(), merge.name().startsWith("s1/s2/")); + assertEquals("s1/s2/mmax", mergeWithName.name()); Set allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a", - add.getVarName(), addWithName.getVarName(), merge.getVarName(), mergeWithName.getVarName())); + add.name(), addWithName.name(), merge.name(), mergeWithName.name())); Set allowedOpNames = new HashSet<>(); //Check op names: @@ -102,8 +102,8 @@ public class NameScopeTests extends BaseNd4jTest { //Check fields - Variable, SDOp, etc for(Variable v : sd.getVariables().values()){ - assertTrue(v.getVariable().getVarName(), allowedVarNames.contains(v.getVariable().getVarName())); - assertEquals(v.getName(), v.getVariable().getVarName()); + assertTrue(v.getVariable().name(), allowedVarNames.contains(v.getVariable().name())); + assertEquals(v.getName(), v.getVariable().name()); if(v.getInputsForOp() != null){ for(String s : v.getInputsForOp()){ assertTrue(s, allowedOpNames.contains(s)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java index 5d6b7bfc3..303739ea1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java @@ -108,14 +108,14 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { sd.fit(ds); } - for(String s : new String[]{"w", "b", badd.getVarName(), add.getVarName(), "l1", "l2"}){ + for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){ SDVariable gradVar = sd.getVariable(s).gradient(); assertNotNull(s, gradVar); } //Unused: assertFalse(shape.hasGradient()); try{ assertNull(shape.gradient()); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("only floating point variables")); } - for(String s : new String[]{unused1.getVarName(), unused2.getVarName(), unused3.getVarName()}){ + for(String s : new String[]{unused1.name(), unused2.name(), unused3.name()}){ assertNull(sd.getVariable(s).gradient()); } } @@ -151,20 +151,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { sd.setLossVariables("loss1"); sd.createGradFunction(); for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ - assertNotNull(v.getVarName(), v.gradient()); + assertNotNull(v.name(), v.gradient()); } for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ - assertNull(v.getVarName(), v.gradient()); + assertNull(v.name(), v.gradient()); } //Now, set to other loss function sd.setLossVariables("loss2"); sd.createGradFunction(); for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ - assertNull(v.getVarName(), v.gradient()); + assertNull(v.name(), v.gradient()); } for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ - assertNotNull(v.getVarName(), v.gradient()); + assertNotNull(v.name(), v.gradient()); } //Train the first side of the graph. The other side should remain unmodified! diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 7d17b3604..69c69388b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -16,12 +16,7 @@ package org.nd4j.autodiff.samediff; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import static org.junit.Assume.assumeNotNull; import static org.nd4j.linalg.indexing.NDArrayIndex.all; @@ -29,6 +24,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; import java.lang.reflect.Field; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -43,6 +39,7 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.OpValidationSuite; +import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; @@ -150,7 +147,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}), input); - sd.execAndEndResult(); + sd.outputAll(null); nodeA.isPlaceHolder(); } @@ -183,10 +180,10 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = avgMSE.eval(); assertEquals(1, result.length()); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); } @Test @@ -207,10 +204,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable result = sameDiff.sum(x, 1); //[1,4].sum(1) == [1] - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - INDArray exp = Nd4j.scalar(arr.sumNumber().floatValue()).reshape(1); - INDArray resultArr = result.getArr(); + INDArray resultArr = result.eval(); assertEquals(exp, resultArr); } @@ -225,7 +220,7 @@ public class SameDiffTests extends BaseNd4jTest { Map m = new HashMap<>(); m.put("x", x); m.put("y", y); - INDArray out = sameDiff.exec(m, Collections.singletonList(output.getVarName())).get(output.getVarName()); + INDArray out = sameDiff.output(m, Collections.singletonList(output.name())).get(output.name()); INDArray outputAssertion = x.add(y); assertEquals(outputAssertion, out); } @@ -242,10 +237,9 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sdTargets = sameDiff.var("targets", targets); SDVariable res = sameDiff.loss().weightedCrossEntropyWithLogits(sdTargets, sdInputs, sdWeights); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - INDArray resultArray = res.getArr(); - assertArrayEquals(new long[]{1, 5}, res.getShape()); + INDArray resultArray = res.eval(); + assertArrayEquals(new long[]{1, 5}, resultArray.shape()); } @Test @@ -269,7 +263,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(labelArr, label); - INDArray result = sd.execAndEndResult(); + INDArray result = score.eval(); assertNotNull(result); //*** Fails Here - Null output *** assertEquals(1, result.length()); } @@ -283,8 +277,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable result = sameDiff.math().cosineSimilarity(x, y, 1); SDVariable addResult = result.add(result); SDVariable finalReshape = sameDiff.reshape(addResult, 1, 2); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertArrayEquals(new long[]{1, 2}, finalReshape.getShape()); + Map out = sameDiff.output(Collections.emptyMap(), finalReshape.name()); + assertArrayEquals(new long[]{1, 2}, out.get(finalReshape.name()).shape()); } @Test @@ -295,8 +289,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable y = sameDiff.var("y", arr); SDVariable result = sameDiff.mmul(x, y); SDVariable otherResult = result.add(result); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertArrayEquals(new long[]{2, 2}, result.getShape()); + Map m = sameDiff.outputAll(null); + assertArrayEquals(new long[]{2, 2}, m.get(result.name()).shape()); } @@ -307,7 +301,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable sigmoid = sameDiff.nn().sigmoid("s", x); INDArray assertion = Transforms.sigmoid(arr); - INDArray eval = sameDiff.exec(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); + INDArray eval = sameDiff.output(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); assertEquals(assertion, eval); } @@ -317,8 +311,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var = sameDiff.var("one", Nd4j.scalar(1.0)); SDVariable variable2 = sameDiff.var("two", Nd4j.scalar(1.0)); val sum = var.add(variable2); - sum.eval(); - assertArrayEquals(new long[0], sum.getShape()); + INDArray out = sum.eval(); + assertArrayEquals(new long[0], out.shape()); } @@ -329,8 +323,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable secondVar = second.var(firstVar); - assertTrue(firstVar.getArr() == secondVar.getArr()); - assertEquals(firstVar.getVarName(), secondVar.getVarName()); + assertEquals(firstVar.getArr(), secondVar.getArr()); + assertEquals(firstVar.name(), secondVar.name()); } @@ -343,8 +337,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable secondVar = second.var(firstVar); assumeNotNull(firstVar.getArr()); - assertTrue(firstVar.getArr() == secondVar.getArr()); - assertEquals(firstVar.getVarName(), secondVar.getVarName()); + assertEquals(firstVar.getArr(), secondVar.getArr()); + assertEquals(firstVar.name(), secondVar.name()); } @@ -369,7 +363,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x = sameDiff.var("x", arr); SDVariable s = x.mul("s", x); INDArray assertion = arr.mul(arr); - INDArray eval = sameDiff.exec(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); + INDArray eval = sameDiff.output(Collections.singletonMap("x", arr), Collections.singletonList("s")).get("s"); assertEquals(assertion, eval); } @@ -386,7 +380,7 @@ public class SameDiffTests extends BaseNd4jTest { Map vars = new HashMap<>(); vars.put("x", arr); vars.put("y", yArr); - INDArray eval = sameDiff.exec(vars, Collections.singletonList(sigmoid.getVarName())).get(sigmoid.getVarName()); + INDArray eval = sameDiff.output(vars, Collections.singletonList(sigmoid.name())).get(sigmoid.name()); assertEquals(assertion, eval); } @@ -413,7 +407,7 @@ public class SameDiffTests extends BaseNd4jTest { public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.div(y)}; + return new SDVariable[]{x.div("out", y)}; } }, xAndY); @@ -422,14 +416,14 @@ public class SameDiffTests extends BaseNd4jTest { public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.rdiv(y)}; + return new SDVariable[]{x.rdiv("out", y)}; } }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0); INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25); - assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult()); - assertEquals(assertionForRDiv, sameDiff.getFunction("rdiv").execAndEndResult()); + assertEquals(assertionForDiv, sameDiff.getFunction("div").outputSingle(null, "out")); + assertEquals(assertionForRDiv, sameDiff.getFunction("rdiv").outputSingle(null, "out")); } @@ -444,12 +438,12 @@ public class SameDiffTests extends BaseNd4jTest { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { SDVariable x = sameDiff.var("x", inputs.get("x")); - return new SDVariable[]{sameDiff.math().neg(x)}; + return new SDVariable[]{sameDiff.math().neg("out", x)}; } }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, -1); - assertEquals(assertionForDiv, sameDiff.getFunction("neg").execAndEndResult()); + assertEquals(assertionForDiv, sameDiff.getFunction("neg").outputSingle(null, "out")); } @@ -470,7 +464,7 @@ public class SameDiffTests extends BaseNd4jTest { }, inputs); INDArray assertion = sumInput.sum(1); - INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")) + INDArray out = sameDiff.getFunction("sum").output(Collections.emptyMap(), Collections.singletonList("sum")) .get("sum"); assertEquals(assertion, out); } @@ -483,7 +477,7 @@ public class SameDiffTests extends BaseNd4jTest { */ SameDiff sameDiff = SameDiff.create(); SDVariable sdVariable = sameDiff.var("one", Nd4j.scalar(1.0)); - assumeNotNull(sameDiff.getVariable(sdVariable.getVarName())); + assumeNotNull(sameDiff.getVariable(sdVariable.name())); } @@ -499,7 +493,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable sdVariable = sameDiff.var("one", Nd4j.scalar(1.0)); SDVariable add = sdVariable.add(1.0); - assertEquals(sameDiff.getVariable(add.getVarName()), add); + assertEquals(sameDiff.getVariable(add.name()), add); } @@ -507,17 +501,8 @@ public class SameDiffTests extends BaseNd4jTest { public void testUpdateVariable() { SameDiff sameDiff = SameDiff.create(); SDVariable one = sameDiff.one("one", new long[]{1, 1}); - sameDiff.updateVariableName(one.getVarName(), "one-diff"); - assertEquals(one.getArr(), sameDiff.getVariable("one-diff").getArr()); - } - - - @Test(expected = IllegalStateException.class) - public void testPlaceHolderWithFullShape() { - val sd = SameDiff.create(); - val placeholder = sd.placeHolder("somevar", DataType.FLOAT, 2, 2); - assertTrue(sd.isPlaceHolder(placeholder.getVarName())); - sd.resolveVariablesWith(Collections.singletonMap(placeholder.getVarName(), Nd4j.linspace(1, 4, 4))); + one.rename("one-diff"); + assertEquals(one.eval(), sameDiff.getVariable("one-diff").eval()); } @@ -544,143 +529,6 @@ public class SameDiffTests extends BaseNd4jTest { } - - @Test - public void testIfStatementTrueBodyBackwards() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //true body trigger - SDVariable[] firstInputs = new SDVariable[]{ - sameDiff.var("one", new long[]{1, 1}) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff grad = sameDiff.getFunction("grad"); - /* If ifBlock = (If) grad.getFunction(new long[]{1},new long[]{2}); - SameDiff assertComparision = SameDiff.create(); - SDVariable initialInput = assertComparision.zero("zero",new long[]{1,1}); - initialInput.addi(1.0); - assumeNotNull(ifBlock.getTrueBodyExecuted()); - assertTrue(ifBlock.getTrueBodyExecuted()); - assertEquals(Nd4j.scalar(1.00),initialInput.getArr()); - assertEquals(Nd4j.scalar(1.0),ifBlock.getLoopBodyExecution().getVariableForVertexId(2).getArr()); -*/ - } - - - @Test - public void testIfStatementTrueBody() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //true body trigger - SDVariable[] firstInputs = new SDVariable[]{ - sameDiff.var("one", new long[]{1, 1}) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); - sameDiff.exec(Collections.emptyMap()); - } - - - @Test - public void testIfStatementFalseBody() { - OpValidationSuite - .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations - SameDiff sameDiff = SameDiff.create(); - - SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sum = sameDiff.sum(variableInputs[0], Integer.MAX_VALUE); - SDVariable result = sameDiff.gt(sum, 1.0); - return new SDVariable[]{result}; - } - }; - - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable add = variableInputs[0].add(1.0); - return new SDVariable[]{add}; - } - }; - - SameDiffFunctionDefinition falseBody = new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable sub = variableInputs[0].sub(1.0); - return new SDVariable[]{sub}; - } - }; - - //false body trigger - SDVariable[] secondInputs = new SDVariable[]{ - sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1})) - - }; - - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, secondInputs); - - sameDiff.exec(Collections.emptyMap()); - } - - @Test public void testAutoBroadcastAddMatrixVector() { SameDiff sameDiff = SameDiff.create(); @@ -690,8 +538,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable left = sameDiff.var("arr", arr); SDVariable right = sameDiff.var("row", row); SDVariable test = left.add(right); - sameDiff.exec(Collections.emptyMap(), sameDiff.outputs()); - assertEquals(assertion, test.getArr()); + assertEquals(assertion, test.eval()); } @@ -740,7 +587,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(w, sd.getVariable("W")); sd.associateArrayWithVariable(b, sd.getVariable("b")); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertArrayEquals(new long[]{minibatch, nOut}, outArr.shape()); } @@ -780,7 +627,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.associateArrayWithVariable(weightsArr, weights); sd.associateArrayWithVariable(biasArr, bias); - INDArray result = sd.execAndEndResult(); + INDArray result = avgMSE.eval(); } @@ -798,7 +645,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inArr = Nd4j.create(10, 9, 8); sd.associateArrayWithVariable(inArr, in); - INDArray out = sd.execAndEndResult(); //Exception here, dim0=-1 case only + INDArray out = mean2.eval(); long[] shape = out.shape(); assertArrayEquals(msg, new long[]{10}, shape); @@ -813,10 +660,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", new long[]{10, 9, 8}); SDVariable mean1 = sd.mean(in, 2); //[10,9] out SDVariable mean2 = sd.mean(mean1, 1); //[10] out - sd.execAndEndResult(); + Map m = sd.output((Map)null, mean1.name(), mean2.name()); - INDArray m1 = mean1.getArr(); - INDArray m2 = mean2.getArr(); + INDArray m1 = m.get(mean1.name()); + INDArray m2 = m.get(mean2.name()); assertArrayEquals(new long[]{10, 9}, m1.shape()); assertArrayEquals(new long[]{10}, m2.shape()); @@ -829,20 +676,20 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd2 = SameDiff.create(); SDVariable in2 = sd2.var("in", new long[]{10, 9, 8}); SDVariable meanA = sd2.mean(in2, 0); //[9,8] out - sd2.exec(null, sd2.outputs()); - assertArrayEquals(new long[]{9, 8}, meanA.getShape()); + Map out = sd2.outputAll(null); + assertArrayEquals(new long[]{9, 8}, out.get(meanA.name()).shape()); SDVariable meanB = sd2.mean(meanA, 0); //[8] out - sd2.exec(null, sd2.outputs()); - assertArrayEquals(new long[]{8}, meanB.getShape()); + Map m = sd2.outputAll(null); + assertArrayEquals(new long[]{8}, m.get(meanB.name()).shape()); - assertArrayEquals(meanA.getShape(), meanA.getArr().shape()); - assertArrayEquals(meanB.getShape(), meanB.getArr().shape()); + assertArrayEquals(new long[]{9, 8}, m.get(meanA.name()).shape()); + assertArrayEquals(new long[]{8}, m.get(meanB.name()).shape()); - sd2.exec(Collections.emptyMap(), sd2.outputs()); + m = sd2.outputAll(null); - INDArray mA = meanA.getArr(); - INDArray mB = meanB.getArr(); + INDArray mA = m.get(meanA.name()); + INDArray mB = m.get(meanB.name()); assertArrayEquals(new long[]{9, 8}, mA.shape()); assertArrayEquals(new long[]{8}, mB.shape()); @@ -858,10 +705,10 @@ public class SameDiffTests extends BaseNd4jTest { val f = m.add(2.0); val s = in2.add(5.0); - val arr = sd.execSingle(null, s.getVarName()); - log.info("Result M: {}", m.getArr()); - log.info("Result F: {}", f.getArr()); - log.info("Result S: {}", s.getArr()); + Map map = sd.outputAll(null); + log.info("Result M: {}", map.get(m.name())); + log.info("Result F: {}", map.get(f.name())); + log.info("Result S: {}", map.get(s.name())); } @Test @@ -910,8 +757,8 @@ public class SameDiffTests extends BaseNd4jTest { val input2 = sd.var("input2", vector); val output = sd .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); - output.eval(); - assertArrayEquals(new long[]{3, 1}, output.getShape()); + INDArray out = output.eval(); + assertArrayEquals(new long[]{3, 1}, out.shape()); } @Test @@ -943,10 +790,8 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("initial", Nd4j.linspace(1, 4, 4, DataType.FLOAT).reshape(2, 2)); SDVariable sum = sameDiff.sum(twoByTwo, Integer.MAX_VALUE); - sameDiff.execBackwards(Collections.emptyMap()); - SameDiff grad = sameDiff.getFunction("grad"); - SDVariable gradArr = sameDiff.grad(twoByTwo.getVarName()); - assertEquals(Nd4j.ones(DataType.FLOAT, 2, 2), gradArr.getArr()); + Map grads = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + assertEquals(Nd4j.ones(DataType.FLOAT, 2, 2), grads.get(twoByTwo.name())); } @@ -966,7 +811,7 @@ public class SameDiffTests extends BaseNd4jTest { }, params); SameDiff logisticGraph = sameDiff.getFunction("rsubop"); - INDArray output = logisticGraph.exec(params, Collections.singletonList("rsub")).get("rsub"); + INDArray output = logisticGraph.output(params, Collections.singletonList("rsub")).get("rsub"); assertEquals(Nd4j.ones(4).muli(-1), output); } @@ -999,7 +844,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions"); Map inputsSubset = new HashMap<>(); inputsSubset.put("y", inputs.get("y")); - INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub"); + INDArray output = logisticGraph.output(inputsSubset, Collections.singletonList("rsub")).get("rsub"); INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4, 1}); assertEquals(assertion, output); @@ -1057,7 +902,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("first", Nd4j.linspace(1, 4, 4).reshape('c', 2, 2)); SDVariable add = twoByTwo.add(1.0); - INDArray test = sameDiff.execAndEndResult(); + INDArray test = add.eval(); INDArray assertion = Nd4j.linspace(1, 4, 4).reshape('c', 2, 2).add(1.0); assertEquals(assertion, test); } @@ -1070,8 +915,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sdVariable = sameDiff.var("ones", ones); SDVariable result = sdVariable.add(1.0); SDVariable total = sameDiff.sum(result, Integer.MAX_VALUE); - sameDiff.execAndEndResult(); - assertEquals(56, total.getArr().getDouble(0), 1e-1); + INDArray out = total.eval(); + assertEquals(56, out.getDouble(0), 1e-1); } @@ -1097,11 +942,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expZ = expMmul.addRowVector(iBias); INDArray expOut = Transforms.sigmoid(expZ, true); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map m = sd.outputAll(Collections.emptyMap()); - assertEquals(expMmul, mmul.getArr()); - assertEquals(expZ, z.getArr()); - assertEquals(expOut, out.getArr()); + assertEquals(expMmul, m.get(mmul.name())); + assertEquals(expZ, m.get(z.name())); + assertEquals(expOut, m.get(out.name())); } @Test @@ -1178,8 +1023,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sqDiff = diff.mul("sqDiff", diff); SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE); //Loss function... - sd.exec(Collections.emptyMap(), sd.outputs()); - INDArray outAct = sd.getVariable("out").getArr(); + Map m = sd.output(Collections.emptyMap(), "out"); + INDArray outAct = m.get("out"); assertEquals(a.toString(), outExp, outAct); // L = sum_i (label - out)^2 @@ -1187,10 +1032,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray dLdOutExp = outExp.sub(labelArr).mul(2); INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst(); - sd.execBackwards(Collections.emptyMap()); - SameDiff gradFn = sd.getFunction("grad"); - INDArray dLdOutAct = gradFn.getVariable("out-grad").getArr(); - INDArray dLdInAct = gradFn.getVariable("in-grad").getArr(); + Map grads = sd.calculateGradients(null, "out", "in"); +// sd.execBackwards(Collections.emptyMap()); +// SameDiff gradFn = sd.getFunction("grad"); + INDArray dLdOutAct = grads.get("out"); + INDArray dLdInAct = grads.get("in"); assertEquals(a.toString(), dLdOutExp, dLdOutAct); assertEquals(a.toString(), dLdInExp, dLdInAct); @@ -1234,7 +1080,7 @@ public class SameDiffTests extends BaseNd4jTest { 0.0, 1); out = sd.nn().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertArrayEquals(new long[]{1, 10}, outArr.shape()); } @@ -1256,10 +1102,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.cnn().localResponseNormalization(sdInput, lrn); SDVariable sdOut = sd.math().tanh("out", out); - sd.exec(Collections.emptyMap(), sd.outputs()); + Map map = sd.output(Collections.emptyMap(), "out", out.name()); for (int i = 0; i < 4; i++) { - assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); + assertEquals(1, map.get(out.name()).get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); } } @@ -1280,10 +1126,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = mean.add(variance); SDVariable out = sd.math().tanh("out", sum); - INDArray outArr = sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray meanArray = mean.getArr(); - INDArray varArray = variance.getArr(); + INDArray meanArray = m.get(mean.name()); + INDArray varArray = m.get(variance.name()); assertEquals(meanArray.getDouble(0), 2.5, 1e-5); assertEquals(varArray.getDouble(0), 1.25, 1e-5); @@ -1309,15 +1155,14 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = normMean.add(normVariance); SDVariable out = sd.math().tanh("out", sum); - INDArray outArr = sd.execAndEndResult(); + Map m = sd.outputAll(null); - INDArray meanArray = normMean.getArr(); - INDArray varArray = normVariance.getArr(); + INDArray meanArray = m.get(normMean.name()); + INDArray varArray = m.get(normVariance.name()); assertEquals(meanArray.getDouble(0, 0), 1, 1e-5); assertEquals(meanArray.getDouble(0, 1), 2, 1e-5); assertArrayEquals(meanArray.shape(), varArray.shape()); - } @@ -1353,7 +1198,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.cnn().depthWiseConv2d(in, dW, b, c); out = sd.math().tanh("out", out); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 val outShape = outArr.shape(); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); @@ -1369,11 +1214,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable mean = sd.mean("mean", v); - INDArray out = sd.execAndEndResult(); + INDArray out = mean.eval(); assertEquals(out, arr.mean(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = m.get("in"); //If L = mean(in) //then dL/dIn = 1/N @@ -1391,11 +1236,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable mean = sd.sum("sum", v); - INDArray out = sd.execAndEndResult(); + INDArray out = mean.eval(); assertEquals(out, arr.sum(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = m.get("in"); //If L = sum(in) //then dL/dIn = 1 @@ -1414,10 +1259,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable stdev = sd.standardDeviation("stdev", v, biasCorrected); - INDArray out = sd.execAndEndResult(); + INDArray out = stdev.eval(); assertEquals(out, arr.std(biasCorrected, Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = stdev(in) @@ -1444,11 +1289,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable var = sd.variance("var", v, biasCorrected); - INDArray out = sd.execAndEndResult(); + INDArray out = var.eval(); assertEquals(out, arr.var(biasCorrected, Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); - INDArray dLdIn = sd.grad("in").getArr(); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); + INDArray dLdIn = g.get("in"); //If L = var(in) //then dL/dIn = 2/(N-1) * (in-mean) @@ -1472,17 +1317,17 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable min = sd.min("min", v); - INDArray out = sd.execAndEndResult(); + INDArray out = min.eval(); assertEquals(out, arr.min(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = min(in) //then dL/dIn = 1 if in_i == min(in) or 0 otherwise //Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType()); + INDArray exp = Nd4j.exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType()); assertEquals(exp, dLdIn); } @@ -1497,16 +1342,16 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v = sd.var("in", arr); SDVariable min = sd.max("max", v); - INDArray out = sd.execAndEndResult(); + INDArray out = min.eval(); assertEquals(out, arr.max(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = max(in) //then dL/dIn = 1 if in_i == max(in) or 0 otherwise - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE); + INDArray exp = Nd4j.exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE); assertEquals(exp, dLdIn); } @@ -1522,10 +1367,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable prod = sd.prod("prod", v); double p = arr.prodNumber().doubleValue(); - INDArray out = sd.execAndEndResult(); + INDArray out = prod.eval(); assertEquals(out, arr.prod(Integer.MAX_VALUE)); - sd.execBackwards(Collections.emptyMap()); + Map g = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); INDArray dLdIn = sd.grad("in").getArr(); //If L = prod(in) @@ -1551,8 +1396,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expOut = in.getArr().sub(label.getArr()); expOut.muli(expOut); - System.out.println("About to exec"); - INDArray out = sd.execAndEndResult(); //JVM crash + INDArray out = sqDiff.eval(); assertEquals(out, expOut); } @@ -1565,7 +1409,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", Nd4j.create(2, 3)); SDVariable expanded = sd.f().expandDims(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expanded.eval(); switch (i) { case 0: assertArrayEquals(new long[]{1, 2, 3}, out.shape()); @@ -1588,11 +1432,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var0 = sd.var("in", DataType.DOUBLE, new long[]{3, 4}); SDVariable out = sd.zerosLike("out", var0); - INDArray out1 = sd.execAndEndResult(); + INDArray out1 = out.eval(); assertEquals(Nd4j.zeros(3, 4), out1); sd.associateArrayWithVariable(Nd4j.create(3, 4), var0); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = out.eval(); assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), out2); } @@ -1602,11 +1446,11 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable var0 = sd.var("in", new long[]{3, 4}); SDVariable out = sd.onesLike("out", var0); - INDArray out1 = sd.execAndEndResult(); + INDArray out1 = out.eval(); assertEquals(Nd4j.ones(3, 4), out1); sd.associateArrayWithVariable(Nd4j.create(3, 4), var0); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = out.eval(); assertEquals(Nd4j.ones(3, 4), out2); } @@ -1618,12 +1462,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable ones = sd.onesLike("ones", var0); SDVariable out = sd.sum("oun", ones); - INDArray outArr = sd.execAndEndResult(); + INDArray outArr = out.eval(); assertEquals(Nd4j.scalar(12.0), outArr); - sd.execBackwards(Collections.emptyMap()); + Map m = sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet()); - assertEquals(Nd4j.create(3, 4), sd.grad("in").getArr()); + assertEquals(Nd4j.create(3, 4), m.get("in")); } @@ -1634,7 +1478,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray a = Nd4j.rand(new long[]{3, 4, 5}); INDArray b = Nd4j.rand(new long[]{3, 4, 5}); - INDArray expOut = Nd4j.getExecutioner().exec(new ManhattanDistance(a, b, 0)); + INDArray expOut = Nd4j.exec(new ManhattanDistance(a, b, 0)); val expShape = new long[]{4, 5}; @@ -1662,7 +1506,7 @@ public class SameDiffTests extends BaseNd4jTest { double maxSum = max.sumNumber().doubleValue(); double jd = 1.0 - minSum / maxSum; - INDArray out = sd.execAndEndResult(); + INDArray out = jaccard.eval(); assertEquals(1, out.length()); assertEquals(jd, out.getDouble(0), 1e-6); @@ -1710,36 +1554,36 @@ public class SameDiffTests extends BaseNd4jTest { case 4: t = sd.gte(in1, in2); expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); + Nd4j.exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); break; case 5: t = sd.lte(in1, in2); expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); + Nd4j.exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut})); break; case 6: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.or(ia, ib); break; case 7: t = sd.max(in1, in2); - expOut = Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]; + expOut = Nd4j.exec(new Max(ia, ib, ia.dup()))[0]; break; case 8: t = sd.min(in1, in2); - expOut = Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]; + expOut = Nd4j.exec(new Min(ia, ib, ia.dup()))[0]; break; case 9: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.and(ia, ib); break; case 10: - ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); - ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); + ia = Nd4j.exec(new BernoulliDistribution(ia, 0.5)); + ib = Nd4j.exec(new BernoulliDistribution(ib, 0.5)); t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); expOut = Transforms.xor(ia, ib); break; @@ -1748,7 +1592,7 @@ public class SameDiffTests extends BaseNd4jTest { } log.info("Executing: " + i); - INDArray out = sd.execAndEndResult(); + INDArray out = t.eval(); assertEquals(expOut, out); } @@ -1776,22 +1620,22 @@ public class SameDiffTests extends BaseNd4jTest { switch (i) { case 0: t = sd.math().isNonDecreasing(in1); - Nd4j.getExecutioner().exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); break; case 1: t = sd.math().isStrictlyIncreasing(in1); - Nd4j.getExecutioner().exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); break; case 2: t = sd.isNumericTensor(in1); - Nd4j.getExecutioner().exec(new IsNumericTensor(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNumericTensor(new INDArray[]{ia}, new INDArray[]{expOut})); break; default: throw new RuntimeException(); } log.info("Executing: " + i); - INDArray out = sd.execAndEndResult(); + INDArray out = t.eval(); assertEquals(expOut, out); } @@ -1810,7 +1654,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable expand = sd.f().expandDims(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expand.eval(); INDArray expOut; switch (i) { @@ -1851,7 +1695,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable squeeze = sd.f().squeeze(in, i); - INDArray out = sd.execAndEndResult(); + INDArray out = squeeze.eval(); INDArray expOut; switch (i) { @@ -1890,7 +1734,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable expand = sd.expandDims(in, i); SDVariable squeeze = sd.squeeze(expand, i); - INDArray out = sd.execAndEndResult(); + INDArray out = squeeze.eval(); String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); @@ -1918,7 +1762,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable squeeze = sd.squeeze(in, i); SDVariable expand = sd.expandDims(squeeze, i); - INDArray out = sd.execAndEndResult(); + INDArray out = expand.eval(); String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); @@ -1937,8 +1781,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable labelsVar = sd.constant("labels", labels); SDVariable predictionsVar = sd.constant("predictions", pred); SDVariable weightsVar = sd.constant("weights", weights); - sd.math().confusionMatrix("cm", labelsVar, predictionsVar, numClasses, weightsVar); - INDArray out = sd.execAndEndResult(); + SDVariable cm = sd.math().confusionMatrix("cm", labelsVar, predictionsVar, numClasses, weightsVar); + INDArray out = cm.eval(); INDArray exp = Nd4j.create(new float[][]{{0, 0, 0, 0, 0}, {0, 0, 10, 0, 0}, {0, 0, 100, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 1000}}).castTo(DataType.INT); @@ -1957,7 +1801,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable argmax = sd.argmax("argmax", in, dim); - INDArray out = sd.execAndEndResult(); + INDArray out = argmax.eval(); INDArray exp = Nd4j.argMax(inArr, dim); @@ -1977,7 +1821,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", inArr); SDVariable argmin = sd.argmin("argmin", in, dim); - INDArray out = sd.execAndEndResult(); + INDArray out = argmin.eval(); INDArray exp = Nd4j.argMax(inArr.neg(), dim); //argmin(x) == argmax(-x) @@ -2140,7 +1984,7 @@ public class SameDiffTests extends BaseNd4jTest { System.out.println(in); INDArray exp = Nd4j.pullRows(in, 1, new int[]{0, 1, 5}); //Along dimension 1 -> equiv to "indexes for axis 0" - INDArray act = sd.execAndEndResult(); + INDArray act = gather.eval(); assertEquals(exp, act); } @@ -2158,7 +2002,7 @@ public class SameDiffTests extends BaseNd4jTest { .addOutputs(out) .build(); - Nd4j.getExecutioner().exec(op); + Nd4j.exec(op); System.out.println(out); @@ -2194,10 +2038,9 @@ public class SameDiffTests extends BaseNd4jTest { INDArray expNaN = Nd4j.create(new boolean[]{false, false}); SDVariable isnan = sd.math().isNaN(in); - sd.exec(Collections.emptyMap(), sd.outputs()); - assertEquals(expFinite, finite.getArr()); - assertEquals(expInfinite, infinite.getArr()); - assertEquals(expNaN, isnan.getArr()); + assertEquals(expFinite, finite.eval()); + assertEquals(expInfinite, infinite.eval()); + assertEquals(expNaN, isnan.eval()); } @@ -2291,7 +2134,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable write0 = tensorArray.write(var2, 0, var1); SDVariable write1 = tensorArray.write(write0, 1, var2); SDVariable result = tensorArray.stack(write1); - sd.exec(null, result.getVarName()); + sd.output((Map)null, result.name()); assertEquals(Nd4j.pile(arr1, arr2), result.eval()); } @@ -2392,12 +2235,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = in.sum(1); INDArray exp = in.getArr().sum(1).reshape(3); - INDArray out = sd.execAndEndResult(); + INDArray out = sum.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1, 20, 20).reshape(5, 4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = sum.eval(); assertArrayEquals(new long[]{5}, out2.shape()); exp = in.getArr().sum(1).reshape(5); @@ -2412,12 +2255,12 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable sum = in.argmax(1); INDArray exp = in.getArr().argMax(1).reshape(3); - INDArray out = sd.execAndEndResult(); + INDArray out = sum.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1, 20, 20).reshape(5, 4)); - INDArray out2 = sd.execAndEndResult(); + INDArray out2 = sum.eval(); assertArrayEquals(new long[]{5}, out2.shape()); exp = in.getArr().argMax(1).reshape(5); @@ -2436,12 +2279,11 @@ public class SameDiffTests extends BaseNd4jTest { gradMap.put("out", externalGrad); ExternalErrorsFunction fn = sd.f().externalErrors(out); - sd.execAndEndResult(); Map m = new HashMap<>(); m.put("out-grad", externalGrad); - sd.execBackwards(m); + Map grads = sd.calculateGradients(m, sd.getVariables().keySet()); - INDArray gradVar = var.getGradient().getArr(); + INDArray gradVar = grads.get(var.name()); assertEquals(externalGrad.mul(0.5), gradVar); @@ -2449,7 +2291,7 @@ public class SameDiffTests extends BaseNd4jTest { externalGrad = Nd4j.linspace(1, 12, 12).reshape(3, 4).muli(10); m.put("out-grad", externalGrad); - sd.execBackwards(m); + grads = sd.calculateGradients(m, sd.getVariables().keySet()); gradVar = var.getGradient().getArr(); @@ -2468,73 +2310,26 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.mmul(in, w); SDVariable loss = out.std("out", true); - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); + INDArray outArr = loss.eval(); +// sd.execBackwards(Collections.emptyMap()); + Map grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); + origGrad.put("in", grads.get(in.name()).dup()); + origGrad.put("w", grads.get(w.name()).dup()); + origGrad.put("out", grads.get(out.name()).dup()); in.getArr().assign(Nd4j.rand(in.getArr().shape())); - INDArray outArr2 = sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); + INDArray outArr2 = loss.eval(); +// sd.execBackwards(Collections.emptyMap()); + grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), in.gradient().getArr()); - assertNotEquals(origGrad.get("w"), w.gradient().getArr()); - assertNotEquals(origGrad.get("out"), out.gradient().getArr()); - } - - @Test - public void testUpdatingInplaceFwd() { - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); - SDVariable w = sd.var("w", Nd4j.linspace(1, 20, 20).reshape(4, 5)); - SDVariable out = sd.mmul(in, w); - SDVariable loss = out.std("out", true); - - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); - - in.getArr().muli(5); - - //check gradient function copy of array - SameDiff sdGrad = sd.getFunction("grad"); - INDArray gradArrIn = sdGrad.getVariable("in").getArr(); - assertEquals(in.getArr(), gradArrIn); - } - - @Test - public void testUpdatingAssociateFwd() { - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); - SDVariable w = sd.var("w", Nd4j.linspace(1, 20, 20).reshape(4, 5)); - SDVariable out = sd.mmul(in, w); - SDVariable loss = out.std("out", true); - - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("w", w.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); - - INDArray newIn = in.getArr().dup().muli(5); - in.setArray(newIn); - - //check gradient function copy of array - SameDiff sdGrad = sd.getFunction("grad"); - INDArray gradArrIn = sdGrad.getVariable("in").getArr(); - assertEquals(newIn, gradArrIn); + assertNotEquals(origGrad.get("in"), grads.get(in.name())); + assertNotEquals(origGrad.get("w"), grads.get(w.name())); + assertNotEquals(origGrad.get("out"), grads.get(out.name())); } @Test @@ -2544,27 +2339,25 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = in.mul(2.0); SDVariable loss = out.std("out", true); - INDArray outArr = sd.execAndEndResult().dup(); - sd.execBackwards(Collections.emptyMap()); - - SameDiff sdGrad = sd.getFunction("grad"); + INDArray outArr = loss.eval(); + Map grads = sd.calculateGradients(null, in.name(), out.name()); Map origGrad = new HashMap<>(); - origGrad.put("in", in.gradient().getArr().dup()); - origGrad.put("out", out.gradient().getArr().dup()); + origGrad.put("in", grads.get(in.name()).dup()); + origGrad.put("out", grads.get(out.name()).dup()); double stdBefore = in.getArr().stdNumber().doubleValue(); in.getArr().assign(Nd4j.rand(in.getArr().shape())); double stdAfter = in.getArr().stdNumber().doubleValue(); System.out.println("Before vs. after: " + stdBefore + ", " + stdAfter); - INDArray outArr2 = sd.execAndEndResult(); - sd.execBackwards(Collections.emptyMap()); + INDArray outArr2 = loss.eval(); + grads = sd.calculateGradients(null, in.name(), out.name()); assertNotEquals(outArr, outArr2); //Ensure gradients are also changed: - assertNotEquals(origGrad.get("in"), in.gradient().getArr()); - assertNotEquals(origGrad.get("out"), out.gradient().getArr()); + assertNotEquals(origGrad.get("in"), grads.get(in.name())); + assertNotEquals(origGrad.get("out"), grads.get(out.name())); } @Test @@ -2589,10 +2382,10 @@ public class SameDiffTests extends BaseNd4jTest { Map phMap = new HashMap<>(); phMap.put(fn.getGradPlaceholderName(), grad); - log.info("--------------- sd.execAndEndResult() ---------------"); - sd.execAndEndResult(); + log.info("--------------- out.eval() ---------------"); + out.eval(); log.info("--------------- sd.execBackwards() #1 ---------------"); - sd.execBackwards(phMap); + sd.calculateGradients(phMap, "in", "W", "b"); log.info("--------------- sd.execBackwards() #2 ---------------"); System.out.println(sd.getFunction("grad").summary()); @@ -2601,8 +2394,8 @@ public class SameDiffTests extends BaseNd4jTest { grad = Nd4j.linspace(1, 8, 8).reshape(2, 4); phMap.put(fn.getGradPlaceholderName(), grad); - sd.execBackwards(phMap); - INDArray inGrad = in.getGradient().getArr(); + Map grads = sd.calculateGradients(phMap, sd.getVariables().keySet()); + INDArray inGrad = grads.get(in.name()); assertArrayEquals(new long[]{2, 5}, inGrad.shape()); } @@ -2658,9 +2451,8 @@ public class SameDiffTests extends BaseNd4jTest { Map placeholders = new HashMap<>(); placeholders.put("x", x); placeholders.put("y", y); - sd.createGradFunction(); //Otherwise: xSd.gradient() etc won't be defined - sd.execBackwards(placeholders, Arrays.asList(xSd.gradient().getVarName(), ySd.gradient().getVarName())); - INDArray xGradientEnforced = add.getGradient().getArr(true); + Map grads = sd.calculateGradients(placeholders, xSd.name(), ySd.name()); + INDArray xGradientEnforced = grads.get("x"); assertNotNull(xGradientEnforced); } @@ -2723,6 +2515,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3); in.setArray(inArr); INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4); + in2.setArray(inArr2); TrainingConfig c = TrainingConfig.builder() .updater(new Adam(0.1)) @@ -2778,7 +2571,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out2 = tanh.eval(); - assertEquals(out, out2); + assertNotEquals(out, out2); assertEquals(VariableType.VARIABLE, w.getVariableType()); assertEquals(VariableType.VARIABLE, b.getVariableType()); assertEquals(VariableType.ARRAY, add.getVariableType()); @@ -2918,12 +2711,12 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd = SameDiff.create(); SDVariable linspace = sd.linspace("at", DataType.DOUBLE, 1, 15, 15); SDVariable a = sd.reshape("a", linspace, 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5)); SDVariable out = a.mul(b); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, "a"); out.eval(); sd.grad("a").eval(); @@ -2938,12 +2731,12 @@ public class SameDiffTests extends BaseNd4jTest { public void testNonScalarOutput2() { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5)); SDVariable out = a.mul(b).mean(1); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, Lists.asList("a", new String[]{})); //System.out.println(out.eval()); INDArray actGrad = sd.grad("a").eval(); @@ -2958,15 +2751,16 @@ public class SameDiffTests extends BaseNd4jTest { public void testNonScalarOutput3() { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); - SDVariable b = sd.one("b", DataType.DOUBLE, 3, 5);//.add(3); + SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5));//.add(3); SDVariable out = a.mul(b).mean(0, 1); + out.markAsLoss(); out.eval(); - sd.execBackwards(null, "a"); + Map g = sd.calculateGradients(null, "a"); //System.out.println(out.eval()); - INDArray gradAct = sd.grad("a").eval(); + INDArray gradAct = g.get("a"); INDArray expGrad = Nd4j.valueArrayOf(new long[]{3, 5}, 1.0 / 12, DataType.DOUBLE); String err = OpValidation.validate(new TestCase(sd).gradientCheck(true)); @@ -2984,7 +2778,7 @@ public class SameDiffTests extends BaseNd4jTest { Map m = new HashMap<>(); m.put("b", Nd4j.rand(DataType.DOUBLE, 4, 5)); - sd.execBackwards(m, "a", "b"); + Map g = sd.calculateGradients(m, "a", "b"); b.setArray(m.get("b")); @@ -3006,7 +2800,7 @@ public class SameDiffTests extends BaseNd4jTest { final SDVariable out = a.mmul(b).add(c.mmul(d)).sum(); out.markAsLoss(); - sd.execBackwards(null); + Map g = sd.calculateGradients(null, sd.getVariables().keySet()); } @Test @@ -3018,7 +2812,7 @@ public class SameDiffTests extends BaseNd4jTest { a.add(b.add(c)).sum().markAsLoss(); - sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4, 4))); + sd.calculateGradients(Collections.singletonMap("c", Nd4j.rand(4, 4)), sd.getVariables().keySet()); assertNotNull(sd.grad("a")); assertNull(sd.grad("b")); assertNull(sd.grad("c")); @@ -3094,16 +2888,16 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5)); SDVariable v3 = v1.mmul("oldName", v2); - INDArray out = sd.execSingle(null, "oldName"); + INDArray out = sd.outputSingle(null, "oldName"); SDVariable renamed = v3.rename("newName"); assertTrue(v3 == renamed); - assertEquals("newName", renamed.getVarName()); + assertEquals("newName", renamed.name()); assertNull(sd.getVariable("oldName")); assertNotNull(sd.getVariable("newName")); - INDArray out2 = sd.execSingle(null, "newName"); + INDArray out2 = sd.outputSingle(null, "newName"); assertEquals(out, out2); } @@ -3117,7 +2911,7 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable v3 = v1.mmul("oldName", v2); SDVariable v4 = v3.std("out", false); - INDArray out = sd.execSingle(Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)), "out"); + INDArray out = sd.outputSingle(Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)), "out"); sd.setTrainingConfig(TrainingConfig.builder() .updater(new Adam(1e-3)) @@ -3133,6 +2927,7 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testPlaceholderShapeValidation() { SameDiff sd = SameDiff.create(); + SDVariable scalar = sd.scalar("scalar", 0.0f); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4); SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4); SDVariable ph3 = sd.placeHolder("ph3", DataType.FLOAT, 3, -1); @@ -3177,7 +2972,7 @@ public class SameDiffTests extends BaseNd4jTest { //Also try training: SDVariable sum = sd.math.mergeAdd(ph1, ph2, ph3, ph4); - SDVariable mean = sum.mean(); + SDVariable mean = sum.add(scalar).mean(); MultiDataSet mds = new MultiDataSet(new INDArray[]{wrongShape, wrongShape, wrongShape, wrongShape}, null); sd.setTrainingConfig(TrainingConfig.builder() @@ -3214,7 +3009,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.output(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); @@ -3224,7 +3019,7 @@ public class SameDiffTests extends BaseNd4jTest { Map allPh = new HashMap<>(); allPh.put("in", inputArr); allPh.put("label", labelUnused); - m = sd.exec(allPh, "softmax"); + m = sd.output(allPh, "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out2 = m.get("softmax"); @@ -3254,7 +3049,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.output(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); @@ -3265,7 +3060,7 @@ public class SameDiffTests extends BaseNd4jTest { allPh.put("in", inputArr); allPh.put("label", labelUnused); allPh.put("in2", Nd4j.scalar(1.0f)); - m = sd.exec(allPh, "softmax"); + m = sd.output(allPh, "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out2 = m.get("softmax"); @@ -3289,7 +3084,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.FLOAT, tanh.dataType()); assertEquals(DataType.FLOAT, stdev.dataType()); - Map out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); + Map out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); } @@ -3308,7 +3103,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, tanh.dataType()); assertEquals(DataType.DOUBLE, stdev.dataType()); - out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); + out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3337,7 +3132,7 @@ public class SameDiffTests extends BaseNd4jTest { Map ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)); - Map out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); + Map out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { if (e.getKey().equals("x") || e.getKey().equals("y")) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); @@ -3360,7 +3155,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, add.dataType()); assertEquals(DataType.DOUBLE, relu.dataType()); - out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); + out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3395,9 +3190,9 @@ public class SameDiffTests extends BaseNd4jTest { assertNotNull(w.gradient()); assertNotNull(b.gradient()); - sd.execBackwards(Collections.singletonMap("in", in)); - assertNotNull(ph.gradient().getArr()); - assertNotNull(w.gradient().getArr()); + Map m = sd.calculateGradients(Collections.singletonMap("in", in), ph.name(), w.name()); + assertNotNull(m.get(ph.name())); + assertNotNull(m.get(w.name())); } else { sd.createGradFunction(); assertNull(ph.gradient()); @@ -3411,29 +3206,28 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable a = SD.placeHolder("a", DataType.DOUBLE); - SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); - SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); + SameDiff sd = SameDiff.create(); + SDVariable a = sd.placeHolder("a", DataType.DOUBLE); + SDVariable b = sd.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = sd.var("c", Nd4j.createFromArray(9.0)); - SDVariable output = SD.ifCond("out", null, (sd) -> a.lt(b), (sd) -> c, (sd) -> c.add(5)); + SDVariable output = sd.ifCond("out", null, s -> a.lt(b), s -> c, s -> c.add(5)); Map firstBranch = Maps.newHashMap(); firstBranch.put("a", Nd4j.createFromArray(3.0)); - assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); + assertEquals(Nd4j.createFromArray(9.0), sd.output(firstBranch, "out").get("out")); Map secondBranch = Maps.newHashMap(); secondBranch.put("a", Nd4j.createFromArray(7.0)); - assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); - - //TODO complains that it can't deserialize a meta type, but there are no meta type ops here - // looks like a difference between Op.Type and OpType. Switch is saved as a OpType.LOGIC - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - - assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); - assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); + System.out.println(sd.summary()); + INDArray outArr = sd.output(secondBranch, "out").get("out"); + assertEquals(Nd4j.createFromArray(14.0), outArr); + ByteBuffer bb = sd.asFlatBuffers(false); + sd = SameDiff.fromFlatBuffers(bb); + assertEquals(Nd4j.createFromArray(9.0), sd.output(firstBranch, "out").get("out")); + assertEquals(Nd4j.createFromArray(14.0), sd.output(secondBranch, "out").get("out")); } @Test @@ -3456,7 +3250,7 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(Nd4j.createFromArray(10.0), SD.exec(null, "out").get("out")); + assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out")); } @Test @@ -3473,11 +3267,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(15, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(15, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test @@ -3499,11 +3293,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(35, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(35, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @@ -3525,11 +3319,11 @@ public class SameDiffTests extends BaseNd4jTest { INDArray out = sum[1].eval(); assertEquals(115, out.getInt(0)); - String outName = sum[1].getVarName(); + String outName = sum[1].name(); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); - assertEquals(115, SD.exec(null, outName).get(outName).getInt(0)); + assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test @@ -3609,4 +3403,54 @@ public class SameDiffTests extends BaseNd4jTest { String err = OpValidation.validate(tc); assertNull(err); } + + @Test + public void testSameDiffSeedReproducibilityVarInit() { + + SameDiff sd0 = SameDiff.create(); + SameDiff sd1 = SameDiff.create(); + Nd4j.getRandom().setSeed(12345); + SDVariable rand0 = sd0.var("random", new UniformInitScheme('c', 3), DataType.FLOAT, 3, 1); + + Nd4j.getRandom().setSeed(12345); + SDVariable rand1 = sd1.var("random", new UniformInitScheme('c', 3), DataType.FLOAT, 3, 1); + + + Nd4j.getRandom().setSeed(0); + System.out.println(rand0.eval()); + + Nd4j.getRandom().setSeed(0); + System.out.println(rand1.eval()); + + INDArray a0 = rand0.eval(); + Nd4j.getRandom().setSeed(0); + INDArray a1 = rand1.eval(); + assertEquals(a0, a1); + } + + + @Test + public void testCalculateGradientsAndOutputs(){ + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable z = in.mmul(w).add("z", b); + SDVariable softmax = sd.nn.softmax("softmax", z); + + Map ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4)); + List outputs = Arrays.asList("in", "z", "softmax"); + List grads = Arrays.asList("in", "w", "z"); + + OutAndGrad oag = sd.calculateGradientsAndOutputs(ph, outputs, grads); + Map outs = oag.getOutputs(); + Map g = oag.getGradients(); + + + Map outExp = sd.output(ph, outputs); + Map gExp = sd.calculateGradients(ph, grads); + + assertEquals(outExp, outs); + assertEquals(gExp, g); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index 8f1ab016f..e1123f5d8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -19,22 +19,28 @@ package org.nd4j.autodiff.samediff.listeners; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import java.util.Arrays; -import java.util.List; +import java.util.*; +import lombok.NonNull; import org.junit.Test; -import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.IrisDataSetIterator; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; @@ -49,6 +55,11 @@ public class ListenerTest extends BaseNd4jTest { super(backend); } + @Override + public char ordering() { + return 'c'; + } + @Test public void irisHistoryTest() { @@ -112,8 +123,237 @@ public class ListenerTest extends BaseNd4jTest { assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75); } - @Override - public char ordering() { - return 'c'; + @Test + public void testListenerCalls(){ + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable z = in.mmul(w).add(b); + SDVariable softmax = sd.nn.softmax("softmax", z); + SDVariable loss = sd.loss.logLoss("loss" ,label, softmax); + + TestListener tl = new TestListener(Operation.INFERENCE); + sd.setListeners(tl); + + //Check listener called during inference + Map phMap = Collections.singletonMap("in", Nd4j.rand(1, 4)); + + for( int i=1; i<=5; i++ ) { + INDArray out = sd.outputSingle(phMap, "softmax"); + + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationEndCount); + assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax + assertEquals(3*i, tl.opExecutionCount); + assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax outputs + assertEquals(0, tl.preUpdateCount); //Inference -> no updating + } + + //Check listener NOT called during inference when set to Operation.TRAINING + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + sd.outputSingle(phMap, "softmax"); + + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.emptyMap(), tl.operationStartCount); + assertEquals(Collections.emptyMap(), tl.operationEndCount); + assertEquals(0, tl.preOpExecutionCount); + assertEquals(0, tl.opExecutionCount); + assertEquals(0, tl.activationAvailableCount); + assertEquals(0, tl.preUpdateCount); + + //Check listener called during gradient calculation + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + phMap = new HashMap<>(); + phMap.put("in", Nd4j.rand( DataType.FLOAT, 1, 4)); + phMap.put("label", Nd4j.createFromArray(0f, 1f, 0f).reshape(1, 3)); + + for( int i=1; i<=3; i++ ) { + sd.calculateGradients(phMap, "in", "w", "b"); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount); + assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward + assertEquals(7*i, tl.opExecutionCount); + assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w) + assertEquals(0, tl.preUpdateCount); + } + + + //Check listener NOT called during gradient calculation - when listener is still set to INFERENCE mode + tl = new TestListener(Operation.INFERENCE); + sd.setListeners(tl); + for( int i=1; i<=3; i++ ) { + sd.calculateGradients(phMap, "in", "w", "b"); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.emptyMap(), tl.operationStartCount); + assertEquals(Collections.emptyMap(), tl.operationEndCount); + assertEquals(0, tl.preOpExecutionCount); + assertEquals(0, tl.opExecutionCount); + assertEquals(0, tl.activationAvailableCount); + assertEquals(0, tl.preUpdateCount); + } + + //Check fit: + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + sd.setTrainingConfig(TrainingConfig.builder() + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .updater(new Adam(1e-3)) + .build()); + + SingletonDataSetIterator dsi = new SingletonDataSetIterator(new DataSet(phMap.get("in"), phMap.get("label"))); + for( int i=1; i<=3; i++ ) { + sd.fit(dsi, 1); + assertEquals(i, tl.epochStartCount); + assertEquals(i, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(i, tl.iterationStartCount); + assertEquals(i, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount); + assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward + assertEquals(7*i, tl.opExecutionCount); + assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w) + assertEquals(2*i, tl.preUpdateCount); //w, b + } + + + //Check evaluation: + tl = new TestListener(Operation.EVALUATION); + sd.setListeners(tl); + + for( int i=1; i<=3; i++ ) { + sd.evaluate(dsi, "softmax", new Evaluation()); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationEndCount); + assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax + assertEquals(3*i, tl.opExecutionCount); + assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax + assertEquals(0, tl.preUpdateCount); //w, b + } + } + + private static class TestListener implements Listener { + + public TestListener(Operation operation){ + this.operation = operation; + } + + private final Operation operation; + + private int epochStartCount = 0; + private int epochEndCount = 0; + private int validationDoneCount = 0; + private int iterationStartCount = 0; + private int iterationDoneCount = 0; + private Map operationStartCount = new HashMap<>(); + private Map operationEndCount = new HashMap<>(); + private int preOpExecutionCount = 0; + private int opExecutionCount = 0; + private int activationAvailableCount = 0; + private int preUpdateCount = 0; + + + @Override + public ListenerVariables requiredVariables(SameDiff sd) { + return null; + } + + @Override + public boolean isActive(Operation operation) { + return this.operation == null || this.operation == operation; + } + + @Override + public void epochStart(SameDiff sd, At at) { + epochStartCount++; + } + + @Override + public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { + epochEndCount++; + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) { + validationDoneCount++; + return ListenerResponse.CONTINUE; + } + + @Override + public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) { + iterationStartCount++; + } + + @Override + public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { + iterationDoneCount++; + } + + @Override + public void operationStart(SameDiff sd, Operation op) { + if(!operationStartCount.containsKey(op)) { + operationStartCount.put(op, 1); + } else { + operationStartCount.put(op, operationStartCount.get(op) + 1); + } + } + + @Override + public void operationEnd(SameDiff sd, Operation op) { + if(!operationEndCount.containsKey(op)) { + operationEndCount.put(op, 1); + } else { + operationEndCount.put(op, operationEndCount.get(op) + 1); + } + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + preOpExecutionCount++; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + opExecutionCount++; + } + + @Override + public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) { + activationAvailableCount++; + } + + @Override + public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) { + preUpdateCount++; + } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java index ffb7c319d..157fc5b46 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java @@ -109,6 +109,8 @@ public class FileReadWriteTests extends BaseNd4jTest { for (int i = 0; i < s.outputsLength(); i++) { outputs.add(s.outputs(i)); } + if(outputs.isEmpty()) + outputs = null; assertEquals(sd.outputs(), outputs); //Check variables diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 2411af627..44b465fd3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -63,7 +63,7 @@ public class UIListenerTest { Map m = new HashMap<>(); iter.reset(); m.put("in", iter.next().getFeatures()); - INDArray out = sd.execSingle(m, "softmax"); + INDArray out = sd.outputSingle(m, "softmax"); assertNotNull(out); assertArrayEquals(new long[]{150, 3}, out.shape()); } @@ -181,7 +181,8 @@ public class UIListenerTest { SameDiff sd2 = SameDiff.create(); SDVariable in1 = sd2.placeHolder("in1", DataType.FLOAT, -1, 4); SDVariable in2 = sd2.placeHolder("in2", DataType.FLOAT, -1, 4); - SDVariable mul = in1.mul(in2); + SDVariable w = sd2.var("w", DataType.FLOAT, 1, 4); + SDVariable mul = in1.mul(in2).mul(w); SDVariable loss = mul.std(true); sd2.setTrainingConfig(TrainingConfig.builder() .dataSetFeatureMapping("in") diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 219ccc19c..012ba3434 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -161,7 +161,6 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { ec.eval(labels, arr); int[] expLabelCounts = labels.sum(0).data().asInt(); - // FIXME: int cast int[] expPredictionCount = new int[(int) labels.size(1)]; INDArray argmax = Nd4j.argMax(arr, 1); for (int i = 0; i < argmax.length(); i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java index 1f6301b2f..d320ad6e3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -75,9 +75,10 @@ public class ExecutionTests extends BaseNd4jTest { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); + System.out.println(tg.summary()); - Map result_0 = tg.exec(Collections.emptyMap(), tg.outputs()); + Map result_0 = tg.outputAll(null); val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0); assertEquals(exp_0, result_0.get("Sum")); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index 68f09c864..08e4a959c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -23,7 +23,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.transform.*; -import org.nd4j.base.Preconditions; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.tensorflow.TFImportOverride; @@ -35,7 +34,6 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.resources.Downloader; import org.nd4j.util.ArchiveUtils; @@ -109,7 +107,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Skip the "IteratorV2" op - we don't want or need this TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); /* Modify the network to remove hard-coded dropout operations for inference. @@ -176,7 +174,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Find pre-dropout input variable: SDVariable newOut = null; for(SDVariable v : inputs){ - if(v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")){ + if(v.name().endsWith("/BiasAdd") || v.name().endsWith("/Softmax") || v.name().endsWith("/add_1") || v.name().endsWith("/Tanh")){ newOut = v; break; } @@ -251,7 +249,7 @@ public class BERTGraphTest extends BaseNd4jTest { placeholderValues.put("IteratorGetNext:1", mask); placeholderValues.put("IteratorGetNext:4", segmentIdxs); - Map out = sd.exec(placeholderValues, "loss/Softmax"); + Map out = sd.output(placeholderValues, "loss/Softmax"); INDArray softmax = out.get("loss/Softmax"); // System.out.println("OUTPUT - Softmax"); // System.out.println(softmax); @@ -317,7 +315,7 @@ public class BERTGraphTest extends BaseNd4jTest { //Skip the "IteratorV2" op - we don't want or need this TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); /* Set floatConstants = new HashSet<>(Arrays.asList( @@ -337,8 +335,8 @@ public class BERTGraphTest extends BaseNd4jTest { //For training, convert weights and biases from constants to variables: for(SDVariable v : sd.variables()){ - if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.getVarName())){ //Skip scalars - trainable params - log.info("Converting to variable: {} - dtype: {} - shape: {}", v.getVarName(), v.dataType(), Arrays.toString(v.getArr().shape())); + if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name())){ //Skip scalars - trainable params + log.info("Converting to variable: {} - dtype: {} - shape: {}", v.name(), v.dataType(), Arrays.toString(v.getArr().shape())); v.convertToVariable(); } } @@ -395,14 +393,14 @@ public class BERTGraphTest extends BaseNd4jTest { placeholderValues.put("IteratorGetNext:4", segmentIdxs); placeholderValues.put("label", labelArr); - INDArray lossArr = sd.exec(placeholderValues, "loss").get("loss"); + INDArray lossArr = sd.output(placeholderValues, "loss").get("loss"); assertTrue(lossArr.isScalar()); double scoreBefore = lossArr.getDouble(0); for( int i=0; i<5; i++ ){ sd.fit(mds); } - lossArr = sd.exec(placeholderValues, "loss").get("loss"); + lossArr = sd.output(placeholderValues, "loss").get("loss"); assertTrue(lossArr.isScalar()); double scoreAfter = lossArr.getDouble(0); @@ -431,7 +429,7 @@ public class BERTGraphTest extends BaseNd4jTest { return "IteratorV2".equals(nodeDef.getName()); }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(f, m, filter); + SameDiff sd = TFGraphMapper.importGraph(f, m, filter); LogFileWriter w = new LogFileWriter(new File("C:/Temp/BERT_UI.bin")); long bytesWritten = w.writeGraphStructure(sd); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index ecc81c981..c57f8c5d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -30,10 +30,13 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; +import org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr; import org.nd4j.autodiff.validation.OpValidation; +import org.nd4j.autodiff.validation.TestCase; import org.nd4j.base.Preconditions; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; @@ -62,6 +65,7 @@ import org.springframework.core.io.support.ResourcePatternResolver; import java.io.*; import java.net.URI; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Pattern; @@ -84,7 +88,7 @@ public class TFGraphTestAllHelper { @Override public SameDiff apply(File file, String name) { try(InputStream is = new BufferedInputStream(new FileInputStream(file))){ - SameDiff sd = TFGraphMapper.getInstance().importGraph(is); + SameDiff sd = TFGraphMapper.importGraph(is); return sd; } catch (IOException e){ throw new RuntimeException(e); @@ -138,7 +142,19 @@ public class TFGraphTestAllHelper { " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; - SameDiff graph = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null); + Set outputsToCheck = new HashSet<>(); + for(String s : predictions.keySet()) { + // we need to convert name from python name format with . on indices, to :. i.e.: output.1 -> output:1 + if (s.matches(".*\\.\\d+")) { + int idx = s.lastIndexOf('.'); + s = s.substring(0, idx) + ":" + s.substring(idx+1); + } + outputsToCheck.add(s); + } + + Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck); + SameDiff graph = p.getFirst(); + Map sameDiffPredictions = p.getSecond(); //Collect coverage info about ops OpValidation.collectTensorflowImportCoverage(graph); @@ -156,7 +172,7 @@ public class TFGraphTestAllHelper { nd4jNode = outputNode.replaceAll("\\.", ":"); try { - nd4jPred = graph.getVariable(nd4jNode).getArr(); + nd4jPred = sameDiffPredictions.get(nd4jNode); } catch (NullPointerException e) { throw new NullPointerException("Can't find SameDiff variable with name [" + nd4jNode + "]"); } @@ -270,6 +286,12 @@ public class TFGraphTestAllHelper { log.info("\n========================================================\n"); } + //Serialize and deserialize, check equality: + ByteBuffer serialized = graph.asFlatBuffers(true); + Preconditions.checkNotNull(serialized, "Serialization failed? Null output"); + OpValidation.checkDeserializedEquality(graph, serialized, new TestCase(graph).testName(modelName).placeholderValues(inputs)); + + Nd4j.EPS_THRESHOLD = 1e-5; } @@ -285,7 +307,9 @@ public class TFGraphTestAllHelper { " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - SameDiff graph = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener)); + Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null); + SameDiff graph = p.getFirst(); + Map sdPredictions = p.getSecond(); //Collect coverage info about ops OpValidation.collectTensorflowImportCoverage(graph); @@ -313,7 +337,7 @@ public class TFGraphTestAllHelper { log.info("\n\tFORCING no check on " + varName); } else { assertArrayEquals("Shape not equal on node " + varName, tfValue.shape(), graph.getVariable(varName).getShape()); - INDArray sdVal = graph.getVariable(varName).getArr(); + INDArray sdVal = sdPredictions.get(varName); if(maxRelErrorOverride != null){ INDArray diff = Transforms.abs(tfValue.sub(sdVal), false); INDArray absErrorMask = diff.gte(minAbsErrorOverride); //value 1 if x[i] > minAbsError; value 0 otherwise. Used to get rid of 1e-30 vs. 1e-29 type failures @@ -362,30 +386,33 @@ public class TFGraphTestAllHelper { Nd4j.EPS_THRESHOLD = 1e-5; } - public static SameDiff getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, - ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners) throws IOException { + public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, + ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, + Set requiredOutputs) throws IOException { log.info("\n\tRUNNING TEST " + modelName + "..."); SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); if(listeners != null){ graph.setListeners(listeners); } -// = TFGraphMapper.getInstance().importGraph(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getInputStream()); - //System.out.println(graph.summary()); + + if(requiredOutputs == null){ + requiredOutputs = graph.variableMap().keySet(); + } + + Map outMap = null; if (executeWith.equals(ExecuteWith.SAMEDIFF)) { - List outputs = graph.outputs(); - if(outputs.isEmpty()){ - //Edge case: no ops - List vars = graph.variables(); - outputs = new ArrayList<>(); - for(SDVariable v : vars) { - outputs.add(v.getVarName()); - } - } - if (!inputs.isEmpty()) { - graph.exec(inputs, outputs); //This is expected to be just one result - } else { - graph.exec(Collections.emptyMap(), outputs); //there are graphs with no placeholders like g_00 - } + //Set memory manager - check that all arrays (other than the ones we requested as output) + CloseValidationMemoryMgr mmgr = new CloseValidationMemoryMgr(graph, new ArrayCloseMemoryMgr()); + long tid = Thread.currentThread().getId(); + if(!graph.getSessions().containsKey(tid)) + graph.getSessions().put(tid, new InferenceSession(graph)); + //Execute + graph.getSessions().get(tid).setMmgr(mmgr); + outMap = graph.output(inputs, new ArrayList<>(requiredOutputs)); + + //Check that all arrays were released + mmgr.assertAllReleasedExcept(outMap.values()); + graph.getSessions().clear(); } else if (executeWith.equals(ExecuteWith.LIBND4J)) { for (String input : inputs.keySet()) { graph.associateArrayWithVariable(inputs.get(input), graph.variableMap().get(input)); @@ -396,7 +423,6 @@ public class TFGraphTestAllHelper { val executioner = new NativeGraphExecutioner(); val results = executioner.executeGraph(graph, configuration); - //graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/non2d_1.fb")); } else if (executeWith.equals(ExecuteWith.JUST_PRINT)) { for (String input : inputs.keySet()) { graph.associateArrayWithVariable(inputs.get(input), graph.variableMap().get(input)); @@ -405,7 +431,8 @@ public class TFGraphTestAllHelper { val string = graph.asFlatPrint(); log.info("Graph structure: \n{}", string); } - return graph; + + return new Pair<>(graph, outMap); } private static String[] modelDirNames(String base_dir, ExecuteWith executeWith, String modelFileName) throws IOException { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 1da31d863..3a39dac37 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -53,7 +53,7 @@ public class TFGraphTestList { public static String[] modelNames = new String[]{ // "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" - "cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu" + "accumulate_n/rank0" }; @After diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 8429637fd..05edef2b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -255,7 +255,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we OpValidationSuite.ignoreFailing(); } -// if(!modelName.startsWith("ssd")){ +// if(!modelName.startsWith("mobilenet_v2_1.0_224")){ // OpValidationSuite.ignoreFailing(); // } currentTestDir = testDir.newFolder(); @@ -282,9 +282,12 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we } //Libnd4j exec: + /* + //AB 2019/10/19 - Libnd4j execution disabled pending execution rewrite currentTestDir = testDir.newFolder(); log.info("----- Libnd4j Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.LIBND4J, LOADER, maxRE, minAbs); + */ } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java index 43a9b2911..77d3c759e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java @@ -105,11 +105,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { //Perform inference List inputs = sd.inputs(); assertEquals(1, inputs.size()); - List outputs = sd.outputs(); - assertEquals(1, outputs.size()); - String out = outputs.get(0); - Map m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); + String out = "MobilenetV1/Predictions/Softmax"; + Map m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); @@ -167,7 +165,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { assertEquals(1, inputs.size()); String out = "softmax_tensor"; - Map m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); + Map m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index a501f9ff4..22b8b4492 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -39,7 +39,6 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.If; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -100,25 +99,25 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testSingleExample_1() { - val g =TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb")); + val g = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb")); val array = Nd4j.ones(1, 28, 28); g.associateArrayWithVariable(array, "flatten_1_input"); //g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build()); - g.execAndEndResult(); + g.outputAll(null); } @Test public void testAssertImport_1() { - val graph = TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb")); + val graph = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb")); } @Test public void testArgMaxImport_2() throws Exception { - val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream()); + val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream()); graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/argmax_macos.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true); @@ -127,78 +126,16 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testArgMaxImport_1() throws Exception { - val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); + val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); log.info(graph.asFlatPrint()); - val result = graph.execAndEndResult(); + val result = graph.outputAll(null).get(graph.outputs().get(0)); val exp = Nd4j.createFromArray(new long[]{2, 2, 2}); assertEquals(exp, result); } - - @Test - public void testIfStatementNodes() throws Exception { - // /home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/examples/simple_cond/frozen_graph.pbtxt - val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream(); - val mapper = TFGraphMapper.getInstance(); - val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream); - val nodes = mapper.nodesByName(readGraph); - /** - * Work backwards starting fom the condition id (usually a name containing condid/pred_id: - - */ - - val firstInput = nodes.get("cond5/Merge"); - val ifNodes = mapper.nodesForIf(firstInput,readGraph); - assertEquals(5,ifNodes.getFalseNodes().size()); - assertEquals(5,ifNodes.getTrueNodes().size()); - assertEquals(10,ifNodes.getCondNodes().size()); - - - val secondInput = nodes.get("cond6/Merge"); - val ifNodesTwo = mapper.nodesForIf(secondInput,readGraph); - assertEquals(5,ifNodesTwo.getFalseNodes().size()); - assertEquals(5,ifNodesTwo.getTrueNodes().size()); - assertEquals(6,ifNodesTwo.getCondNodes().size()); - - - val parentContext = SameDiff.create(); - val ifStatement = new If(); - ifStatement.initFromTensorFlow(firstInput,parentContext,Collections.emptyMap(),readGraph); - assertNotNull(ifStatement.getLoopBodyExecution()); - assertNotNull(ifStatement.getFalseBodyExecution()); - assertNotNull(ifStatement.getPredicateExecution()); - - } - - @Test - @Ignore - public void testIfIgnoreWhileMerge() throws Exception { - val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_while/frozen_model.pb").getInputStream(); - val mapper = TFGraphMapper.getInstance(); - val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream); - val nodes = mapper.nodesByName(readGraph); - val firstInput = nodes.get("output/Merge"); - assertNotNull(firstInput); - assertFalse(mapper.isOpIgnoreException(firstInput)); - - val resourceInputStreamIf = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream(); - val readGraphIf = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStreamIf); - val nodesif = mapper.nodesByName(readGraphIf); - /** - * Work backwards starting fom the condition id (usually a name containing condid/pred_id: - - */ - - val secondInput = nodesif.get("cond5/Merge"); - assertNotNull(secondInput); - assertTrue(mapper.isOpIgnoreException(secondInput)); - - } - - @Test public void testHashEquality1() { long hash = HashUtil.getLongHash("Conv2D"); @@ -222,7 +159,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph1() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream()); assertNotNull(graph); @@ -245,7 +182,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph2() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); assertNotNull(graph); } @@ -254,7 +191,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph3() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream()); assertNotNull(graph); } @@ -262,7 +199,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testImportIris() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream()); assertNotNull(graph); } @@ -271,7 +208,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void importGraph4() throws Exception { - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream()); assertNotNull(graph); @@ -285,7 +222,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { graph.var("Placeholder", p0); graph.var("Placeholder_1", p1); - val res = graph.execAndEndResult(); + val res = graph.outputAll(null).get(graph.outputs().get(0)); @@ -306,7 +243,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList()); System.out.println(nodeNames); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream()); val convNode = tg.getVariable("conv2d/kernel"); @@ -322,14 +259,14 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediate2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream()); } @Test public void testIntermediate1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); assertTrue(tg.getVariable("input") != null); // assertTrue(tg.getVariableSpace().getVariable("input").isPlaceholder()); @@ -348,7 +285,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateLoop1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream()); assertNotNull(tg); @@ -363,7 +300,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testWeirdConvImport() { - val tg = TFGraphMapper.getInstance().importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); + val tg = TFGraphMapper.importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); assertNotNull(tg); } @@ -371,7 +308,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateLoop3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream()); assertNotNull(tg); @@ -397,14 +334,14 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateStridedSlice1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream()); assertNotNull(tg); val constIn = tg.getVariable("StridedSlice/input"); assertNotNull(constIn); - val arr = tg.getArrForVarName(constIn.getVarName()); + val arr = tg.getArrForVarName(constIn.name()); assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5); @@ -473,7 +410,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateTensorArraySimple1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); tg.setArrayForVariable("input_matrix",Nd4j.ones(3,2)); assertNotNull(tg); @@ -500,7 +437,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testIntermediateTensorArrayLoop1() throws Exception { val input = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); tg.setArrayForVariable("input_matrix",input); assertNotNull(tg); @@ -545,7 +482,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testIntermediateReduction() throws Exception { Nd4j.create(1); - SameDiff tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + SameDiff tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); SDVariable sumResultVar = tg.getVariable("Sum"); /* val func = tg.getFunctionForVertexId(sumResultVar.getVertexId()); @@ -709,14 +646,13 @@ public class TensorFlowImportTest extends BaseNd4jTest { } */ - SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream()); + SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream()); assertNotNull(graph); INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4); INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT)); - INDArray actual = graph.execSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); + INDArray actual = graph.outputSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); assertEquals(input,graph.getVariable("input").getArr()); - assertArrayEquals(input.shape(),graph.getShapeForVarName(graph.getVariable("input").getVarName())); assertEquals(expectedOutput,actual); } @@ -724,17 +660,17 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testImportMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); val variables = new HashMap(); for (val var : tg.variables()) { - variables.put(var.getVarName(), var); + variables.put(var.name(), var); } val functions = new HashMap(); for (val func: tg.ops()) { val ownName = func.getOwnName(); - String outName = func.outputVariables()[0].getVarName(); + String outName = func.outputVariables()[0].name(); assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); assertEquals(ownName, outName); @@ -744,7 +680,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testCondMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0_1.fb")); @@ -759,7 +695,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testCondMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.create(2, 2).assign(-1); @@ -767,7 +703,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb")); //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(1); assertNotNull(array); assertEquals(exp, array); @@ -776,7 +712,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.create(2, 2).assign(1); tg.associateArrayWithVariable(input, tg.getVariable("input_0")); @@ -786,7 +722,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(1); assertNotNull(array); assertEquals(exp, array); @@ -795,7 +731,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.scalar(4.0); tg.associateArrayWithVariable(input, tg.getVariable("input_1")); @@ -804,7 +740,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); /* - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(2); assertNotNull(array); assertEquals(exp, array);*/ @@ -813,7 +749,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileMapping3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input = Nd4j.scalar(9.0); tg.associateArrayWithVariable(input, tg.getVariable("input_1")); @@ -822,7 +758,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(4); assertNotNull(array); assertEquals(exp, array); @@ -832,7 +768,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileDualMapping1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(-4.0); val input1 = Nd4j.scalar(1.0); @@ -843,7 +779,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + INDArray array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(-1); assertNotNull(array); assertEquals(exp, array); @@ -852,7 +788,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testWhileDualMapping2() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(-9.0); val input1 = Nd4j.scalar(1.0); @@ -863,7 +799,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + val array = tg.outputAll(null).get(tg.outputs().get(0)); val exp = Nd4j.create(2, 2).assign(-3); assertNotNull(array); assertEquals(exp, array); @@ -873,7 +809,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testMixedWhileCond1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 2).assign(1.0); val input1 = Nd4j.create(3, 3).assign(2.0); @@ -885,7 +821,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { //log.info("{}", tg.asFlatPrint()); - val array = tg.execAndEndResult(); + Map m = tg.outputAll(null); + val array = m.get(tg.outputs().get(0)); //val array = tg.getVariable("output").getArr(); val exp = Nd4j.create(2, 2).assign(15.0); assertNotNull(array); @@ -896,7 +833,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Ignore public void testProfConv() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt")); + val tg = TFGraphMapper.importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt")); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/profiling_conv.fb")); @@ -907,7 +844,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_matrix_diag() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(2, 5, 4).assign(1.0); @@ -926,7 +863,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_tensor_dot_misc() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/tensor_dot_misc/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/tensor_dot_misc/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(36, 3, 4, 5).assign(1.0); @@ -943,7 +880,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_transpose() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3}); @@ -960,7 +897,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_simpleif_0() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2}); @@ -977,7 +914,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_ae_00() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[] {0.98174960, 0.44406342, 0.50100771, 1.00000000, -0.94038386, 0.46501783, -0.49040590, 0.98153842, -0.00198260, 0.49108310, -0.06085236, 0.93523693, -0.05857396, -0.46633510, -0.02806635, -0.96879626, -0.03938015, -0.51578135, -0.06333921, -1.00000000}, new int[] {5, 4}); @@ -992,7 +929,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_expand_dim() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4}); @@ -1007,7 +944,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_reduce_dim_false() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); assertNotNull(tg); @@ -1019,7 +956,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testCrash_119_reduce_dim_true() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true); @@ -1027,11 +964,11 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_1() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.ones(3, 2); - val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2}); @@ -1040,12 +977,12 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_2() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.ones(3, 2); - val array = tg.exec(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.output(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)).get(tg.outputs().get(0)); val exp = Nd4j.create(new float[] {2, 2}, new int[]{2}); @@ -1057,10 +994,10 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testTensorArray_119_3() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); assertNotNull(tg); - val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.emptyMap(), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4}); @@ -1069,12 +1006,12 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testTensorArray_119_4() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); assertNotNull(tg); val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); log.info("Graph: {}", tg.asFlatPrint()); - val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); + val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2}); @@ -1084,15 +1021,15 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testLossImport_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); - tg.execAndEndResult(); + tg.outputAll(null); } @Test public void testG_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream()); val g = tg.asFlatBuffers(true); } @@ -1101,9 +1038,9 @@ public class TensorFlowImportTest extends BaseNd4jTest { public void testBoolImport_1() throws Exception { Nd4j.create(1); for (int e = 0; e < 1000; e++){ - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); - Map result = tg.exec(Collections.emptyMap(), tg.outputs()); + Map result = tg.output(Collections.emptyMap(), tg.outputs()); assertNotNull(result); assertTrue(result.size() > 0); @@ -1113,9 +1050,9 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testLogical_1() throws Exception { Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); - tg.execAndEndResult(); + tg.outputAll(null); } @Test @@ -1123,18 +1060,18 @@ public class TensorFlowImportTest extends BaseNd4jTest { // tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb Nd4j.create(1); - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); assertNotNull(tg); } @Test(expected = ND4JIllegalStateException.class) public void testNonFrozenGraph1() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); } @Test public void testRandomGraph() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/scalar_float32.fb")); @@ -1142,7 +1079,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testRandomGraph2() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb")); + val tg = TFGraphMapper.importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb")); assertNotNull(tg); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mobilenet_v2.fb")); @@ -1151,7 +1088,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Ignore public void testRandomGraph3() throws Exception { - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); log.info("{}", tg.asFlatPrint()); @@ -1161,7 +1098,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test public void testControlDependencies1() throws Exception { - SameDiff sd = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream()); + SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java index 1035bda7a..06c85b289 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java @@ -103,7 +103,7 @@ public class ImportModelDebugger { File modelFile = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85\\tf_model.pb"); File rootDir = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85"); - SameDiff sd = TFGraphMapper.getInstance().importGraph(modelFile); + SameDiff sd = TFGraphMapper.importGraph(modelFile); ImportDebugListener l = ImportDebugListener.builder(rootDir) .checkShapesOnly(true) @@ -118,7 +118,7 @@ public class ImportModelDebugger { List outputs = sd.outputs(); - sd.exec(ph, outputs); + sd.output(ph, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 31d51d59a..bac06b981 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -5273,7 +5273,6 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray exp = Nd4j.linspace(0, 9, 10, DataType.DOUBLE); int cnt = 0; for (long i = matrix.rows() - 1; i >= 0; i--) { - // FIXME: int cast matrix.getRow((int) i).assign(cnt); cnt++; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java index f88e78408..f1fdf9c57 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java @@ -342,7 +342,6 @@ public class ShufflesTests extends BaseNd4jTest { public float[] measureState(INDArray data) { // for 3D we save 0 element for each slice. - // FIXME: int cast float[] result = new float[(int) data.shape()[0]]; for (int x = 0; x < data.shape()[0]; x++) { @@ -390,7 +389,6 @@ public class ShufflesTests extends BaseNd4jTest { } public float[] measureState(INDArray data) { - // FIXME: int cast float[] result = new float[data.rows()]; for (int x = 0; x < data.rows(); x++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/AggregatesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/AggregatesTests.java deleted file mode 100644 index 91d80cd56..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/AggregatesTests.java +++ /dev/null @@ -1,178 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.aggregates; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.Aggregate; -import org.nd4j.linalg.api.ops.aggregates.impl.AggregateAxpy; -import org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * @author raver119@gmail.com - */ -@RunWith(Parameterized.class) -public class AggregatesTests extends BaseNd4jTest { - - public AggregatesTests(Nd4jBackend backend) { - super(backend); - } - - @Before - public void setUp() { - //DataTypeUtil.setDTypeForContext(DataType.DOUBLE); - } - - @Test - public void testAggregate1() { - INDArray arrayX = Nd4j.ones(10); - INDArray arrayY = Nd4j.zeros(10); - - INDArray exp1 = Nd4j.ones(10); - - AggregateAxpy axpy = new AggregateAxpy(arrayX, arrayY, 1.0f); - - Nd4j.getExecutioner().exec(axpy); - - assertEquals(exp1, arrayY); - } - - - @Test - public void testBatchedAggregate1() { - OpValidationSuite.ignoreFailing(); //CRASHING - INDArray arrayX1 = Nd4j.ones(DataType.FLOAT, 10); - INDArray arrayY1 = Nd4j.zeros(DataType.FLOAT,10); - - INDArray arrayX2 = Nd4j.ones(DataType.FLOAT,10); - INDArray arrayY2 = Nd4j.zeros(DataType.FLOAT,10); - - INDArray exp1 = Nd4j.create(DataType.FLOAT,10).assign(1f); - INDArray exp2 = Nd4j.create(DataType.FLOAT,10).assign(1f); - - AggregateAxpy axpy1 = new AggregateAxpy(arrayX1, arrayY1, 1.0f); - AggregateAxpy axpy2 = new AggregateAxpy(arrayX2, arrayY2, 1.0f); - - List batch = new ArrayList<>(); - batch.add(axpy1); - batch.add(axpy2); - - Nd4j.getExecutioner().exec(batch); - - assertEquals(exp1, arrayY1); - assertEquals(exp2, arrayY2); - } - - @Test - public void testBatchedAggregate2() { - INDArray arrayX1 = Nd4j.ones(10); - INDArray arrayY1 = Nd4j.zeros(10).assign(2.0f); - - INDArray arrayX2 = Nd4j.ones(10); - INDArray arrayY2 = Nd4j.zeros(10).assign(2.0f); - - INDArray arrayX3 = Nd4j.ones(10); - INDArray arrayY3 = Nd4j.ones(10); - - INDArray exp1 = Nd4j.create(10).assign(4f); - INDArray exp2 = Nd4j.create(10).assign(3f); - INDArray exp3 = Nd4j.create(10).assign(3f); - - AggregateAxpy axpy1 = new AggregateAxpy(arrayX1, arrayY1, 2.0f); - AggregateAxpy axpy2 = new AggregateAxpy(arrayX2, arrayY2, 1.0f); - AggregateAxpy axpy3 = new AggregateAxpy(arrayX3, arrayY3, 2.0f); - - List batch = new ArrayList<>(); - batch.add(axpy1); - batch.add(axpy2); - batch.add(axpy3); - - Nd4j.getExecutioner().exec(batch); - - assertEquals(exp1, arrayY1); - assertEquals(exp2, arrayY2); - assertEquals(exp3, arrayY3); - } - - @Test - public void testBatchedSkipGram1() { - OpValidationSuite.ignoreFailing(); //CRASHING - INDArray syn0 = Nd4j.create(DataType.FLOAT, 10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(DataType.FLOAT,10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.ones(DataType.FLOAT,10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(DataType.FLOAT,10000).assign(0.5f); - - double lr = 0.001; - - int idxSyn0_1 = 0; - int idxSyn0_2 = 3; - - INDArray expSyn0 = Nd4j.create(DataType.FLOAT,10).assign(0.01f); - INDArray expSyn1_1 = Nd4j.create(DataType.FLOAT,10).assign(0.020005); // gradient is 0.00005 - INDArray expSyn1_2 = Nd4j.create(DataType.FLOAT,10).assign(0.019995f); // gradient is -0.00005 - - - INDArray syn0row_1 = syn0.getRow(idxSyn0_1); - INDArray syn0row_2 = syn0.getRow(idxSyn0_2); - - AggregateSkipGram op1 = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, idxSyn0_1, new int[] {1, 2}, - new int[] {0, 1}, 0, 0, 10, lr, 1L, 10); - AggregateSkipGram op2 = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, idxSyn0_2, new int[] {4, 5}, - new int[] {0, 1}, 0, 0, 10, lr, 1L, 10); - - - List batch = new ArrayList<>(); - batch.add(op1); - batch.add(op2); - - Nd4j.getExecutioner().exec(batch); - - /* - Since expTable contains all-equal values, and only difference for ANY index is code being 0 or 1, syn0 row will stay intact, - because neu1e will be full of 0.0f, and axpy will have no actual effect - */ - assertEquals(expSyn0, syn0row_1); - assertEquals(expSyn0, syn0row_2); - - // syn1 row 1 modified only once - assertEquals(expSyn1_1, syn1.getRow(1)); - assertEquals(expSyn1_1, syn1.getRow(4)); - - // syn1 row 2 modified only once - assertEquals(expSyn1_2, syn1.getRow(2)); - assertEquals(expSyn1_2, syn1.getRow(5)); - } - - - @Override - public char ordering() { - return 'c'; - } -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/HierarchicSoftmaxTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/HierarchicSoftmaxTests.java deleted file mode 100644 index 3f07b6a2a..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/aggregates/HierarchicSoftmaxTests.java +++ /dev/null @@ -1,472 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.aggregates; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW; -import org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram; -import org.nd4j.linalg.api.ops.aggregates.impl.HierarchicSoftmax; -import org.nd4j.linalg.api.ops.impl.nlp.CbowRound; -import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; - -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; - -/** - * This tests pack covers simple gradient checks for AggregateSkipGram, CBOW and HierarchicSoftmax - * - * @author raver119@gmail.com - */ -@Slf4j -@RunWith(Parameterized.class) -public class HierarchicSoftmaxTests extends BaseNd4jTest { - - - public HierarchicSoftmaxTests(Nd4jBackend backend) { - super(backend); - } - - @Before - public void setUp() { - // DataTypeUtil.setDTypeForContext(DataType.DOUBLE); - } - - @Test - public void testHSGradient1() { - INDArray syn0 = Nd4j.ones(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.ones(10, 10).assign(0.02f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray neu1e = Nd4j.create(10); - - INDArray expSyn0 = Nd4j.create(10).assign(0.01f); - INDArray expSyn1 = Nd4j.create(10).assign(0.020005); - INDArray expNeu1e = Nd4j.create(10).assign(0.00001f); - - int idxSyn0 = 1; - int idxSyn1 = 1; - int code = 0; - - double lr = 0.001; - - HierarchicSoftmax op = - new HierarchicSoftmax(syn0.getRow(idxSyn0), syn1.getRow(idxSyn1), expTable, neu1e, code, lr); - - Nd4j.getExecutioner().exec(op); - - INDArray syn0row = syn0.getRow(idxSyn0); - INDArray syn1row = syn1.getRow(idxSyn1); - - // expected gradient is 0.0005 - // expected neu1 = 0.00001 - // expected syn1 = 0.020005 - - assertEquals(expNeu1e, neu1e); - - assertEquals(expSyn1, syn1row); - - // we hadn't modified syn0 at all yet - assertEquals(expSyn0, syn0row); - } - - @Test - public void testSGGradient1() { - INDArray syn0 = Nd4j.create(DataType.DOUBLE, 10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(DataType.DOUBLE,10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.create(DataType.DOUBLE,10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(DataType.DOUBLE,10000).assign(0.5f); - - double lr = 0.001; - - int idxSyn0 = 0; - - INDArray expSyn0 = Nd4j.create(DataType.DOUBLE,10).assign(0.01001f); - INDArray expSyn1_1 = Nd4j.create(DataType.DOUBLE,10).assign(0.020005); - - INDArray syn0row = syn0.getRow(idxSyn0); - - log.info("syn0row before: {}", Arrays.toString(syn0row.dup().data().asFloat())); - - AggregateSkipGram op = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, idxSyn0, new int[] {1}, - new int[] {0}, 0, 0, 10, lr, 1L, 10); - //Nd4j.getExecutioner().exec(op); - val sg = new SkipGramRound(idxSyn0, syn0, syn1, expTable, new int[] {1}, new byte[]{0}, lr, 1L, Nd4j.empty(syn0.dataType())); - Nd4j.getExecutioner().exec(sg); - - log.info("syn0row after: {}", Arrays.toString(syn0row.dup().data().asFloat())); - - assertEquals(expSyn0, syn0row); - assertEquals(expSyn1_1, syn1.getRow(1)); - } - - @Test - public void testSGGradient2() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.ones(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - - double lr = 0.001; - - int idxSyn0 = 0; - - INDArray expSyn0 = Nd4j.create(10).assign(0.01f); - INDArray expSyn1_1 = Nd4j.create(10).assign(0.020005); // gradient is 0.00005 - INDArray expSyn1_2 = Nd4j.create(10).assign(0.019995f); // gradient is -0.00005 - - - INDArray syn0row = syn0.getRow(idxSyn0); - - - log.info("syn1row2 before: {}", Arrays.toString(syn1.getRow(2).dup().data().asFloat())); - - AggregateSkipGram op = new AggregateSkipGram(syn0, syn1, null, expTable, null, idxSyn0, new int[] {1, 2}, - new int[] {0, 1}, 0, 0, 10, lr, 1L, 10); - //Nd4j.getExecutioner().exec(op); - val sg = new SkipGramRound(idxSyn0, syn0, syn1, expTable, new int[] {1, 2}, new byte[]{0, 1}, lr, 1L, Nd4j.empty(syn0.dataType())); - Nd4j.getExecutioner().exec(sg); - - /* - Since expTable contains all-equal values, and only difference for ANY index is code being 0 or 1, syn0 row will stay intact, - because neu1e will be full of 0.0f, and axpy will have no actual effect - */ - assertEquals(expSyn0, syn0row); - - // syn1 row 1 modified only once - assertArrayEquals(expSyn1_1.data().asFloat(), syn1.getRow(1).dup().data().asFloat(), 1e-7f); - - log.info("syn1row2 after: {}", Arrays.toString(syn1.getRow(2).dup().data().asFloat())); - - // syn1 row 2 modified only once - assertArrayEquals(expSyn1_2.data().asFloat(), syn1.getRow(2).dup().data().asFloat(), 1e-7f); - } - - /** - * This particular test does nothing: neither HS or Neh is executed - * - * @throws Exception - */ - @Test - public void testSGGradientNoOp() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.ones(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = null; - - double lr = 0.001; - - int idxSyn0 = 0; - INDArray expSyn0 = Nd4j.create(10).assign(0.01f); - INDArray expSyn1 = syn1.dup(); - - AggregateSkipGram op = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, table, idxSyn0, new int[] {}, - new int[] {}, 0, 0, 10, lr, 1L, 10); - - Nd4j.getExecutioner().exec(op); - - assertEquals(expSyn0, syn0.getRow(idxSyn0)); - assertEquals(expSyn1, syn1); - } - - @Test - public void testSGGradientNegative1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.ones(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = Nd4j.create(100000); - - double lr = 0.001; - - INDArray expSyn0 = Nd4j.create(10).assign(0.01f); - - int idxSyn0 = 1; - - log.info("syn0row1 after: {}", Arrays.toString(syn0.getRow(idxSyn0).dup().data().asFloat())); - - - AggregateSkipGram op = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, table, idxSyn0, new int[] {}, - new int[] {}, 1, 3, 10, lr, 2L, 10); - //Nd4j.getExecutioner().exec(op); - - val sg = new SkipGramRound(idxSyn0, 3, syn0, syn1Neg, expTable, table, 1, lr, 2L, Nd4j.empty(syn0.dataType())); - Nd4j.getExecutioner().exec(sg); - - log.info("syn0row1 after: {}", Arrays.toString(syn0.getRow(idxSyn0).dup().data().asFloat())); - - // we expect syn0 to be equal, since 2 rounds with +- gradients give the same output value for neu1e - assertEquals(expSyn0, syn0.getRow(idxSyn0)); - } - - - @Test - public void testCBOWGradient1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - - double lr = 0.025; - - INDArray syn0row_before_0 = syn0.getRow(0).dup(); - INDArray syn0row_before_1 = syn0.getRow(1).dup(); - INDArray syn0row_before_2 = syn0.getRow(2).dup(); - - AggregateCBOW op = new AggregateCBOW(syn0, syn1, null, expTable, null, 0, new int[] {0, 1, 2}, new int[] {4, 5}, - new int[] {1, 1}, 0, 0, 10, lr, 2L, 10); - //Nd4j.getExecutioner().exec(op); - - val sg = new CbowRound(0, new int[] {0, 1, 2}, new int[] {0,0,0}, syn0, syn1, expTable, new int[] {4, 5}, new byte[]{1, 1}, lr, 2L, Nd4j.empty(syn0.dataType()), 1); - Nd4j.getExecutioner().exec(sg); - - INDArray syn0row_0 = syn0.getRow(0); - INDArray syn0row_1 = syn0.getRow(1); - INDArray syn0row_2 = syn0.getRow(2); - - INDArray syn1row_4 = syn1.getRow(4); - INDArray syn1row_5 = syn1.getRow(5); - INDArray syn1row_6 = syn1.getRow(6); - - INDArray expSyn0row_0 = Nd4j.create(10).assign(0.0095f); - INDArray expSyn1row_4 = Nd4j.create(10).assign(0.019875f); - INDArray expSyn1row_6 = Nd4j.create(10).assign(0.02f); - - assertNotEquals(syn0row_before_0, syn0row_0); - assertNotEquals(syn0row_before_1, syn0row_1); - assertNotEquals(syn0row_before_2, syn0row_2); - - // neu1 is expected to be 0.01 - // dot is expected to be 0.002 - // g is expected -0.0125 for both rounds: both codes are 1, so (1 - 1 - 0.5) * 0.025 = -0.0125 - // neu1e is expected to be -0.00025 after first round ( g * syn1 + neu1e) (-0.0125 * 0.02 + 0.000) - // neu1e is expected to be -0.00050 after second round (-0.0125 * 0.02 + -0.00025) - // syn1 is expected to be 0.019875 after first round (g * neu1 + syn1) (-0.0125 * 0.01 + 0.02 ) - // syn1 is expected to be 0.019875 after second round (g * neu1 + syn1) (-0.0125 * 0.01 + 0.02 ) NOTE: each of round uses it's own syn1 index - - // syn0 is expected to be 0.0095f after op (syn0 += neu1e) (0.01 += -0.0005) - - log.info("syn1row4[0]: {}", syn1row_4.getFloat(0)); - - assertEquals(expSyn0row_0, syn0row_0); - assertEquals(expSyn0row_0, syn0row_1); - assertEquals(expSyn0row_0, syn0row_2); - - assertEquals(expSyn1row_4, syn1row_4); - assertEquals(expSyn1row_4, syn1row_5); - assertEquals(expSyn1row_6, syn1row_6); - - } - - @Test - public void testCBOWGradientNoOp1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.ones(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = Nd4j.create(100000); - - double lr = 0.025; - - INDArray expSyn0 = syn0.dup(); - INDArray expSyn1 = syn1.dup(); - INDArray expSyn1Neg = syn1Neg.dup(); - - AggregateCBOW op = new AggregateCBOW(syn0, syn1, syn1Neg, expTable, table, 0, new int[] {}, new int[] {}, - new int[] {}, 0, 0, 10, lr, 2L, 10); - - Nd4j.getExecutioner().exec(op); - - assertEquals(expSyn0, syn0); - assertEquals(expSyn1, syn1); - assertEquals(expSyn1Neg, syn1Neg); - } - - @Test - public void testCBOWGradientNegative1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.create(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = Nd4j.create(100000); - - double lr = 0.025; - - INDArray syn0dup = syn0.dup(); - INDArray syn1dup = syn1.dup(); - INDArray syn1NegDup = syn1Neg.dup(); - - INDArray expSyn0_row0 = Nd4j.create(10).assign(0.0096265625); - INDArray expSyn0_row3 = Nd4j.create(10).assign(0.01f); - INDArray expSyn1Neg_row6 = Nd4j.create(10).assign(0.030125f); - - AggregateCBOW op = new AggregateCBOW(syn0, syn1, syn1Neg, expTable, table, 0, new int[] {0, 1, 2}, new int[] {}, new int[] {}, 2, 6, 10, lr, 2L, 10); - //Nd4j.getExecutioner().exec(op); - - val sg = new CbowRound(0, new int[]{0, 1, 2}, new int[] {0, 0, 0}, 6, syn0, syn1Neg, expTable, table, 2, lr, 2L, Nd4j.empty(syn0.dataType()), 1); - Nd4j.getExecutioner().exec(sg); - - - assertNotEquals(syn0dup, syn0); - assertNotEquals(syn1NegDup, syn1Neg); - assertEquals(syn1dup, syn1); - - // neu1 is expected to be 0.01 - // dot is expected to be 0.003 (dot += 0.01 * 0.03) for round 1 & 2. - // dot is expected to be 0.002987 for round 3 (because syn1Neg for idx 8 is modified at round 2) - // g is expected to be 0.0125 for the first round (code is 1) - // g is expected to be -0.0125 for the second round (code is 0) - // g is expected to be -0.0125 for the third round (code is 0) - // neu1e is expected to be 0.000375 after first round (0.0125 * 0.03 + 0.00) - // neu1e is expected to be 0.00 after second round (-0.0125 * 0.03 + 0.000375) - // neu1e is expected to be -0.0003734375 after third round (-0.0125 * 0.029875 + 0.00) - // syn1Neg idx6 is expected to be 0.030125 after first round (0.0125 * 0.01 + 0.03) - // syn1Neg idx8 is expected to be 0.029875 after second round (-0.0125 * 0.01 + 0.03) - // syn1Neg idx8 is expected to be 0.02975 after third round (-0.0125 * 0.01 + 0.029875) - // syn0 idx0 is expected to be 0.00 after training (0.01 += -0.0003734375) - - log.info("syn1neg_row6 after: {}", Arrays.toString(syn1Neg.getRow(6).dup().data().asFloat())); - - // checking target first - assertEquals(expSyn1Neg_row6, syn1Neg.getRow(6)); - - assertEquals(expSyn0_row0, syn0.getRow(0)); - assertEquals(expSyn0_row0, syn0.getRow(1)); - assertEquals(expSyn0_row0, syn0.getRow(2)); - - // these rows shouldn't change - assertEquals(expSyn0_row3, syn0.getRow(3)); - assertEquals(expSyn0_row3, syn0.getRow(4)); - assertEquals(expSyn0_row3, syn0.getRow(5)); - assertEquals(expSyn0_row3, syn0.getRow(6)); - assertEquals(expSyn0_row3, syn0.getRow(7)); - assertEquals(expSyn0_row3, syn0.getRow(8)); - assertEquals(expSyn0_row3, syn0.getRow(9)); - } - - - @Test - public void testCBOWInference1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.create(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = Nd4j.create(100000); - - double lr = 0.025; - - INDArray syn0dup = syn0.dup(); - INDArray syn1dup = syn1.dup(); - INDArray syn1NegDup = syn1Neg.dup(); - - INDArray inference = Nd4j.create(10).assign(0.04f); - INDArray dup = inference.dup(); - INDArray expInference = Nd4j.create(10).assign(0.0395f); - - log.info("Empty vector: {}", Arrays.toString(inference.data().asFloat())); - - /* - surrounding words are 0 and 1 - */ - AggregateCBOW op = new AggregateCBOW(syn0, syn1, null, expTable, null, 0, new int[] {0, 1}, new int[] {4, 5}, - new int[] {1, 1}, 0, 0, 10, lr, 2L, 10, 0, false, inference); - - Nd4j.getExecutioner().exec(op); - - /* - syn0, syn1 and syn1Neg should stay intact during inference - */ - assertEquals(syn0dup, syn0); - assertEquals(syn1dup, syn1); - assertEquals(syn1NegDup, syn1Neg); - - /** - * neu1 is expected to be 0.02 - * syn1 is expected to be 0.02 - * dot is expected to be 0.04 ( 0.02 * 0.02 * 10) - * g is expected to be -0.0125 for BOTH rounds, since we're not changing syn1 values during inference - * neu1e is expected to be -0.00025 at first round (-0.0125 * 0.02 + 0.00) - * neu1e is expected to be -0.0005 at second round (-0.0125 * 0.02 + -0.00025) - * inference is expected to be 0.0395 after training (0.04 + -0.0005) - */ - - assertNotEquals(dup, inference); - - log.info("Inferred vector: {}", Arrays.toString(inference.data().asFloat())); - - assertEquals(expInference, inference); - - } - - @Test - public void testSGInference1() { - INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); - INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); - INDArray syn1Neg = Nd4j.create(10, 10).assign(0.03f); - INDArray expTable = Nd4j.create(10000).assign(0.5f); - INDArray table = Nd4j.create(100000); - - double lr = 0.025; - - INDArray syn0dup = syn0.dup(); - INDArray syn1dup = syn1.dup(); - INDArray syn1NegDup = syn1Neg.dup(); - - INDArray inference = Nd4j.create(10).assign(0.04f); - INDArray dup = inference.dup(); - INDArray expInference = Nd4j.create(10).assign(0.0395f); - - AggregateSkipGram op = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, 0, new int[] {1, 2}, - new int[] {1, 1}, 0, 0, 10, lr, 1L, 10, inference); - - Nd4j.getExecutioner().exec(op); - - /* - syn0, syn1 and syn1Neg should stay intact during inference - */ - assertEquals(syn0dup, syn0); - assertEquals(syn1dup, syn1); - assertEquals(syn1NegDup, syn1Neg); - - assertNotEquals(dup, inference); - - /** - * dot is expected to be 0.008 for both rounds - * g is expected to be -0.0125 for both rounds, since we don't update syn0/syn1 before end of SG round - * neu1e is expected to be -0.00025 after first round (-0.0125 * 0.02 + 0.00) - * neu1e is expected to be -0.0005 after first round (-0.0125 * 0.02 + -0.00025) - * inferenceVector is expected to be 0.0395 after training (0.04 + -0.0005) - */ - - assertEquals(expInference, inference); - } - - @Override - public char ordering() { - return 'c'; - } -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java index 907bdd04a..2dc595f54 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java @@ -38,31 +38,26 @@ public class TestNDArrayCreationUtil extends BaseNd4jTest { @Test public void testShapes() { - // FIXME: int cast long[] shape2d = {2, 3}; for (Pair p : NDArrayCreationUtil.getAllTestMatricesWithShape(2, 3, 12345, DataType.DOUBLE)) { assertArrayEquals(p.getSecond(), shape2d, p.getFirst().shape()); } - // FIXME: int cast long[] shape3d = {2, 3, 4}; for (Pair p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape3d, DataType.DOUBLE)) { assertArrayEquals(p.getSecond(), shape3d, p.getFirst().shape()); } - // FIXME: int cast long[] shape4d = {2, 3, 4, 5}; for (Pair p : NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, ArrayUtil.toInts(shape4d), DataType.DOUBLE)) { assertArrayEquals(p.getSecond(), shape4d, p.getFirst().shape()); } - // FIXME: int cast long[] shape5d = {2, 3, 4, 5, 6}; for (Pair p : NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, ArrayUtil.toInts(shape5d), DataType.DOUBLE)) { assertArrayEquals(p.getSecond(), shape5d, p.getFirst().shape()); } - // FIXME: int cast long[] shape6d = {2, 3, 4, 5, 6, 7}; for (Pair p : NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, ArrayUtil.toInts(shape6d), DataType.DOUBLE)) { assertArrayEquals(p.getSecond(), shape6d, p.getFirst().shape()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java index 09350092c..e04a41714 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java @@ -1304,7 +1304,7 @@ public class ConvolutionTests extends BaseNd4jTest { @Test public void testConvOutWidthAndHeight() { - int outSize = Convolution.outSize(2, 1, 1, 2, 1, false); + long outSize = Convolution.outSize(2, 1, 1, 2, 1, false); assertEquals(6, outSize); } /* diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java index 46d93fff5..cd7ef26b6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java @@ -55,7 +55,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { @Test public void testConvOutWidthAndHeight() { - int outSize = Convolution.outSize(2, 1, 1, 2, 1, false); + long outSize = Convolution.outSize(2, 1, 1, 2, 1, false); assertEquals(6, outSize); } @@ -415,14 +415,13 @@ public class ConvolutionTestsC extends BaseNd4jTest { int outH = (int)Math.ceil(input.size(2)/(double)s[0]); int outW = (int)Math.ceil(input.size(3)/(double)s[1]); - // FIXME: int cast - int totalPadH = (outH-1)*s[0] + k[0] - (int) input.size(2); - int totalPadW = (outW-1)*s[1] + k[1] - (int) input.size(3); + long totalPadH = (outH-1)*s[0] + k[0] - input.size(2); + long totalPadW = (outW-1)*s[1] + k[1] - input.size(3); - int topPad = totalPadH/2; - int bottomPad = totalPadH - topPad; - int leftPad = totalPadW/2; - int rightPad = totalPadW - leftPad; + long topPad = totalPadH/2; + long bottomPad = totalPadH - topPad; + long leftPad = totalPadW/2; + long rightPad = totalPadW - leftPad; INDArray outGrad = Nd4j.create(input.shape()); @@ -432,10 +431,10 @@ public class ConvolutionTestsC extends BaseNd4jTest { for( int x=0; x max){ max = v; - maxPos = new int[]{kTLy + kY, kTLx + kX}; + maxPos = new long[]{kTLy + kY, kTLx + kX}; } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java new file mode 100644 index 000000000..c1e5a8704 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java @@ -0,0 +1,93 @@ +package org.nd4j.linalg.convolution; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.*; + +import static org.junit.Assert.*; + +public class DeconvTests extends BaseNd4jTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public DeconvTests(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void compareKeras() throws Exception { + File f = testDir.newFolder(); + Resources.copyDirectory("keras/deconv", f); + + File[] files = f.listFiles(); + + Set tests = new HashSet<>(); + for(File file : files){ + String n = file.getName(); + if(!n.startsWith("mb")) + continue; + + int idx = n.lastIndexOf('_'); + String name = n.substring(0, idx); + tests.add(name); + } + + List l = new ArrayList<>(tests); + Collections.sort(l); + assertFalse(l.isEmpty()); + + for(String s : l){ + String s2 = s.replaceAll("[a-zA-Z]", ""); + String[] nums = s2.split("_"); + int mb = Integer.parseInt(nums[0]); + int k = Integer.parseInt(nums[1]); + int size = Integer.parseInt(nums[2]); + int stride = Integer.parseInt(nums[3]); + boolean same = s.contains("same"); + int d = Integer.parseInt(nums[5]); + boolean nchw = s.contains("nchw"); + + INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy")); + INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy")); + INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT); + INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy")); + + CustomOp op = DynamicCustomOp.builder("deconv2d") + .addInputs(in, w, b) + .addIntegerArguments( + k, k, + stride, stride, + 0, 0, //padding + d, d, + same ? 1 : 0, + nchw ? 0 : 1) + .callInplace(false) + .build(); + INDArray out = Nd4j.create(op.calculateOutputShape().get(0)); + out.assign(Double.NaN); + op.addOutputArgument(out); + Nd4j.exec(op); + + boolean eq = expOut.equalsWithEps(out, 1e-4); + assertTrue(eq); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index c3a66200b..c6efda9b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -88,8 +88,7 @@ public class CrashTest extends BaseNd4jTest { INDArray y = Nd4j.create(64, 64, 1024); for (int i = 0; i < ITERATIONS; i++) { - // FIXME: int cast - int slice = RandomUtils.nextInt(0, (int) x.shape()[0]); + long slice = RandomUtils.nextLong(0, x.shape()[0]); op(x.slice(slice), y.slice(slice), i); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index ded23f810..ad38f39d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -26,8 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.custom.Flatten; -import org.nd4j.linalg.api.ops.custom.ScatterUpdate; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; @@ -807,4 +806,129 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().commit(); } + + @Test + public void testAdjustContrast() { + INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3); + INDArray out = Nd4j.zeros(DataType.DOUBLE,4, 4, 3); + + INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + }).reshape(4,4,3); + Nd4j.exec(new AdjustContrast(in, 2.0, out)); + + assertArrayEquals(out.shape(), in.shape()); + assertEquals(expected, out); + } + + @Test + public void testAdjustContrastV2() { + INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3); + INDArray out = Nd4j.createUninitialized(4,4,3); + + INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + }).reshape(4,4,3); + + Nd4j.exec(new AdjustContrastV2(in, 2.0, out)); + + assertArrayEquals(out.shape(), in.shape()); + assertEquals(expected, out); + } + + @Test + public void testBitCast() { + INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); + INDArray out = Nd4j.createUninitialized(2,2); + + Nd4j.exec(new BitCast(in, DataType.DOUBLE.toInt(), out)); + + INDArray expected = Nd4j.createFromArray(new double[]{2., 512., 8192., 131072.032 }).reshape(2,2); + assertArrayEquals(new long[]{2,2}, out.shape()); + assertEquals(expected, out); + } + + @Test + public void testCompareAndBitpack() { + INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, + -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4); + INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4); + INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}). + reshape(2,3,4); + + Nd4j.exec(new CompareAndBitpack(in ,2.0, out)); + assertArrayEquals(new long[]{2,3,4}, out.shape()); + } + + @Test + public void testDivideNoNan() { + INDArray in1 = Nd4j.rand(DataType.DOUBLE, 2,3,4); + INDArray in2 = Nd4j.rand(DataType.DOUBLE, 2,3,4); + INDArray out = Nd4j.createUninitialized(DataType.DOUBLE, 2,3,4); + + Nd4j.exec(new DivideNoNan(in1, in2, out)); + assertArrayEquals(new long[]{2,3,4}, out.shape()); + } + + @Test + public void testDrawBoundingBoxes() { + INDArray images = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 2*4*5*3).reshape(2,4,5,3); + INDArray boxes = Nd4j.createFromArray(new float[]{ 0.0f , 0.0f , 1.0f , 1.0f, + 0.1f, 0.2f, 0.9f, 0.8f, + 0.3f, 0.3f, 0.7f, 0.7f, + 0.4f, 0.4f, 0.6f, 0.6f}).reshape(2,2,4); + INDArray colors = Nd4j.createFromArray(new float[]{ + 201.0f, 202.0f, 203.0f, 127.0f, 128.0f, 129.0f}). + reshape(2,3); + INDArray output = Nd4j.create(DataType.FLOAT, images.shape()); + INDArray expected = Nd4j.createFromArray(new float[]{127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, + 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 201.f, 202.f, 203.f, 201.f ,202.f ,203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, + + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, + 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}). + reshape(2,4,5,3); + + Nd4j.exec(new DrawBoundingBoxes(images, boxes, colors, output)); + + assertArrayEquals(images.shape(), output.shape()); + assertEquals(expected, output); + } + + @Test + public void FakeQuantWithMinMaxVarsPerChannel() { + + INDArray x = Nd4j.createFromArray(new float[]{-63.80f, -63.75f, -63.4f, -63.5f, 0.0f, 0.1f}). + reshape(1,2,3,1); + + INDArray min = Nd4j.createFromArray(new float[]{-63.65f}); + INDArray max = Nd4j.createFromArray(new float[]{0.1f}); + + INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1); + INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}). + reshape(1,2,3,1); + + Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output)); + + assertEquals(expected, output); + } + + @Test + public void testKnnMinDistance() { + INDArray point = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray lowest = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray highest = Nd4j.rand(DataType.FLOAT, 1, 20); + INDArray distance = Nd4j.scalar(0.f); + + Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance)); + System.out.println(distance); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java index 1f8547a27..6dfe1e4c7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java @@ -20,11 +20,13 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.Assert.assertEquals; @@ -42,32 +44,32 @@ public class ImagePreProcessortTest extends BaseNd4jTest { @Test public void simpleImageTest() { - INDArray rChannels = Nd4j.zeros(10, 10).addi(128); - INDArray gChannels = Nd4j.zeros(10, 10).addi(64); - INDArray bChannels = Nd4j.zeros(10, 10).addi(255); - INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(3, 10, 10); + INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128); + INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64); + INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255); + INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(1, 3, 10, 10); INDArray orig = image.dup(); //System.out.println(Arrays.toString(image.shape())); - DataSet ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1)); + DataSet ds = new DataSet(image, Nd4j.ones(1, 1)); ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler(); //So this should scale to 0.5,0.25 and 1; INDArray expected = image.mul(0); - expected.slice(0, 0).addi(0.5); - expected.slice(1, 0).addi(0.25); - expected.slice(2, 0).addi(1.0); + expected.slice(0, 1).addi(0.5); + expected.slice(1, 1).addi(0.25); + expected.slice(2, 1).addi(1.0); myScaler.transform(ds); assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01); //Now giving it 16 bits instead of the default //System.out.println(Arrays.toString(image.shape())); - ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1)); + ds = new DataSet(image, Nd4j.ones(1, 1)); myScaler = new ImagePreProcessingScaler(0, 1, 16); //So this should scale to 0.5,0.25 and 1; expected = image.mul(0); - expected.slice(0, 0).addi(0.5 / 256); - expected.slice(1, 0).addi(0.25 / 256); - expected.slice(2, 0).addi(1.0 / 256); + expected.slice(0, 1).addi(0.5 / 256); + expected.slice(1, 1).addi(0.25 / 256); + expected.slice(2, 1).addi(1.0 / 256); myScaler.transform(ds); assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01); @@ -88,6 +90,16 @@ public class ImagePreProcessortTest extends BaseNd4jTest { myScaler.transform(before); myScaler.revertFeatures(before); assertEquals(orig, before); + + + //Test labels (segmentation case) + before = orig.dup(); + myScaler = new ImagePreProcessingScaler(0, 1); + myScaler.transformLabel(before); + expected = orig.div(255); + assertEquals(expected, before); + myScaler.revertLabels(before); + assertEquals(orig, before); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java index 62edf3499..18776292a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java @@ -323,7 +323,6 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { this.samples = samples; this.origin = origin; - // FIXME: int cast numFeatures = (int) featureScale.size(0); maxN = samples * timeSteps; INDArray template = Nd4j.linspace(origin, origin + timeSteps - 1, timeSteps).reshape(1, -1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java index 3548dc2e2..a28c026cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java @@ -153,9 +153,8 @@ public class TestPCA extends BaseNd4jTest { System.out.println("Eigenvalues:\n" + ns.format(myPCA.getEigenvalues())); double variance = 0.0; - // FIXME: int cast // sample 1000 of the randomly generated samples with the reduced basis set - for (int i = 0; i < 1000; i++) + for (long i = 0; i < 1000; i++) variance += myPCA.estimateVariance(m.getRow(i), reduced70.columns()); variance /= 1000.0; System.out.println("Fraction of variance using 70% variance with " + reduced70.columns() + " columns: " diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 52ede954a..e04250f69 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -545,7 +545,6 @@ public class OpExecutionerTests extends BaseNd4jTest { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); - // FIXME: int cast val expected = new double[(int) slice.length()]; for (int i = 0; i < slice.length(); i++) expected[i] = (float) Math.exp(slice.getDouble(i)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 0df75ac74..72be040c5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -411,7 +411,6 @@ public class OpExecutionerTestsC extends BaseNd4jTest { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); - // FIXME: int cast val expected = new double[(int) slice.length()]; for (int i = 0; i < slice.length(); i++) expected[i] = (float) Math.exp(slice.getDouble(i)); @@ -852,7 +851,6 @@ public class OpExecutionerTestsC extends BaseNd4jTest { val next = iter.next(); double d = fourd.getDouble(next); - // FIXME: int cast sums[(int) next[0]] += d; sumSquares[(int) next[0]] += d * d; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index f95842cc2..ed8f4d441 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -1375,6 +1375,24 @@ public class RandomTests extends BaseNd4jTest { log.info("Array: {}", array); } + @Test + public void testOrthogonalDistribution2() { + val dist = new OrthogonalDistribution(1.0); + + val array = dist.sample(new int[] {9, 6}); + + log.info("Array: {}", array); + } + + @Test + public void testOrthogonalDistribution3() { + val dist = new OrthogonalDistribution(1.0); + + val array = dist.sample(new int[] {9, 9}); + + log.info("Array: {}", array); + } + @Test public void reproducabilityTest(){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 7e9f1e91c..2483f03e6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -100,9 +100,8 @@ public class PaddingTestsC extends BaseNd4jTest { val h = linspaced.size(2); val w = linspaced.size(3); - // FIXME: int cast - int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true); - int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true); + long outWidth = Convolution.outSize(h, kh, sy, ph, 1, true); + long outHeight = Convolution.outSize(w, kw, sx, pw, 1, true); INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); System.out.println(padded); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index fee72b2d3..d98e9218e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -319,6 +319,11 @@ public class BasicWorkspaceTests extends BaseNd4jTest { long reqMemory = 5 * Nd4j.sizeOfDataType(array1.dataType()); assertEquals(reqMemory + reqMemory % 8, wsI.getPrimaryOffset()); assertEquals(array1, array2); + + INDArray array3 = Nd4j.createUninitializedDetached(DataType.FLOAT, new long[0]); + assertTrue(array3.isScalar()); + assertEquals(1, array3.length()); + assertEquals(1, array3.data().length()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java index 295ad5d67..d7db01c7b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java @@ -126,7 +126,6 @@ public class BinarySerdeTest extends BaseNd4jTest { Nd4j.getCompressor().compressi(arr, "GZIP"); for (int i = 0; i < numTrials; i++) { StopWatch oldStopWatch = new StopWatch(); - // FIXME: int cast BufferedOutputStream bos = new BufferedOutputStream(new ByteArrayOutputStream((int) arr.length())); DataOutputStream dos = new DataOutputStream(bos); oldStopWatch.start(); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java b/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java new file mode 100644 index 000000000..c336befd7 --- /dev/null +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/collections/WeakIdentityHashMap.java @@ -0,0 +1,161 @@ +package org.nd4j.collections; + +import lombok.*; + +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.util.*; + +/** + * A hash map implementation with weak identity keys. + * For details, see {@link WeakHashMap} and {@link IdentityHashMap} + * + * @param Key type + * @param Value type + * @author Alex Black + */ +public class WeakIdentityHashMap implements Map { + + protected final Map, V> map; + protected final ReferenceQueue refQueue; + + public WeakIdentityHashMap(){ + map = new HashMap<>(); + refQueue = new ReferenceQueue<>(); + } + + //Clear references to any map keys that have been GC'd + protected void clearReferences(){ + Reference r; + while((r = refQueue.poll()) != null){ + map.remove(r); + } + } + + @Override + public int size() { + clearReferences(); + return map.size(); + } + + @Override + public boolean isEmpty() { + clearReferences(); + return map.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + clearReferences(); + return map.containsKey(new KeyRef<>(key)); + } + + @Override + public boolean containsValue(Object value) { + clearReferences(); + return map.containsValue(value); + } + + @Override + public V get(Object key) { + clearReferences(); + return map.get(new KeyRef<>(key)); + } + + @Override + public V put(K key, V value) { + clearReferences(); + map.put(new KeyRef<>(key), value); + return value; + } + + @Override + public V remove(Object key) { + clearReferences(); + return map.remove(new KeyRef<>(key)); + } + + @Override + public void putAll(Map m) { + clearReferences(); + for(Map.Entry e : m.entrySet()){ + map.put(new KeyRef<>(e.getKey()), e.getValue()); + } + } + + @Override + public void clear() { + map.clear(); + clearReferences(); + } + + @Override + public Set keySet() { + clearReferences(); + Set ret = new HashSet<>(); + for(KeyRef k : map.keySet() ){ + K key = k.get(); + if(key != null) + ret.add(key); + } + return ret; + } + + @Override + public Collection values() { + clearReferences(); + return map.values(); + } + + @Override + public Set> entrySet() { + clearReferences(); + Set> ret = new HashSet<>(); + for(Map.Entry, V> e : map.entrySet()){ + K k = e.getKey().get(); + if(k != null){ + ret.add(new Entry(k, e.getValue())); + } + } + return ret; + } + + + protected static class KeyRef extends WeakReference { + private final int hash; + public KeyRef(@NonNull K referent) { + super(referent); + this.hash = System.identityHashCode(referent); + } + + @Override + public int hashCode(){ + return hash; + } + + @Override + public boolean equals(Object o){ + if(this == o){ + return true; + } + if(o instanceof WeakReference){ + return this.get() == ((WeakReference) o).get(); + } + return false; + } + } + + @Data + @AllArgsConstructor + protected static class Entry implements Map.Entry { + protected K key; + protected V value; + + @Override + public V setValue(V value){ + this.value = value; + return value; + } + } +} diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index cf54d4357..caeb0d47b 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -163,7 +163,8 @@ public class ArrayUtil { } public static long[] nTimes(long n, long toReplicate) { - // FIXME: int cast + if (n > Integer.MAX_VALUE) + throw new RuntimeException("Index overflow in nTimes"); val ret = new long[(int) n]; Arrays.fill(ret, toReplicate); return ret; @@ -1329,8 +1330,6 @@ public class ArrayUtil { * @return the shape for tensor matrix multiply */ public static long[] getTensorMmulShape(long[] aShape, long[] bShape, int[][] axes) { - // FIXME: int cast - int validationLength = Math.min(axes[0].length, axes[1].length); for (int i = 0; i < validationLength; i++) { @@ -2970,7 +2969,9 @@ public class ArrayUtil { } public static long[] buildInterleavedVector(Random rng, long length) { - // FIXME: int cast + if (length > Integer.MAX_VALUE) { + throw new RuntimeException("Integer overflow"); + } val result = new long[(int) length]; List indexes = new ArrayList<>(); diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index bcdfba202..ec4739b86 100644 --- a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -21,8 +21,6 @@ import org.nd4j.config.ND4JEnvironmentVars; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.context.Nd4jContext; import org.nd4j.linalg.io.Resource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; @@ -152,6 +150,9 @@ public abstract class Nd4jBackend { */ public static Nd4jBackend load() throws NoAvailableBackendException { + String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true"); + boolean logInit = Boolean.parseBoolean(logInitProperty); + List backends = new ArrayList<>(1); ServiceLoader loader = ServiceLoader.load(Nd4jBackend.class); try { @@ -183,7 +184,9 @@ public abstract class Nd4jBackend { error = e.getMessage(); } if (!available) { - log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error); + if(logInit) { + log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error); + } continue; } @@ -193,7 +196,9 @@ public abstract class Nd4jBackend { e.printStackTrace(); } - log.info("Loaded [{}] backend", backend.getClass().getSimpleName()); + if(logInit) { + log.info("Loaded [{}] backend", backend.getClass().getSimpleName()); + } return backend; } @@ -273,6 +278,8 @@ public abstract class Nd4jBackend { return getClass().getName(); } + public abstract void logBackendInit(); + @SuppressWarnings("serial") public static class NoAvailableBackendException extends Exception { diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java index e063d16ae..b1ea7e76d 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java @@ -147,7 +147,7 @@ public class GraphInferenceGrpcClient { val arrOff = array.toFlatArray(builder); byte variableType = 0; //TODO is this OK here? - val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType); + val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType, 0, 0, 0); ins[cnt++] = varOff; } diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java index d8dd3f4df..1c84da010 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java @@ -43,7 +43,7 @@ public class GraphInferenceGrpcClientTest { val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE); // preparing and registering graph (it's optional, and graph might be embedded into Docker image - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build()); @@ -66,7 +66,7 @@ public class GraphInferenceGrpcClientTest { val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE); // preparing and registering graph (it's optional, and graph might be embedded into Docker image - val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream()); assertNotNull(tg); client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build()); diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java index 37dbfdb05..bf1c0ebc1 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java @@ -200,7 +200,7 @@ public class SameDiffServlet implements ModelServingServlet { map.put(n, mds.getFeatures(cnt++)); } - val output = sdModel.exec(map, orderedOutputNodes); + val output = sdModel.output(map, orderedOutputNodes); val arrays = new INDArray[output.size()]; // now we need to get ordered output arrays, as specified in server constructor diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java index 677326e86..e05a5b6f9 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java @@ -77,7 +77,6 @@ public class AeronNDArraySerdeTest { Nd4j.getCompressor().compressi(arr, "GZIP"); for (int i = 0; i < numTrials; i++) { StopWatch oldStopWatch = new StopWatch(); - // FIXME: int cast BufferedOutputStream bos = new BufferedOutputStream(new ByteArrayOutputStream((int) arr.length())); DataOutputStream dos = new DataOutputStream(bos); oldStopWatch.start(); diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java index 73bd00036..ee53c37e2 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/ProtoBufToFlatBufConversion.java @@ -56,7 +56,7 @@ public class ProtoBufToFlatBufConversion { */ public static void convert(String inFile, String outFile) throws IOException, org.nd4j.linalg.exception.ND4JIllegalStateException { - SameDiff tg = TFGraphMapper.getInstance().importGraph(new File(inFile)); + SameDiff tg = TFGraphMapper.importGraph(new File(inFile)); tg.asFlatFile(new File(outFile)); } @@ -90,7 +90,7 @@ public class ProtoBufToFlatBufConversion { }; - SameDiff sd = TFGraphMapper.getInstance().importGraph(new File(inFile), m, filter); + SameDiff sd = TFGraphMapper.importGraph(new File(inFile), m, filter); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/mul")) // .../dropout/mul diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index e72ca37ed..79a185a68 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -63,7 +63,7 @@ case class SDIndexWrapper(end: Long) { case class SDIndexWrapper1(start: Int) { def ::(end: Int): SDIndex = - SDIndex.interval(start, end) + SDIndex.interval(start.toLong, end.toLong) } object --- extends SDIndex {