From d86dd5b131404282d14cb089e0134299b0be2179 Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Wed, 8 Apr 2020 17:20:48 +0300 Subject: [PATCH] DL4J and SameDiff integration tests + LSTMLayer java op class (#353) * init in this branch Signed-off-by: Andrii Tuzhykov * Lenetet Mnist workflow Signed-off-by: Andrii Tuzhykov * small fix for calculations Signed-off-by: Andrii Tuzhykov * for Alex to check placeholder null pointer issue Signed-off-by: Andrii Tuzhykov * CNN3D workflow Signed-off-by: Andrii Tuzhykov * state for launching on dxg to regenterate dl4j examples Signed-off-by: Andrii Tuzhykov * SD RNN test case workflow Signed-off-by: Andrii Tuzhykov * small fixes Signed-off-by: Andrii Tuzhykov * checkpoint at lstmBlock: Input array 1 (x) rank must be got input with rank 2 issue Signed-off-by: Andrii Tuzhykov * Fix LSTMLayer inputs order Signed-off-by: Andrii Tuzhykov * lstm mismatch with c++ op issue Signed-off-by: Andrii Tuzhykov * LSTMLayer config draft Signed-off-by: Andrii Tuzhykov * LSTMLayer config draft v2 Signed-off-by: Andrii Tuzhykov * have doubt I had to do this Signed-off-by: Andrii Tuzhykov * NDRNN generated by codegen Signed-off-by: Andrii Tuzhykov * LSTMLayerTestCases draft Signed-off-by: Andrii Tuzhykov * minor fixes again * added LSTMLayer testcases to nd4j-tests + setted Preconditions in LSTMLayer constructors Signed-off-by: Andrii Tuzhykov * added lost SDCNNtestcases Signed-off-by: Andrii Tuzhykov * overrided getNumOutputs from DynamicCustomOp in LSTMLayer and reorganized LSTMLayerOutputs according to cpp op Signed-off-by: Andrii Tuzhykov * finished with LSTMLayerOutputs Signed-off-by: Andrii Tuzhykov * Fix MKLDNN platform checks (i.e., when MKLDNN can be used vs. not) Signed-off-by: Alex Black * Fix LSTMLayerWeights input order Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * minor fixes Signed-off-by: Andrii Tuzhykov * fixed LSTMLayer testcases Signed-off-by: Andrii Tuzhykov * finished SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov * finished all testcases + minor fixes Signed-off-by: Andrii Tuzhykov * Multiple generation-related fixes Signed-off-by: Alex Black * Fix multiple issues Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * LSTM fixes Signed-off-by: Alex Black * Regenerate ND4J namespaces and fix multiple issues Signed-off-by: Alex Black * changed SameDiffRNNTestCase Signed-off-by: Andrii Tuzhykov * Small fix Signed-off-by: Alex Black * added Nd4j.getRandom().setSeed(12345) where needed Signed-off-by: Andrii Tuzhykov * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black * #8828 Fix ND4J profiler NaN/Inf checks when using OpContext Signed-off-by: Alex Black * Tweak to weight init for SameDiff CNN test case Signed-off-by: Alex Black * Tweaks for test cases Signed-off-by: Alex Black * Ignore failing tests until fixed Signed-off-by: Alex Black * Fix Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../eval/EvaluationCalibration.java | 2 +- .../conf/layers/RecurrentAttentionLayer.java | 4 +- .../IntegrationTestBaselineGenerator.java | 86 +- .../integration/IntegrationTestsSameDiff.java | 17 + .../testcases/dl4j/CNN2DTestCases.java | 5 +- .../testcases/dl4j/MLPTestCases.java | 7 +- .../testcases/dl4j/RNNTestCases.java | 8 +- .../testcases/dl4j/UnsupervisedTestCases.java | 4 +- .../testcases/samediff/SameDiffCNNCases.java | 398 + .../samediff/SameDiffMLPTestCases.java | 188 +- .../samediff/SameDiffRNNTestCases.java | 289 + .../declarable/platform/mkldnn/lstmLayer.cpp | 20 +- .../DifferentialFunctionFactory.java | 2 +- .../nd4j/autodiff/samediff/SDVariable.java | 7 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 271 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 8103 ++++++++++------- .../org/nd4j/autodiff/samediff/ops/SDCNN.java | 29 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 12 +- .../org/nd4j/autodiff/samediff/ops/SDRNN.java | 233 +- .../src/main/java/org/nd4j/enums/CellAct.java | 45 + .../src/main/java/org/nd4j/enums/GateAct.java | 45 + .../java/org/nd4j/enums/LSTMDataFormat.java | 36 + .../org/nd4j/enums/LSTMDirectionMode.java | 38 + .../src/main/java/org/nd4j/enums/OutAct.java | 45 + .../java/org/nd4j/enums/RnnDataFormat.java | 32 + .../converters/ImportClassMapping.java | 1 + .../ops/executioner/DefaultOpExecutioner.java | 68 +- .../ops/executioner/OpExecutionerUtil.java | 37 +- .../layers/convolution/MaxPoolWithArgmax.java | 8 +- .../ops/impl/layers/convolution/SConv2D.java | 2 +- .../ops/impl/layers/recurrent/LSTMBlock.java | 144 + .../ops/impl/layers/recurrent/LSTMLayer.java | 173 +- .../recurrent/config/LSTMActivations.java | 48 + .../recurrent/config/LSTMDataFormat.java | 41 + .../recurrent/config/LSTMDirectionMode.java | 38 + .../recurrent/config/LSTMLayerConfig.java | 119 + .../recurrent/outputs/LSTMLayerOutputs.java | 190 +- .../recurrent/weights/LSTMLayerWeights.java | 99 + .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 2 + .../api/ops/impl/reduce/custom/BatchMmul.java | 20 + .../linalg/api/ops/impl/shape/GatherNd.java | 13 +- .../linalg/api/ops/impl/shape/Linspace.java | 21 +- .../linalg/api/ops/impl/shape/MeshGrid.java | 7 + .../linalg/api/ops/impl/shape/Reshape.java | 7 +- .../api/ops/impl/shape/SequenceMask.java | 11 +- .../nd4j/linalg/api/ops/impl/shape/Slice.java | 4 + .../nd4j/linalg/api/ops/impl/shape/Stack.java | 2 +- .../api/ops/impl/shape/StridedSlice.java | 12 +- .../linalg/api/ops/impl/shape/Unstack.java | 11 +- .../linalg/api/ops/impl/transforms/Pad.java | 4 + .../transforms/custom/DynamicPartition.java | 5 +- .../ops/impl/transforms/custom/ListDiff.java | 8 +- .../ops/impl/transforms/custom/XwPlusB.java | 8 +- .../api/ops/impl/transforms/dtype/Cast.java | 22 +- .../linalg/api/ops/random/impl/Range.java | 7 + .../org/nd4j/linalg/factory/ops/NDBase.java | 214 +- .../org/nd4j/linalg/factory/ops/NDCNN.java | 14 +- .../org/nd4j/linalg/factory/ops/NDLoss.java | 7 +- .../org/nd4j/linalg/factory/ops/NDMath.java | 106 +- .../org/nd4j/linalg/factory/ops/NDNN.java | 46 +- .../org/nd4j/linalg/factory/ops/NDRNN.java | 106 +- .../ops/executioner/CudaExecutioner.java | 26 +- .../nativecpu/ops/NativeOpExecutioner.java | 13 +- .../opvalidation/LayerOpValidation.java | 251 +- .../opvalidation/MiscOpValidation.java | 10 +- .../opvalidation/RnnOpValidation.java | 6 +- .../opvalidation/ShapeOpValidation.java | 44 +- .../opvalidation/TransformOpValidation.java | 2 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 85 +- .../nd4j/linalg/nativ/OpsMappingTests.java | 2 +- .../profiling/OperationProfilerTests.java | 55 + .../java/org/nd4j/linalg/util/ArrayUtil.java | 15 + 72 files changed, 8063 insertions(+), 3997 deletions(-) create mode 100644 deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java create mode 100644 deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java index 4a4299042..bda8b21b2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/EvaluationCalibration.java @@ -25,7 +25,7 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; */ @Deprecated @Getter -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = true) public class EvaluationCalibration extends org.nd4j.evaluation.classification.EvaluationCalibration implements org.deeplearning4j.eval.IEvaluation { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index e4e5b7d21..d12e0ec74 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -185,7 +185,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer { final val R = paramTable.get(RECURRENT_WEIGHT_KEY); final val b = paramTable.get(BIAS_KEY); - SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2); + long[] shape = layerInput.getShape(); + Preconditions.checkState(shape != null, "Null shape for input placeholder"); + SDVariable[] inputSlices = sameDiff.unstack(layerInput, 2, (int)shape[2]); this.timeSteps = inputSlices.length; SDVariable[] outputSlices = new SDVariable[timeSteps]; SDVariable prev = null; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index a493337c8..01b3b2e53 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -20,7 +20,10 @@ package org.deeplearning4j.integration; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; +import org.deeplearning4j.integration.testcases.dl4j.*; +import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; +import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -66,14 +69,36 @@ public class IntegrationTestBaselineGenerator { } runGeneration( - SameDiffMLPTestCases.getMLPMnist() + + // DL4J integration test cases. + +// CNN1DTestCases.getCnn1dTestCaseCharRNN(), +// CNN2DTestCases.testLenetTransferDropoutRepeatability(), +//// CNN2DTestCases.getCnn2DSynthetic(), +// CNN2DTestCases.getLenetMnist(), +// CNN2DTestCases.getVGG16TransferTinyImagenet(), +// CNN2DTestCases.getYoloHouseNumbers(), +// CNN3DTestCases.getCnn3dTestCaseSynthetic(), +// MLPTestCases.getMLPMnist(), +// MLPTestCases.getMLPMoon(), +// RNNTestCases.getRnnCharacterTestCase(), +// RNNTestCases.getRnnCsvSequenceClassificationTestCase1(), +// RNNTestCases.getRnnCsvSequenceClassificationTestCase2(), +// UnsupervisedTestCases.getVAEMnistAnomaly(), + + // Samediff test cases done + SameDiffMLPTestCases.getMLPMnist(), + SameDiffMLPTestCases.getMLPMoon(), + SameDiffCNNCases.getLenetMnist(), + SameDiffCNNCases.getCnn3dSynthetic(), + SameDiffRNNTestCases.getRnnCsvSequenceClassificationTestCase1() ); } private static void runGeneration(TestCase... testCases) throws Exception { - for( TestCase tc : testCases ) { + for (TestCase tc : testCases) { final ModelType modelType = tc.modelType(); //Basic validation: @@ -122,18 +147,18 @@ public class IntegrationTestBaselineGenerator { mln = new MultiLayerNetwork(mlc); mln.init(); m = mln; - } else if (config instanceof ComputationGraphConfiguration){ + } else if (config instanceof ComputationGraphConfiguration) { ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; json = cgc.toJson(); cg = new ComputationGraph(cgc); cg.init(); m = cg; } else { - sd = (SameDiff)config; + sd = (SameDiff) config; } File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME); - if(modelType != ModelType.SAMEDIFF) { + if (modelType != ModelType.SAMEDIFF) { File configFile = new File(testBaseDir, "config." + (modelType == ModelType.MLN ? "mlc.json" : "cgc.json")); FileUtils.writeStringToFile(configFile, json, StandardCharsets.UTF_8); log.info("RANDOM_INIT test - saved configuration: {}", configFile.getAbsolutePath()); @@ -147,10 +172,10 @@ public class IntegrationTestBaselineGenerator { m = tc.getPretrainedModel(); if (m instanceof MultiLayerNetwork) { mln = (MultiLayerNetwork) m; - } else if(m instanceof ComputationGraph){ + } else if (m instanceof ComputationGraph) { cg = (ComputationGraph) m; } else { - sd = (SameDiff)m; + sd = (SameDiff) m; } } @@ -158,7 +183,7 @@ public class IntegrationTestBaselineGenerator { //Generate predictions to compare against if (tc.isTestPredictions()) { List> inputs = modelType != ModelType.SAMEDIFF ? tc.getPredictionsTestData() : null; - List> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null; + List> inputsSd = modelType == ModelType.SAMEDIFF ? tc.getPredictionsTestDataSameDiff() : null; // Preconditions.checkState(inputs != null && inputs.size() > 0, "Input data is null or length 0 for test: %s", tc.getTestName()); @@ -178,7 +203,7 @@ public class IntegrationTestBaselineGenerator { Nd4j.write(out, dos); } } - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { for (Pair p : inputs) { INDArray[] out = cg.output(false, p.getFirst(), p.getSecond(), null); @@ -192,11 +217,11 @@ public class IntegrationTestBaselineGenerator { } } else { List outNames = tc.getPredictionsNamesSameDiff(); - for( Map ph : inputsSd ){ - Map out = sd.output(ph, outNames); + for (Map ph : inputsSd) { + Map out = sd.output(ph, outNames); //Save the output... - for(String s : outNames){ + for (String s : outNames) { File f = new File(predictionsTestDir, "output_" + (count++) + "_" + s + ".bin"); try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(f))) { Nd4j.write(out.get(s), dos); @@ -211,7 +236,7 @@ public class IntegrationTestBaselineGenerator { //Compute and save gradients: if (tc.isTestGradients()) { INDArray gradientFlat = null; - Map grad; + Map grad; if (modelType == ModelType.MLN) { MultiDataSet data = tc.getGradientsTestData(); mln.setInput(data.getFeatures(0)); @@ -220,7 +245,7 @@ public class IntegrationTestBaselineGenerator { mln.computeGradientAndScore(); gradientFlat = mln.getFlattenedGradients(); grad = m.gradient().gradientForVariable(); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { MultiDataSet data = tc.getGradientsTestData(); cg.setInputs(data.getFeatures()); cg.setLabels(data.getLabels()); @@ -229,17 +254,17 @@ public class IntegrationTestBaselineGenerator { gradientFlat = cg.getFlattenedGradients(); grad = m.gradient().gradientForVariable(); } else { - Map ph = tc.getGradientsTestDataSameDiff(); + Map ph = tc.getGradientsTestDataSameDiff(); List allVars = new ArrayList<>(); - for(SDVariable v : sd.variables()){ - if(v.getVariableType() == VariableType.VARIABLE){ + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE) { allVars.add(v.name()); } } grad = sd.calculateGradients(ph, allVars); } - if(modelType != ModelType.SAMEDIFF) { + if (modelType != ModelType.SAMEDIFF) { File gFlatFile = new File(testBaseDir, IntegrationTestRunner.FLAT_GRADIENTS_FILENAME); IntegrationTestRunner.write(gradientFlat, gFlatFile); } @@ -254,25 +279,25 @@ public class IntegrationTestBaselineGenerator { } //Test pretraining - if(tc.isTestUnsupervisedTraining()){ + if (tc.isTestUnsupervisedTraining()) { log.info("Performing layerwise pretraining"); MultiDataSetIterator iter = tc.getUnsupervisedTrainData(); INDArray paramsPostTraining; - if(modelType == ModelType.MLN){ + if (modelType == ModelType.MLN) { int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); - for( int i : layersToTrain){ + for (int i : layersToTrain) { mln.pretrainLayer(i, dsi); } paramsPostTraining = mln.params(); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); - for( String i : layersToTrain){ + for (String i : layersToTrain) { cg.pretrainLayer(i, iter); } paramsPostTraining = cg.params(); @@ -290,20 +315,20 @@ public class IntegrationTestBaselineGenerator { MultiDataSetIterator trainData = tc.getTrainingData(); CollectScoresListener l = new CollectScoresListener(1); - if(modelType != ModelType.SAMEDIFF) + if (modelType != ModelType.SAMEDIFF) m.setListeners(l); History h = null; if (modelType == ModelType.MLN) { mln.fit(trainData); - } else if(modelType == ModelType.CG) { + } else if (modelType == ModelType.CG) { cg.fit(trainData); } else { h = sd.fit(trainData, 1); } double[] scores; - if(modelType != ModelType.SAMEDIFF){ + if (modelType != ModelType.SAMEDIFF) { scores = l.getListScore().toDoubleArray(); } else { scores = h.lossCurve().getLossValues().toDoubleVector(); @@ -314,11 +339,11 @@ public class IntegrationTestBaselineGenerator { FileUtils.writeStringToFile(f, String.join(",", s), StandardCharsets.UTF_8); if (tc.isTestParamsPostTraining()) { - if(modelType == ModelType.SAMEDIFF){ + if (modelType == ModelType.SAMEDIFF) { File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR); p.mkdirs(); - for(SDVariable v : sd.variables()){ - if(v.getVariableType() == VariableType.VARIABLE){ + for (SDVariable v : sd.variables()) { + if (v.getVariableType() == VariableType.VARIABLE) { INDArray arr = v.getArr(); File p2 = new File(p, v.name() + ".bin"); IntegrationTestRunner.write(arr, p2); @@ -331,7 +356,6 @@ public class IntegrationTestBaselineGenerator { } } - if (tc.isTestEvaluation()) { IEvaluation[] evals = tc.getNewEvaluations(); MultiDataSetIterator iter = tc.getEvaluationTestData(); @@ -339,7 +363,7 @@ public class IntegrationTestBaselineGenerator { if (modelType == ModelType.MLN) { DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); mln.doEvaluation(dsi, evals); - } else if(modelType == ModelType.CG){ + } else if (modelType == ModelType.CG) { cg.doEvaluation(iter, evals); } else { evals = tc.doEvaluationSameDiff(sd, iter, evals); diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java index f16a5e187..de5bc0ea1 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java @@ -16,6 +16,7 @@ package org.deeplearning4j.integration; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.junit.Rule; import org.junit.Test; @@ -37,4 +38,20 @@ public class IntegrationTestsSameDiff extends BaseDL4JTest { IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMnist(), testDir); } + @Test + public void testMLPMoon() throws Exception { + IntegrationTestRunner.runTest(SameDiffMLPTestCases.getMLPMoon(), testDir); + } + + @Test + public void testLenetMnist() throws Exception { + IntegrationTestRunner.runTest(SameDiffCNNCases.getLenetMnist(), testDir); + } + + @Test + public void testCnn3dSynthetic() throws Exception { + IntegrationTestRunner.runTest(SameDiffCNNCases.getCnn3dSynthetic(), testDir); + } + + } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java index 71336c0a6..b857b2fb2 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java @@ -194,6 +194,8 @@ public class CNN2DTestCases { testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb) testEvaluation = false; testOverfitting = false; + maxRelativeErrorOutput = 0.2; + minAbsErrorOutput = 0.05; //Max value is around 0.22 } @Override @@ -314,6 +316,7 @@ public class CNN2DTestCases { ComputationGraph model = new TransferLearning.GraphBuilder(pretrained) .fineTuneConfiguration(fineTuneConf) .removeVertexKeepConnections("conv2d_9") + .removeVertexAndConnections("outputs") .addLayer("convolution2d_9", new ConvolutionLayer.Builder(1,1) .nIn(1024) @@ -393,7 +396,7 @@ public class CNN2DTestCases { @Override public ModelType modelType() { - return ModelType.CG; + return ModelType.MLN; } @Override diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java index 232219f04..4264531aa 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java @@ -77,6 +77,10 @@ public class MLPTestCases { testOverfitting = true; maxRelativeErrorOverfit = 2e-2; minAbsErrorOverfit = 1e-2; + maxRelativeErrorGradients = 0.01; + minAbsErrorGradients = 0.05; + maxRelativeErrorParamsPostTraining = 0.01; + minAbsErrorParamsPostTraining = 0.05; } @Override @@ -135,8 +139,7 @@ public class MLPTestCases { public IEvaluation[] getNewEvaluations(){ return new IEvaluation[]{ new Evaluation(), - new ROCMultiClass(), - new EvaluationCalibration() + new ROCMultiClass() }; } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java index f89643380..29f382735 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java @@ -24,6 +24,7 @@ import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.shade.guava.io.Files; import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.testcases.dl4j.misc.CharacterIterator; @@ -91,7 +92,7 @@ public class RNNTestCases { } private int miniBatchSize = 32; - private int exampleLength = 1000; + private int exampleLength = 200; @Override @@ -101,6 +102,7 @@ public class RNNTestCases { @Override public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); int nOut = iter.totalOutcomes(); @@ -113,7 +115,7 @@ public class RNNTestCases { .seed(12345) .l2(0.001) .weightInit(WeightInit.XAVIER) - .updater(new RmsProp(0.1)) + .updater(new Adam(1e-3)) .list() .layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .activation(Activation.TANH).build()) @@ -140,7 +142,7 @@ public class RNNTestCases { @Override public MultiDataSetIterator getTrainingData() throws Exception { DataSetIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); - iter = new EarlyTerminationDataSetIterator(iter, 2); //3 minibatches, 1000/200 = 5 updates per minibatch + iter = new EarlyTerminationDataSetIterator(iter, 2); //2 minibatches, 200/50 = 4 updates per minibatch return new MultiDataSetIteratorAdapter(iter); } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java index 622a6e9cf..b627f06dc 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java @@ -72,12 +72,12 @@ public class UnsupervisedTestCases { return new NeuralNetConfiguration.Builder() .dataType(DataType.FLOAT) .seed(12345) - .updater(new Adam(0.05)) + .updater(new Adam(1e-3)) .weightInit(WeightInit.XAVIER) .l2(1e-4) .list() .layer(0, new VariationalAutoencoder.Builder() - .activation(Activation.LEAKYRELU) + .activation(Activation.TANH) .encoderLayerSizes(256, 256) //2 encoder layers, each of size 256 .decoderLayerSizes(256, 256) //2 decoder layers, each of size 256 .pzxActivationFunction(Activation.IDENTITY) //p(z|data) activation function diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java new file mode 100644 index 000000000..74c4f3bfb --- /dev/null +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java @@ -0,0 +1,398 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.integration.testcases.samediff; + +import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; +import org.deeplearning4j.integration.ModelType; +import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Nesterovs; + +import java.util.*; + +public class SameDiffCNNCases { + + + public static TestCase getLenetMnist() { + return new TestCase() { + { + testName = "LenetMnistSD"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = false; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + int nChannels = 1; // Number of input channels + int outputNum = 10; // The number of possible outcomes + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, outputNum); + + //input [minibatch, channels=1, Height = 28, Width = 28] + SDVariable in4d = in.reshape(-1, nChannels, 28, 28); + + int kernelHeight = 5; + int kernelWidth = 5; + + + // w0 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 1, outputChannels = 20] + // b0 [20] + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, nChannels, 20).muli(0.01)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 20).muli(0.01)); + + + SDVariable layer0 = sd.nn.relu(sd.cnn.conv2d("layer0", in4d, w0, b0, Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(1) + .sW(1) + .dataFormat("NCHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 28 - 5 + 2*0 ) / 1 + 1 = 24 + // [minibatch,20,24,24] + + + SDVariable layer1 = sd.cnn.maxPooling2d("layer1", layer0, Pooling2DConfig.builder() + .kH(2).kW(2) + .sH(2).sW(2) + .isNHWC(false) + .build()); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 24 - 2 + 2*0 ) / 2 + 1 = 12 + // [minibatch,12,12,20] + + + // w2 [kernelHeight = 5, kernelWidth = 5 , inputChannels = 20, outputChannels = 50] + // b0 [50] + SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, kernelHeight, kernelWidth, 20, 50).muli(0.01)); + SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 50).muli(0.01)); + + + SDVariable layer2 = sd.nn.relu(sd.cnn.conv2d("layer2", layer1, w2, b2, Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(1) + .sW(1) + .dataFormat("NCHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 12 - 5 + 2*0 ) / 1 + 1 = 8 + // [minibatch,8,8,50] + + + SDVariable layer3 = sd.cnn.maxPooling2d("layer3", layer2, Pooling2DConfig.builder() + .kH(2).kW(2) + .sH(2).sW(2) + .isNHWC(false) + .build()); + + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W) = ( 8 - 2 + 2*0 ) / 2 + 1 = 4 + // [minibatch,4,4,50] + + int channels_height_width = 4 * 4 * 50; + SDVariable layer3_reshaped = layer3.reshape(-1, channels_height_width); + + SDVariable w4 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width, 500).muli(0.01)); + SDVariable b4 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 500).muli(0.01)); + + + SDVariable layer4 = sd.nn.relu("layer4", layer3_reshaped.mmul(w4).add(b4), 0); + + SDVariable w5 = sd.var("w5", Nd4j.rand(DataType.FLOAT, 500, outputNum)); + SDVariable b5 = sd.var("b5", Nd4j.rand(DataType.FLOAT, outputNum)); + + SDVariable out = sd.nn.softmax("out", layer4.mmul(w5).add(b5)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(1e-3)) + .l2(1e-3) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + + return sd; + + + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + DataSet ds = new MnistDataSetIterator(8, true, 12345).next(); + Map map = new HashMap<>(); + map.put("in", ds.getFeatures()); + map.put("label", ds.getLabels()); + return map; + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); + + iter = new EarlyTerminationDataSetIterator(iter, 60); + return new MultiDataSetIteratorAdapter(iter); + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10)); + } + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + DataSetIterator iter = new MnistDataSetIterator(8, true, 12345); + + List> list = new ArrayList<>(); + + org.nd4j.linalg.dataset.DataSet ds = iter.next(); + ds = ds.asList().get(0); + + list.add(Collections.singletonMap("in", ds.getFeatures())); + ds = iter.next(); + list.add(Collections.singletonMap("in", ds.getFeatures())); + return list; + } + + @Override + public List getPredictionsNamesSameDiff() { + return Collections.singletonList("out"); + + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration()}; + } + + + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + }; + } + + + public static TestCase getCnn3dSynthetic() { + return new TestCase() { + { + testName = "Cnn3dSynthetic"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = false; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + int nChannels = 3; // Number of input channels + int outputNum = 10; // The number of possible outcomes + + SameDiff sd = SameDiff.create(); + + + //input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8] + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, nChannels, 8, 8, 8); + + SDVariable label = sd.placeHolder("label", DataType.FLOAT, nChannels, outputNum); + + //input in NCDHW [minibatch, channels=3, Height = 8, Width = 8, Depth = 8] + + // Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels] + // [kernelDepth = 3, kernelHeight = 3, kernelWidth = 3, inputChannels = 3, outputChannels = 8] + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 3, 3, 3, nChannels, 8)); + // Optional 1D bias array with shape [outputChannels]. May be null. + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 8)); + + + SDVariable layer0 = sd.nn.relu(sd.cnn.conv3d("layer0", in, w0, b0, Conv3DConfig.builder() + .kH(3) + .kW(3) + .kD(3) + .sH(2) + .sW(2) + .sD(2) + .dataFormat("NCDHW") + .build()), 0); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W)(D) = (8 - 3 + 2*0 ) / 2 + 1 = 3 + // [minibatch,8,3,3,3] + + + SDVariable layer1 = sd.cnn.maxPooling3d("layer1", layer0, Pooling3DConfig.builder() + .kH(2).kW(2).kD(2) + .sH(2).sW(2).sD(2) + .isNCDHW(true) + .build()); + + // outputSize = (inputSize - kernelSize + 2*padding) / stride + 1 + // outputsize_H(W)(D) = ( 3 - 2 + 2*0 ) / 2 + 1 = 1 + // [minibatch,8,1,1,1] + + + int channels_height_width_depth = 8 * 1 * 1 * 1; + + SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth); + + SDVariable w1 = sd.var("w4", Nd4j.rand(DataType.FLOAT, channels_height_width_depth, 10)); + SDVariable b1 = sd.var("b4", Nd4j.rand(DataType.FLOAT, 10)); + + + SDVariable out = sd.nn.softmax("out", layer1_reshaped.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Nesterovs(0.01, 0.9)) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + Nd4j.getRandom().setSeed(12345); + //NCDHW format + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); + + Map map = new HashMap<>(); + map.put("in", arr); + map.put("label", labels); + return map; + + } + + + + @Override + public List getPredictionsNamesSameDiff() { + + return Collections.singletonList("out"); + + } + + + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + Nd4j.getRandom().setSeed(12345); + + List> list = new ArrayList<>(); + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + + list.add(Collections.singletonMap("in", arr)); + + return list; + } + + @Override + public MultiDataSet getGradientsTestData() throws Exception { + Nd4j.getRandom().setSeed(12345); + //NCDHW format + INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); + return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels); + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + return new SingletonMultiDataSetIterator(getGradientsTestData()); + } + + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + return getTrainingData(); + } + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations){ + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + @Override + public IEvaluation[] getNewEvaluations(){ + return new IEvaluation[]{new Evaluation()}; + } + + + }; + + } +} \ No newline at end of file diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java index ced461089..9761c87b0 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java @@ -15,9 +15,14 @@ ******************************************************************************/ package org.deeplearning4j.integration.testcases.samediff; +import org.datavec.api.records.reader.RecordReader; +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; +import org.datavec.api.split.FileSplit; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.TestCase; import org.nd4j.autodiff.loss.LossReduce; @@ -26,21 +31,34 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.resources.Resources; +import java.io.File; import java.util.*; +import static org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig.*; + public class SameDiffMLPTestCases { - public static TestCase getMLPMnist(){ + public static TestCase getMLPMnist() { return new TestCase() { { testName = "MLPMnistSD"; @@ -69,10 +87,10 @@ public class SameDiffMLPTestCases { SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 784); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); - SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256)); - SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256)); - SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10)); - SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, 784, 256).muli(0.1)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, 256).muli(0.1)); + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 256, 10).muli(0.1)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10).muli(0.1)); SDVariable a0 = sd.nn.tanh(in.mmul(w0).add(b0)); SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); @@ -91,7 +109,7 @@ public class SameDiffMLPTestCases { @Override public List> getPredictionsTestDataSameDiff() throws Exception { - List> out = new ArrayList<>(); + List> out = new ArrayList<>(); DataSetIterator iter = new MnistDataSetIterator(1, true, 12345); out.add(Collections.singletonMap("in", iter.next().getFeatures())); @@ -110,7 +128,7 @@ public class SameDiffMLPTestCases { @Override public Map getGradientsTestDataSameDiff() throws Exception { DataSet ds = new MnistDataSetIterator(8, true, 12345).next(); - Map map = new HashMap<>(); + Map map = new HashMap<>(); map.put("in", ds.getFeatures()); map.put("label", ds.getLabels()); return map; @@ -153,4 +171,160 @@ public class SameDiffMLPTestCases { }; } + + public static TestCase getMLPMoon() { + return new TestCase() { + { + testName = "MLPMoonSD"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = true; + testGradients = true; + testParamsPostTraining = true; + testEvaluation = true; + testOverfitting = true; + maxRelativeErrorOverfit = 2e-2; + minAbsErrorOverfit = 1e-2; + } + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + @Override + public Object getConfiguration() throws Exception { + + int numInputs = 2; + int numOutputs = 2; + int numHiddenNodes = 20; + double learningRate = 0.005; + + + Nd4j.getRandom().setSeed(12345); + + //Define the network structure: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, numInputs); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, numOutputs); + + SDVariable w0 = sd.var("w0", Nd4j.rand(DataType.FLOAT, numInputs, numHiddenNodes)); + SDVariable b0 = sd.var("b0", Nd4j.rand(DataType.FLOAT, numHiddenNodes)); + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numHiddenNodes, numOutputs)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numOutputs)); + + SDVariable a0 = sd.nn.relu(in.mmul(w0).add(b0), 0); + SDVariable out = sd.nn.softmax("out", a0.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Nesterovs(learningRate, 0.9)) + .weightDecay(1e-3, true) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + } + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + List> out = new ArrayList<>(); + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 0, 2); + + out.add(Collections.singletonMap("in", iter.next().getFeatures())); + + + return out; + } + + + @Override + public List getPredictionsNamesSameDiff() throws Exception { + return Collections.singletonList("out"); + } + + @Override + public Map getGradientsTestDataSameDiff() throws Exception { + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + org.nd4j.linalg.dataset.DataSet ds = new RecordReaderDataSetIterator(rr, 5, 0, 2).next(); + + Map map = new HashMap<>(); + map.put("in", ds.getFeatures()); + map.put("label", ds.getLabels()); + return map; + } + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_train.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2); + + iter = new EarlyTerminationDataSetIterator(iter, 32); + return new MultiDataSetIteratorAdapter(iter); + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration()}; + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + DataSetIterator iter = new RecordReaderDataSetIterator(rr, 32, 0, 2); + return new MultiDataSetIteratorAdapter(iter); + } + + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + + @Override + public MultiDataSet getOverfittingData() throws Exception { + + File f = Resources.asFile("dl4j-integration-tests/data/moon_data_eval.csv"); + RecordReader rr = new CSVRecordReader(); + rr.initialize(new FileSplit(f)); + return new RecordReaderDataSetIterator(rr, 1, 0, 2).next().toMultiDataSet(); + } + + @Override + public int getOverfitNumIterations() { + return 200; + } + }; + + } } + + + + + + + + + + + + diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java new file mode 100644 index 000000000..6bc6254c9 --- /dev/null +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java @@ -0,0 +1,289 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.integration.testcases.samediff; + +import org.datavec.api.records.reader.SequenceRecordReader; +import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; +import org.datavec.api.split.NumberedFileInputSplit; +import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; +import org.deeplearning4j.integration.ModelType; +import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization; +import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.resources.Resources; +import org.nd4j.shade.guava.io.Files; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class SameDiffRNNTestCases { + + public static TestCase getRnnCsvSequenceClassificationTestCase1() { + return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1(); + } + + protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase { + protected RnnCsvSequenceClassificationTestCase1() { + testName = "RnnCsvSequenceClassification1"; + testType = TestType.RANDOM_INIT; + testPredictions = true; + testTrainingCurves = false; + testGradients = false; + testParamsPostTraining = false; + testEvaluation = true; + testOverfitting = false; //Not much point on this one - it already fits very well... + } + + + protected MultiDataNormalization normalizer; + + protected MultiDataNormalization getNormalizer() throws Exception { + if (normalizer != null) { + return normalizer; + } + + normalizer = new MultiNormalizerStandardize(); + normalizer.fit(getTrainingDataUnnormalized()); + + return normalizer; + } + + + @Override + public ModelType modelType() { + return ModelType.SAMEDIFF; + } + + + @Override + public Object getConfiguration() throws Exception { + Nd4j.getRandom().setSeed(12345); + + + int miniBatchSize = 10; + int numLabelClasses = 6; + int nIn = 60; + int numUnits = 7; + int timeSteps = 3; + + + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.placeHolder("in", DataType.FLOAT, miniBatchSize, timeSteps, nIn); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, miniBatchSize, numLabelClasses); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, miniBatchSize, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NTS) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))) + .build(), + c), c); + + +// Behaviour with default settings: 3d (time series) input with shape +// [miniBatchSize, vectorSize, timeSeriesLength] -> 2d output [miniBatchSize, vectorSize] + SDVariable layer0 = outputs.getOutput(); + + SDVariable layer1 = layer0.mean(1); + + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, numUnits, numLabelClasses)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, numLabelClasses)); + + + SDVariable out = sd.nn.softmax("out", layer1.mmul(w1).add(b1)); + SDVariable loss = sd.loss.logLoss("loss", label, out); + + //Also set the training configuration: + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(5e-2)) + .l1(1e-3).l2(1e-3) + .dataSetFeatureMapping("in") //features[0] -> "in" placeholder + .dataSetLabelMapping("label") //labels[0] -> "label" placeholder + .build()); + + return sd; + + } + + + @Override + public List> getPredictionsTestDataSameDiff() throws Exception { + + MultiDataSet mds = getTrainingData().next(); + + List> list = new ArrayList<>(); + + list.add(Collections.singletonMap("in", mds.getFeatures()[0].reshape(10, 1, 60))); + //[batchsize, insize] + + return list; + } + + @Override + public List getPredictionsNamesSameDiff() throws Exception { + return Collections.singletonList("out"); + } + + + @Override + public MultiDataSetIterator getTrainingData() throws Exception { + MultiDataSetIterator iter = getTrainingDataUnnormalized(); + MultiDataSetPreProcessor pp = multiDataSet -> { + INDArray l = multiDataSet.getLabels(0); + l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); + multiDataSet.setLabels(0, l); + multiDataSet.setLabelsMaskArray(0, null); + }; + + + iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); + + return iter; + } + + protected MultiDataSetIterator getTrainingDataUnnormalized() throws Exception { + int miniBatchSize = 10; + int numLabelClasses = 6; + + File featuresDirTrain = Files.createTempDir(); + File labelsDirTrain = Files.createTempDir(); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/features/", featuresDirTrain); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/train/labels/", labelsDirTrain); + + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); + trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); + SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); + trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); + + DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(trainData); + + return iter; + } + + @Override + public IEvaluation[] getNewEvaluations() { + return new IEvaluation[]{ + new Evaluation(), + new ROCMultiClass(), + new EvaluationCalibration() + }; + } + + @Override + public MultiDataSetIterator getEvaluationTestData() throws Exception { + int miniBatchSize = 10; + int numLabelClasses = 6; + +// File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); +// File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); + File featuresDirTest = Files.createTempDir(); + File labelsDirTest = Files.createTempDir(); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest); + Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest); + + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); + trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); + SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); + trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); + + DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, + false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); + + MultiDataSetPreProcessor pp = multiDataSet -> { + INDArray l = multiDataSet.getLabels(0); + l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); + multiDataSet.setLabels(0, l); + multiDataSet.setLabelsMaskArray(0, null); + }; + + + iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); + + return iter; + } + + @Override + public IEvaluation[] doEvaluationSameDiff(SameDiff sd, MultiDataSetIterator iter, IEvaluation[] evaluations) { + sd.evaluate(iter, "out", 0, evaluations); + return evaluations; + } + } + + +} + + + + + + + + + + + + + diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 94c795401..d09a40120 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -368,7 +368,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); - REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); count = 0; auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output @@ -464,13 +464,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { } PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto x = INPUT_VARIABLE(0); // input const auto Wx = INPUT_VARIABLE(1); // input weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights @@ -495,7 +503,15 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType; - return block.isUseMKLDNN() && ( + auto featuresSupported = (cellClip == 0) //Cell clipping not supported + && retFullSeq //Always return full sequence in case of MKL DNN + && !hasPH //Peephole connections not supported in MKL DNN + && !hasSeqLen //Sequence length array not supported in MKL DNN + && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] + && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat + && retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other) + + return block.isUseMKLDNN() && featuresSupported && ( (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index fcb63ea0a..093e3099b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2148,7 +2148,7 @@ public class DifferentialFunctionFactory { public SDVariable gatherNd(SDVariable df, SDVariable indices) { validateDifferentialFunctionsameDiff(df); - return new GatherNd(sameDiff(), df, indices, false).outputVariable(); + return new GatherNd(sameDiff(), df, indices).outputVariable(); } public SDVariable trace(SDVariable in){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 65416a659..3b29e6ccb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.weightinit.WeightInitScheme; import java.io.Serializable; @@ -244,7 +245,7 @@ public class SDVariable implements Serializable { * @return new variable */ public SDVariable assign(Number value){ - return sameDiff.scalarSet(this, value); + return sameDiff.scalarSet(this, value.doubleValue()); } /** @@ -538,7 +539,7 @@ public class SDVariable implements Serializable { * @return Output variable (result of mmul) */ public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) { - return sameDiff.mmul(name, this, other, mMulTranspose); + return sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult()); } @@ -1403,7 +1404,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable reshape(int... newShape){ - return sameDiff.reshape(this, newShape); + return sameDiff.reshape(this, ArrayUtil.toLongArray(newShape)); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index ab3279fd0..c51ac28a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -78,6 +79,7 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ND4JFileUtils; import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Sets; import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.weightinit.WeightInitScheme; @@ -104,7 +106,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; *

* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} */ -@AllArgsConstructor @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; @@ -914,6 +915,8 @@ public class SameDiff extends SDBaseOps { } private SameDiff() { + super(null); + super.sd = this; functionFactory = new DifferentialFunctionFactory(this); sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); @@ -4544,7 +4547,7 @@ public class SameDiff extends SDBaseOps { } //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user - if (v.getOutputOfOp() != null) { + if (v.getOutputOfOp() != null && v.getVariable().dataType().isFPType()) { String opName = v.getOutputOfOp(); SameDiffOp o = ops.get(opName); if (o.getOp() instanceof Assert) { @@ -4621,12 +4624,6 @@ public class SameDiff extends SDBaseOps { return varToUpdate; } - @Override - protected SameDiff sd() { - //Helper method for SDBaseOps etc - return this; - } - /** * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. @@ -5840,7 +5837,6 @@ public class SameDiff extends SDBaseOps { * See {@link #generateNewVarName(String, int, boolean)} * existingOp is true. */ - @Override public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } @@ -5868,4 +5864,261 @@ public class SameDiff extends SDBaseOps { public String toString(){ return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")"; } + + + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, null, cond, trueBody, falseBody); + } + + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, ifName, cond, trueBody, falseBody); + } + + /** + * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) + * + * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. + * + * See Tensorflow Control Flow Implementation + * + * @param outputName Name to give the output variable. If null, doesn't rename + * @param ifName The name of the if block. If null, uses "if" + * @param cond A lambda evaluating to the if condition + * @param trueBody A lambda to be executed if cond is true (the if block) + * @param falseBody A lambda to be executed if cond is false (the else block) + * @return The value of trueBody if cond is true, or falseBody if it isn't + */ + public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + + ifName = newBlockName(ifName == null ? "if" : ifName); + + NameScope ifScope = sd.withNameScope(ifName); + + NameScope condScope = withNameScope("cond"); + final SDVariable pred = cond.define(this); + condScope.close(); + + if (pred.dataType() != DataType.BOOL) { + //cleanup partially added block + + for(SDVariable v : getVariablesInScope(ifScope)) + this.getVariables().remove(v.name()); + + for(SameDiffOp op : this.getOpsInScope(ifScope)) { + for(String in : op.getInputsToOp()){ + this.removeArgFromOp(in, op.getOp()); + } + this.getOps().remove(op.getName()); + } + + + throw new IllegalStateException("Can not use " + pred.name() + + " as the condition of an If statement, the condition must be a boolean."); + } + + final Map switches = new HashMap<>(); + + final Set declared = Sets.newHashSet(this.variableMap().keySet()); + + this.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared.contains(argument.name())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[1]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.name(), s); + return s[1]; + } + }); + NameScope trueScope = this.withNameScope("trueBody"); + SDVariable trueOut = trueBody.define(this); + this.removeArgumentInterceptor(); + + if(declared.contains(trueOut.name())) { + SDVariable[] s = f().switchOp(trueOut, pred); + switches.put(trueOut.name(), s); + trueOut = s[1]; + } + + trueScope.close(); + + final Set declared2 = Sets.newHashSet(variableMap().keySet()); + sd.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared2.contains(argument.name())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.name())) + return switches.get(argument.name())[0]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.name(), s); + return s[0]; + } + }); + NameScope falseScope = this.withNameScope("falseBody"); + SDVariable falseOut = falseBody.define(this); + this.removeArgumentInterceptor(); + + if(declared2.contains(falseOut.name())) { + SDVariable[] s = f().switchOp(falseOut, pred); + switches.put(falseOut.name(), s); + falseOut = s[0]; + } + falseScope.close(); + + SDVariable output = f().merge(trueOut, falseOut); + + ifScope.close(); + + return updateVariableNameAndReference(output, outputName); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, null, loopVars, cond, body); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, loopName, loopVars, cond, body); + } + + + /** + * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) + * + * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. + * + * See Tensorflow Control Flow Implementation + * + * @param outputNames Names to give the output variables. If null, doesn't rename + * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" + * @param loopVars Loop variables' inputs + * @param cond A lambda evaluating to the loop condition + * @param body A lambda doing the loop operation and returning the new loop variable values + * @return The values of the loop variables once condition is false + */ + public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + + final String frameName = this.newBlockName(loopName == null ? "while" : loopName); + + NameScope loopScope = this.withNameScope(frameName); + + //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); + + SDVariable[] entered = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + entered[i] = f().enter(loopVars[i], frameName); + } + + //counter = SD.f().enter(counter, frameName); + + SDVariable[] merged = new SDVariable[loopVars.length]; + Merge[] mergeOps = new Merge[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + // the second arg will later be replaced with the output of NextIteration + // but that isn't available yet (and can't be, as it depends on this) + mergeOps[i] = new Merge(this, entered[i], entered[i]); + merged[i] = mergeOps[i].outputVariable(); + } + + //Merge counterMerge = new Merge(SD, counter, counter); + //counter = counterMerge.outputVariable(); + + NameScope condScope = this.withNameScope("cond"); + SDVariable cond_result = cond.define(this, merged); + condScope.close(); + + + if (cond_result.dataType() != DataType.BOOL) + throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); + + + final Set alreadyEntered = Sets.newHashSet(); + SDVariable[] trueSwitches = new SDVariable[loopVars.length]; + SDVariable[] exits = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable[] s = f().switchOp(merged[i], cond_result); + trueSwitches[i] = s[1]; + alreadyEntered.add(s[1].name()); + exits[i] = f().exit(s[0]); + } + + //SDVariable[] cs = SD.f().switchOp(counter, cond_result); + //SDVariable counterExit = SD.f().exit(cs[0]); + //counter = cs[1]; + + final Set declared = Sets.newHashSet(this.variableMap().keySet()); + final Map done = new HashMap<>(); + + this.addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + if(!declared.contains(argument.name())) + return argument; + + if(alreadyEntered.contains(argument.name())) + return argument; + + if(done.containsKey(argument.name())) + return done.get(argument.name()); + + SDVariable e = f().enter(argument, frameName, true); + done.put(argument.name(), e); + return e; + } + }); + + NameScope bodyScope = this.withNameScope("body"); + SDVariable[] outs = body.define(this, trueSwitches); + bodyScope.close(); + this.removeArgumentInterceptor(); + + //counter.add(1); + + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable n = f().nextIteration(outs[i]); + mergeOps[i].replaceArg(1,n); + } + + //counterMerge.replaceArg(1, counter); + + loopScope.close(); + return updateVariableNamesAndReferences(exits, outputNames); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 8e5d1ca36..3b53e5b65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,3403 +14,4720 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import org.nd4j.shade.guava.collect.Sets; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; -import lombok.NonNull; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.samediff.ArgumentInterceptor; -import org.nd4j.autodiff.samediff.NameScope; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffLambda; -import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda; -import org.nd4j.autodiff.samediff.SameDiffSingleLambda; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.shape.OneHot; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.indexing.conditions.Condition; -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * Core op creator methods available via SameDiff class directly - * - * @author Alex Black - * @see SDMath SDMath for Math operations - * @see SDRandom SDRandom for random number generator operations - * @see SDNN SDNN for general neural network operations - * @see SDCNN SDCNN for Convolutional Neural Network operations - * @see SDRNN SDRNN for Recurrent Neural Network operations - * @see SDLoss SDLoss for loss function operations - */ -public abstract class SDBaseOps { - - /** - * Intended for internal/developer use - */ - protected SDVariable gradientBackwardsMarker(SDVariable x) { - return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), x); - } - - /** - * Intended for internal/developer use - */ - protected SDVariable gradientBackwardsMarker(String name, SDVariable x) { - SDVariable result = f().gradientBackwardsMarker(x); - return updateVariableNameAndReference(result, name); - } - - protected abstract String generateNewVarName(String baseName, int argIndex); - - protected abstract DifferentialFunctionFactory f(); - - protected abstract SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName); - - protected abstract SameDiff sd(); - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmax(SDVariable in, int... dimensions) { - return argmax(null, in, false, dimensions); - } - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension.
- *
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("argmax", in); - SDVariable ret = f().argmax(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #argmax(String, SDVariable, boolean, int...) - */ - public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { - return argmax(null, in, keepDims, dimensions); - } - - /** - * Argmax array reduction operation, optionally along specified dimensions.
- * Output values are the index of the maximum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmax(String name, SDVariable in, int... dimensions) { - return argmax(name, in, false, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmin(SDVariable in, int... dimensions) { - return argmin(null, in, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable argmin(String name, SDVariable in, int... dimensions) { - return argmin(name, in, false, dimensions); - } - - /** - * Argmin array reduction operation, optionally along specified dimensions.
- * Output values are the index of the minimum value of each slice along the specified dimension.
- *
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("argmin", in); - SDVariable ret = f().argmin(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #argmin(String, SDVariable, boolean, int...) - */ - public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { - return argmin(null, in, keepDims, dimensions); - } - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). - *

- *

- * The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) - * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) - * @param transposeA whether first batch of matrices is transposed. - * @param transposeB whether second batch of matrices is transposed. - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB, - boolean transposeA, boolean transposeB) { - return batchMmul(null, matricesA, matricesB, transposeA, transposeB); - } - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). - *

- *

- * The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) - * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) - * @param transposeA whether first batch of matrices is transposed. - * @param transposeB whether second batch of matrices is transposed. - * @param names names for all provided SDVariables - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(String[] names, SDVariable[] matricesA, SDVariable[] matricesB, - boolean transposeA, boolean transposeB) { - validateSameType("batchMmul", true, matricesA); - validateSameType("batchMmul", true, matricesB); - SDVariable[] result = f().batchMmul(matricesA, matricesB, transposeA, transposeB); - return updateVariableNamesAndReferences(result, names); - } - - protected abstract SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames); - - /** - * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same - * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), - * respectively. The result of this operation will be a batch of multiplied matrices. The - * result has the same length as both input batches and each output matrix is of shape (M, K). - * - * @param matricesA First array of input matrices, all of shape (M, N) - * @param matricesB Second array of input matrices, all of shape (N, K) - * @return Array of multiplied SDVariables of shape (M, K) - */ - public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB) { - return batchMmul(null, matricesA, matricesB, false, false); - } - - public SDVariable castTo(SDVariable toCast, org.nd4j.linalg.api.buffer.DataType toType) { - return castTo(null, toCast, toType); - } - - public SDVariable castTo(String name, SDVariable toCast, org.nd4j.linalg.api.buffer.DataType toType) { - SDVariable ret = f().cast(toCast, toType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #concat(String, int, SDVariable...) - */ - public SDVariable concat(int dimension, SDVariable... inputs) { - return concat(null, dimension, inputs); - } - - /** - * Concatenate a set of inputs along the specified dimension.
- * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
- * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c] - * - * @param name Name of the output variable - * @param dimension Dimension to concatenate on - * @param inputs Input variables - * @return Output variable - * @see #stack(String, int, SDVariable...) - */ - public SDVariable concat(String name, int dimension, SDVariable... inputs) { - validateSameType("concat", false, inputs); - SDVariable result = f().concat(dimension, inputs); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cumprod(String, SDVariable, boolean, boolean, int...) - */ - public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return cumprod(null, in, exclusive, reverse, axis); - } - - /** - * Cumulative product operation.
- * For input: [ a, b, c], output is:
- * exclusize=false, reverse=false: [a, a*b, a*b*c]
- * exclusive=true, reverse=false, [0, a, a*b]
- * exclusive=false, reverse=true: [a*b*c, b*c, c]
- * exclusive=true, reverse=true: [b*c, c, 0]

- * - * @param name Name of the output variable - * @param in Input variable - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along - * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @return Output variable - */ - public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { - validateNumerical("cumprod", in); - SDVariable ret = f().cumprod(in, exclusive, reverse, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #cumsum(String, SDVariable, boolean, boolean, int...) - */ - public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return cumsum(null, in, exclusive, reverse, axis); - } - - /** - * Cumulative sum operation.
- * For input: [ a, b, c], output is:
- * exclusize=false, reverse=false: [a, a+b, a+b+c]
- * exclusive=true, reverse=false, [0, a, a+b]
- * exclusive=false, reverse=true: [a+b+c, b+c, c]
- * exclusive=true, reverse=true: [b+c, c, 0]

- * - * @param name Name of the output variable - * @param in Input variable - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along - * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @return Output variable - */ - public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { - validateNumerical("cumsum", in); - SDVariable ret = f().cumsum(in, exclusive, reverse, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * TODO doc string - * - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { - return dot(null, x, y, dimensions); - } - - /** - * TODO doc string - * - * @param name - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { - SDValidation.validateNumerical("dot", x, y); - SDVariable ret = f().dot(x, y, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #dynamicPartition(String[], SDVariable, SDVariable, int) - */ - public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { - return dynamicPartition(null, x, partitions, numPartitions); - } - - /** - * Dynamically partition the input variable values into the specified number of paritions, using the indices.
- * Example:
- *

-     * {@code input = [1,2,3,4,5]
-     * numPartitions = 2
-     * partitions = [1,0,0,1,0]
-     * out[0] = [2,3,5]
-     * out[1] = [1,4] }
-     * 
- * - * @param name Names for the output variables. Length must be equal to numPartitions - * @param x Input variable - * @param partitions 1D input with values 0 to numPartitions-1 - * @param numPartitions Number of partitions, >= 1 - * @return Output variables (equal in number to numPartitions) - */ - public SDVariable[] dynamicPartition(String[] name, SDVariable x, SDVariable partitions, int numPartitions) { - SDVariable[] ret = f().dynamicPartition(x, partitions, numPartitions); - return updateVariableNamesAndReferences(ret, name); - } - - /** - * @see #dynamicStitch(String, SDVariable[], SDVariable[]) - */ - public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] x) { - return dynamicStitch(null, indices, x); - } - - /** - * Dynamically merge the specified input arrays into a single array, using the specified indices - * - * @param name Name of the output variable - * @param indices Indices to use when merging. Must be >= 1, same length as input variables - * @param x Input variables. - * @return Merged output variable - */ - public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] x) { - SDVariable ret = f().dynamicStitch(indices, x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Equals operation: elementwise x == y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(SDVariable x, double y) { - return eq(null, x, y); - } - - /** - * Equals operation: elementwise x == y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(String name, SDVariable x, double y) { - SDVariable result = f().eq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Equal to operation: elementwise x == y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(SDVariable x, SDVariable y) { - return eq(null, x, y); - } - - /** - * Equal to operation: elementwise x == y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable eq(String name, SDVariable x, SDVariable y) { - SDVariable result = f().eq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #expandDims(String, SDVariable, int) - */ - public SDVariable expandDims(SDVariable x, int axis) { - return expandDims(null, x, axis); - } - - /** - * Reshape the input by adding a 1 at the specified location.
- * For example, if input has shape [a, b], then output shape is:
- * axis = 0: [1, a, b]
- * axis = 1: [a, 1, b]
- * axis = 2: [a, b, 1]
- * - * @param name Name of the output variable - * @param x Input variable - * @param axis Axis to expand - * @return Output variable - * @see #squeeze(String, SDVariable, int) - */ - public SDVariable expandDims(String name, SDVariable x, int axis) { - SDVariable result = f().expandDims(x, axis); - return updateVariableNameAndReference(result, name); - } - - /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value - * - * @param shape Shape: must be a 1D array/variable - * @param value Value to set all elements to - * @return Output variable - */ - public SDVariable fill(SDVariable shape, org.nd4j.linalg.api.buffer.DataType dataType, double value) { - return fill(null, shape, dataType, value); - } - - /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value - * - * @param name Name of the output variable - * @param shape Shape: must be a 1D array/variable - * @param value Value to set all elements to - * @return Output variable - */ - public SDVariable fill(String name, SDVariable shape, org.nd4j.linalg.api.buffer.DataType dataType, double value) { - SDVariable result = f().fill(shape, dataType, value); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #gather(String, SDVariable, int[], int) - */ - public SDVariable gather(SDVariable df, int[] indices, int axis) { - return gather(null, df, indices, axis); - } - - /** - * Gather slices from the input variable where the indices are specified as fixed int[] values.
- * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. - * - * @param name name of the output variable - * @param df Input variable - * @param indices Indices to get - * @param axis Axis that the indices refer to - * @return Output variable with slices pulled from the specified axis - */ - public SDVariable gather(String name, SDVariable df, int[] indices, int axis) { - SDVariable ret = f().gather(df, indices, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #gather(String, SDVariable, SDVariable, int) - */ - public SDVariable gather(SDVariable df, SDVariable indices, int axis) { - return gather(null, df, indices, axis); - } - - /** - * Gather slices from the input variable where the indices are specified as dynamic SDVariable values.
- * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. - * - * @param name name of the output variable - * @param df Input variable - * @param indices Indices to get slices for. Rank 0 or 1 input - * @param axis Axis that the indices refer to - * @return Output variable with slices pulled from the specified axis - */ - public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) { - SDVariable ret = f().gather(df, indices, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * TODO doc string - * - * @param df - * @param indices - * @return - */ - public SDVariable gatherNd(SDVariable df, SDVariable indices) { - return gatherNd(null, df, indices); - } - - /** - * TODO doc string - * - * @param name - * @param df - * @param indices - * @return - */ - public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) { - SDVariable ret = f().gatherNd(df, indices); - return updateVariableNameAndReference(ret, name); - } - - /** - * Greater than operation: elementwise x > y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(SDVariable x, double y) { - return gt(null, x, y); - } - - /** - * Greater than operation: elementwise x > y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(String name, SDVariable x, double y) { - validateNumerical("greater than (gt)", x); - SDVariable result = f().gt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than operation: elementwise x > y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(SDVariable x, SDVariable y) { - return gt(null, x, y); - } - - /** - * Greater than operation: elementwise x > y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gt(String name, SDVariable x, SDVariable y) { - SDValidation.validateNumerical("greater than (gt)", x, y); - SDVariable result = f().gt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than or equals operation: elementwise x >= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(SDVariable x, double y) { - return gte(null, x, y); - } - - /** - * Greater than or equals operation: elementwise x >= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(String name, SDVariable x, double y) { - validateNumerical("greater than or equal (gte)", x); - SDVariable result = f().gte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Greater than or equal to operation: elementwise x >= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(SDVariable x, SDVariable y) { - return gte(null, x, y); - } - - /** - * Greater than or equal to operation: elementwise x >= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable gte(String name, SDVariable x, SDVariable y) { - validateNumerical("greater than or equal (gte)", x, y); - SDVariable result = f().gte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise identity operation: out = x - * - * @param input Input variable - * @return Output variable - */ - public SDVariable identity(SDVariable input) { - return identity(null, input); - } - - /** - * Elementwise identity operation: out = x - * - * @param name name of the output variable - * @param input Input variable - * @return Output variable - */ - public SDVariable identity(String name, SDVariable input) { - SDVariable s = f().identity(input); - return updateVariableNameAndReference(s, name); - } - - /** - * Compute the inverse permutation indices for a permutation operation
- * Example: if input is [2, 0, 1] then output is [1, 2, 0]
- * The idea is that x.permute(input).permute(invertPermutation(input)) == x - * - * @param input 1D indices for permutation - * @return 1D inverted permutation - */ - public SDVariable invertPermutation(SDVariable input) { - return invertPermutation(null, input); - } - - /** - * Compute the inverse permutation indices for a permutation operation
- * Example: if input is [2, 0, 1] then output is [1, 2, 0]
- * The idea is that x.permute(input).permute(invertPermutation(input)) == x - * - * @param name name of the output variable - * @param input 1D indices for permutation - * @return 1D inverted permutation - */ - public SDVariable invertPermutation(String name, SDVariable input) { - validateInteger("invert permutation", input); - SDVariable ret = f().invertPermutation(input, false); - return updateVariableNameAndReference(ret, name); - } - - /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 - * - * @param x Input variable - * @return Scalar variable with value 1 - */ - public SDVariable isNumericTensor(SDVariable x) { - return isNumericTensor(null, x); - } - - /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 - * - * @param name Output variable name - * @param x Input variable - * @return Scalar variable with value 1 - */ - public SDVariable isNumericTensor(String name, SDVariable x) { - SDVariable result = f().isNumericTensor(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate - * @return SDVariable with linearly spaced elements - */ - // TODO: fix or remove, currently it is internal recursion - /*public SDVariable linspace(DataType dataType, double start, double stop, long number) { - return linspace(dataType, start, stop, number); - }*/ - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param name Name of the new variable - * @param dataType Data type of the output array - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate - * @return SDVariable with linearly spaced elements - */ - public SDVariable linspace(String name, DataType dataType, double start, double stop, long number) { - SDVariable ret = f().linspace(sd().constant(start), sd().constant(stop), sd().constant(number), dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop' - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] - * - * @param name Name of the new variable - * @param from Start value - * @param to Stop value - * @param length Number of values to generate - * @param dt Data type of the output array - * @return SDVariable with linearly spaced elements - */ - public SDVariable linspace(String name, SDVariable from, SDVariable to, SDVariable length, DataType dt) { - SDVariable ret = f().linspace(from, to, length, dt); - return updateVariableNameAndReference(ret, name); - } - - /** - * Less than operation: elementwise x < y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(SDVariable x, double y) { - return lt(null, x, y); - } - - /** - * Less than operation: elementwise x < y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(String name, SDVariable x, double y) { - validateNumerical("less than (lt)", x); - SDVariable result = f().lt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than operation: elementwise x < y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(SDVariable x, SDVariable y) { - return lt(null, x, y); - } - - /** - * Less than operation: elementwise x < y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lt(String name, SDVariable x, SDVariable y) { - validateNumerical("less than (lt)", x, y); - SDVariable result = f().lt(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than or equals operation: elementwise x <= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(SDVariable x, double y) { - return lte(null, x, y); - } - - /** - * Less than or equals operation: elementwise x <= y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(String name, SDVariable x, double y) { - validateNumerical("less than or equal (lte)", x); - SDVariable result = f().lte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Less than or equal to operation: elementwise x <= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(SDVariable x, SDVariable y) { - return lte(null, x, y); - } - - /** - * Less than or equal to operation: elementwise x <= y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable lte(String name, SDVariable x, SDVariable y) { - validateNumerical("less than or equal (lte)", x, y); - SDVariable result = f().lte(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise - * - * @param in Input variable - * @param condition Condition - * @return Boolean mask mariable - */ - public SDVariable matchCondition(SDVariable in, Condition condition) { - return matchCondition(null, in, condition); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise - * - * @param in Input - * @param condition Condition - * @return Boolean mask - */ - public SDVariable matchCondition(String name, SDVariable in, Condition condition) { - SDVariable ret = f().matchCondition(in, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(SDVariable in, Condition condition) { - return matchConditionCount(null, in, condition); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param name Name of the output variable - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) { - return matchConditionCount(name, in, condition, false); - } - - /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition - * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, boolean keepDim, int... dimensions) { - SDVariable ret = f().matchConditionCount(in, condition, keepDim, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Max array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(SDVariable x, int... dimensions) { - return max(null, x, dimensions); - } - - /** - * Max array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(String name, SDVariable x, int... dimensions) { - return max(name, x, false, dimensions); - } - - /** - * Max array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("max reduction", x); - SDVariable result = f().max(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise maximum operation: out[i] = max(first[i], second[i])
- * Supports broadcasting - * - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable max(SDVariable first, SDVariable second) { - return max(null, first, second); - } - - /** - * Element-wise maximum operation: out[i] = max(first[i], second[i])
- * Supports broadcasting - * - * @param name Name of the output variable - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable max(String name, SDVariable first, SDVariable second) { - validateNumerical("pairwise maxiumum (max)", first, second); - SDVariable result = f().max(first, second); - return updateVariableNameAndReference(result, name); - } - - /** - * Full array mean reduction operation - * - * @param x Input variable - * @return Output variable - scalar - */ - public SDVariable mean(SDVariable x) { - return mean(null, x); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(String name, SDVariable x, int... dimension) { - return mean(name, x, false, dimension); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimension) { - validateNumerical("mean reduction", x); - SDVariable result = f().mean(x, keepDims, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * Mean (average) array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable mean(SDVariable x, int... dimension) { - return mean(null, x, dimension); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in) - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(SDVariable x, int... dimensions) { - return min(null, x, dimensions); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(String name, SDVariable x, int... dimensions) { - return min(name, x, false, dimensions); - } - - /** - * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("min reduction", x); - SDVariable result = f().min(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - - } - - /** - * Element-wise minimum operation: out[i] = min(first[i], second[i])
- * Supports broadcasting - * - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable min(SDVariable first, SDVariable second) { - return min(null, first, second); - } - - /** - * Element-wise minimum operation: out[i] = min(first[i], second[i])
- * Supports broadcasting - * - * @param name Name of the output variable - * @param first First input array - * @param second Second input array - * @return Output variable - */ - public SDVariable min(String name, SDVariable first, SDVariable second) { - validateNumerical("mean (pairwise)", first, second); - SDVariable result = f().min(first, second); - return updateVariableNameAndReference(result, name); - } - - /** - * Matrix multiplication: out = mmul(x,y)
- * Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. - * - * @param x First input variable - * @param y Second input variable - * @param transpose Transpose arguments - * @return Output variable - */ - public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) { - return mmul(null, x, y, transpose); - - } - - /** - * Matrix multiplication: out = mmul(x,y)
- * Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. - * - * @param name Output variable name - * @param x First input variable - * @param y Second input variable - * @param transpose Transpose arguments - * @return Output variable - */ - public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) { - validateNumerical("matrix multiplication (mmul)", x, y); - SDVariable result = f().mmul(x, y, transpose); - return updateVariableNameAndReference(result, name); - } - - /** - * Matrix multiplication: out = mmul(x,y) - * - * @param x First input variable - * @param y Second input variable - * @return Output variable - */ - public SDVariable mmul(SDVariable x, SDVariable y) { - return mmul(null, x, y); - } - - /** - * Matrix multiplication: out = mmul(x,y) - * - * @param name Output variable name - * @param x First input variable - * @param y Second input variable - * @return Output variable - */ - public SDVariable mmul(String name, SDVariable x, SDVariable y) { - return mmul(name, x, y, MMulTranspose.allFalse()); - } - - /** - * Not equals operation: elementwise x != y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(SDVariable x, double y) { - return neq(null, x, y); - } - - /** - * Not equals operation: elementwise x != y
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @param y Double value argument to use in operation - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(String name, SDVariable x, double y) { - validateNumerical("not equals (neq)", x); - SDVariable result = f().neq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Not equal to operation: elementwise x != y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(SDVariable x, SDVariable y) { - return neq(null, x, y); - } - - /** - * Not equal to operation: elementwise x != y
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable neq(String name, SDVariable x, SDVariable y) { - validateNumerical("not equals (neq)", x, y); - SDVariable result = f().neq(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm1(String name, SDVariable x, int... dimensions) { - return norm1(name, x, false, dimensions); - } - - /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i])
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm1 reduction", x); - SDVariable result = f().norm1(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
- * out = sqrt(sum_i x[i]^2) - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm2(String name, SDVariable x, int... dimensions) { - return norm2(name, x, false, dimensions); - } - - /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
- * out = sqrt(sum_i x[i]^2)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm2 reduction", x); - SDVariable result = f().norm2(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the - * specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable normmax(String name, SDVariable x, int... dimensions) { - return normmax(name, x, false, dimensions); - } - - /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the - * specified dimensions:
- * out = max(abs(x[i]))
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions dimensions to reduce over - * @return Output variable - */ - public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("norm max reduction", x); - SDVariable result = f().normmax(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #oneHot(String, SDVariable, int) - */ - public SDVariable oneHot(SDVariable indices, int depth) { - return oneHot(null, indices, depth, -1, 1.00, 0.00); - } - - /** - * Convert the array to a one-hot array with walues {@code on} and {@code off} for each entry
- * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], - * with {@code out[i, ..., j, in[i,...,j]] = on} with other values being set to {@code off} - * - * @param name Output variable name - * @param indices Indices - value 0 to depth-1 - * @param depth Number of classes - * @return Output variable - */ - public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off) { - return oneHot(name, indices, depth, axis, on, off, OneHot.DEFAULT_DTYPE); - } - - /** - * As per {@link #oneHot(String, SDVariable, int, int, double, double)} but allows configuring the output datatype - */ - public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - validateInteger("oneHot", "indices", indices); - SDVariable ret = f().onehot(indices, depth, axis, on, off, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #oneHot(String, SDVariable, int, int, double, double) - */ - public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { - return oneHot(null, indices, depth, axis, on, off, OneHot.DEFAULT_DTYPE); - } - - /** - * @see #oneHot(String, SDVariable, int, int, double, double, DataType) - */ - public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - return oneHot(null, indices, depth, axis, on, off, dataType); - } - - /** - * Convert the array to a one-hot array with walues 0 and 1 for each entry
- * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], - * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0 - * - * @param name Output variable name - * @param indices Indices - value 0 to depth-1 - * @param depth Number of classes - * @return Output variable - * @see #oneHot(SDVariable, int, int, double, double) - */ - public SDVariable oneHot(String name, SDVariable indices, int depth) { - return oneHot(name, indices, depth, -1, 1.00, 0.00); - } - - /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable onesLike(SDVariable input) { - return onesLike(null, input); - } - - /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param name Name of the new SDVariable - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable onesLike(String name, SDVariable input) { - return onesLike(name, input, input.dataType()); - } - - /** - * As per {@link #onesLike(String, SDVariable)} but the output datatype may be specified - */ - public SDVariable onesLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) { - SDVariable ret = f().onesLike(name, input, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable parallel_stack(SDVariable[] values) { - return parallel_stack(null, values); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable parallel_stack(String name, SDVariable[] values) { - validateSameType("parallel_stack", false, values); - SDVariable ret = f().parallel_stack(values); - return updateVariableNameAndReference(ret, name); - } - - /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
- * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] - * - * @param x Input variable - * @return Output variable (permuted input) - */ - public SDVariable permute(SDVariable x, int... dimensions) { - return permute(null, x, dimensions); - } - - /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
- * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] - * - * @param name Output variable name - * @param x Input variable - * @return Output variable (permuted input) - */ - public SDVariable permute(String name, SDVariable x, int... dimensions) { - SDVariable result = f().permute(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * As per {@link #permute(String, SDVariable, int...)} but with SDVariable permute dimension - */ - public SDVariable permute(String name, SDVariable x, SDVariable dimensions){ - SDVariable result = f().permute(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Product array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(SDVariable x, int... dimensions) { - return prod(null, x, dimensions); - } - - /** - * Product array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(String name, SDVariable x, int... dimensions) { - return prod(name, x, false, dimensions); - } - - /** - * Product array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("product reduction (prod)", x); - SDVariable result = f().prod(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} - * up to (but not including) limit.
- * For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} - * - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @param dataType The output variable datatype - * @return 1D SDVariable with the specified values - */ - public SDVariable range(double from, double to, double step, DataType dataType) { - return range(null, from, to, step, dataType); - } - - /** - * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} - * up to (but not including) limit.
- * For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} - * - * @param name Name of the new variable - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @return 1D SDVariable with the specified values - */ - public SDVariable range(String name, double from, double to, double step, DataType dataType) { - SDVariable ret = f().range(from, to, step, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * As per {@link #range(String, double, double, double, DataType)} but with SDVariable arguments - */ - public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, DataType dataType) { - SDVariable ret = f().range(from, to, step, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the rank of the input variable - */ - public SDVariable rank(SDVariable in) { - return rank(null, in); - } - - /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param name Name of the output variable - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the rank of the input variable - */ - public SDVariable rank(String name, SDVariable in) { - SDVariable ret = f().rank(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #repeat(String, SDVariable, int) - */ - public SDVariable repeat(SDVariable df, int axis) { - return repeat(null, df, axis); - } - - /** - * @see #repeat(String, SDVariable, int) - */ - public SDVariable repeat(String name, SDVariable df, int axis) { - SDVariable ret = f().repeat(df, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise replace where condition:
- * out[i] = from[i] if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param update Source array - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) { - return replaceWhere(null, update, from, condition); - } - - /** - * Element-wise replace where condition:
- * out[i] = from[i] if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param name Name of the output variable - * @param update Source array - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, Condition condition) { - SDVariable ret = f().replaceWhere(update, from, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise replace where condition:
- * out[i] = value if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param update Source array - * @param value Value to set at the output, if the condition is satisfied - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(SDVariable update, Number value, Condition condition) { - return replaceWhere(null, update, value, condition); - } - - /** - * Element-wise replace where condition:
- * out[i] = value if condition(update[i]) is satisfied, or
- * out[i] = update[i] if condition(update[i]) is NOT satisfied - * - * @param name Name of the output variable - * @param update Source array - * @param value Value to set at the output, if the condition is satisfied - * @param condition Condition to check on update array elements - * @return New array with values replaced where condition is satisfied - */ - public SDVariable replaceWhere(String name, SDVariable update, Number value, Condition condition) { - SDVariable ret = f().replaceWhere(update, value, condition); - return updateVariableNameAndReference(ret, name); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(SDVariable x, long... shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(String name, SDVariable x, long... shape) { - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(SDVariable x, int... shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, SDVariable) - */ - public SDVariable reshape(String name, SDVariable x, int... shape) { - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, int[]) - */ - public SDVariable reshape(SDVariable x, SDVariable shape) { - return reshape(null, x, shape); - } - - /** - * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the - * input, but with the specified shape.
- * Note that prod(shape) must match length(input) == prod(input.shape) - * - * @param name Output variable name - * @param x Input variable - * @param shape New shape for variable - * @return Output variable - * @see #reshape(SDVariable, int[]) - */ - public SDVariable reshape(String name, SDVariable x, SDVariable shape) { - validateInteger("reshape", "shape", shape); - SDVariable result = f().reshape(x, shape); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #reverse(String, SDVariable, int...) - */ - public SDVariable reverse(SDVariable x, int... dimensions) { - return reverse(null, x, dimensions); - } - - /** - * Reverse the values of an array for the specified dimensions
- * If input is:
- * [ 1, 2, 3]
- * [ 4, 5, 6]
- * then
- * reverse(in, 0):
- * [3, 2, 1]
- * [6, 5, 4]
- * reverse(in, 0):
- * [4, 5, 6]
- * [1, 2 3]
- * - * @param x Input variable - * @param dimensions Dimensions - * @return Output variable - */ - public SDVariable reverse(String name, SDVariable x, int... dimensions) { - SDVariable ret = f().reverse(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { - return reverseSequence(null, x, seq_lengths, seqDim, batchDim); - } - - /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed - * - * @param name Name of the output variable - * @param x Input variable - * @param seq_lengths Length of the sequences - * @param seqDim Sequence dimension - * @param batchDim Batch dimension - * @return Reversed sequences - */ - public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { - SDVariable ret = f().reverseSequence(x, seq_lengths, seqDim, batchDim); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { - return reverseSequence(null, x, seq_lengths); - } - - /** - * @see #reverseSequence(String, SDVariable, SDVariable, int, int) - */ - public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) { - SDVariable ret = f().reverseSequence(x, seq_lengths); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value). - * i.e., returns the remainder after division by 'value' - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarFloorMod(SDVariable in, Number value) { - return scalarFloorMod(null, in, value); - } - - /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value). - * i.e., returns the remainder after division by 'value' - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarFloorMod(String name, SDVariable in, Number value) { - validateNumerical("floorMod", in); - SDVariable ret = f().scalarFloorMod(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar maximum operation: out = max(in, value) - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMax(SDVariable in, Number value) { - return scalarMax(null, in, value); - } - - /** - * Element-wise scalar maximum operation: out = max(in, value) - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMax(String name, SDVariable in, Number value) { - validateNumerical("max", in); - SDVariable ret = f().scalarMax(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise scalar minimum operation: out = min(in, value) - * - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMin(SDVariable in, Number value) { - return scalarMin(null, in, value); - } - - /** - * Element-wise scalar minimum operation: out = min(in, value) - * - * @param name Name of the output variable - * @param in Input variable - * @param value Scalar value to compare - * @return Output variable - */ - public SDVariable scalarMin(String name, SDVariable in, Number value) { - validateNumerical("min", in); - SDVariable ret = f().scalarMin(in, value); - return updateVariableNameAndReference(ret, name); - } - - /** - * Return an array with equal shape to the input, but all elements set to value 'set' - * - * @param in Input variable - * @param set Value to set - * @return Output variable - */ - public SDVariable scalarSet(SDVariable in, Number set) { - return scalarSet(null, in, set); - } - - /** - * Return a variable with equal shape to the input, but all elements set to value 'set' - * - * @param name Name of the output variable - * @param in Input variable - * @param set Value to set - * @return Output variable - */ - public SDVariable scalarSet(String name, SDVariable in, Number set) { - SDVariable ret = f().scalarSet(in, set); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterAdd(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterAdd(null, ref, indices, updates); - } - - /** - * Scatter addition operation.
- * If indices is rank 0 (a scalar), then out[index, ...] += updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] += updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] += updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterAdd", "indices", indices); - SDVariable ret = f().scatterAdd(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterDiv(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterDiv(null, ref, indices, updates); - } - - /** - * Scatter division operation.
- * If indices is rank 0 (a scalar), then out[index, ...] /= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] /= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] /= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterDiv", "indices", indices); - SDVariable ret = f().scatterDiv(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMax(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMax(null, ref, indices, updates); - } - - /** - * Scatter max operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = max(updates[...], in[index,...])
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = max(updates[i,...], in[indices[i],...])
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = max(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMax", "indices", indices); - SDVariable ret = f().scatterMax(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMin(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMin(null, ref, indices, updates); - } - - /** - * Scatter min operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = min(updates[...], in[index,...])
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = min(updates[i,...], in[indices[i],...])
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = min(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMin", "indices", indices); - SDVariable ret = f().scatterMin(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterMul(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterMul(null, ref, indices, updates); - } - - /** - * Scatter multiplication operation.
- * If indices is rank 0 (a scalar), then out[index, ...] *= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] *= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] *= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterMul", "indices", indices); - SDVariable ret = f().scatterMul(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterSub(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterSub(null, ref, indices, updates); - } - - /** - * Scatter subtraction operation.
- * If indices is rank 0 (a scalar), then out[index, ...] -= updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] -= updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] -= updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the contributions from each is handled correctly. - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterSub", "indices", indices); - SDVariable ret = f().scatterSub(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #scatterUpdate(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { - return scatterUpdate(null, ref, indices, updates); - } - - /** - * Scatter update operation.
- * If indices is rank 0 (a scalar), then out[index, ...] = updates[...]
- * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = updates[i, ...]
- * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = updates[i, ..., k, ...]
- * Note that if multiple indices refer to the same location, the output at those locations is undefined - different - * updates may occur in different orders - * - * @param name Name of the output variable - * @param ref Initial/source variable - * @param indices Indices array - * @param updates Updates to add to the initial/source array - * @return The updated variable - */ - public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, SDVariable updates) { - validateInteger("scatterUpdate", "indices", indices); - SDVariable ret = f().scatterUpdate(ref, indices, updates); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMax(String, SDVariable, SDVariable) - */ - public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) { - return segmentMax(null, data, segmentIds); - } - - /** - * Segment max operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMax(String, SDVariable, SDVariable, int)} - * for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment max output - */ - public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMax", "data", data); - validateInteger("segmentMax", "segmentIds", segmentIds); - SDVariable ret = f().segmentMax(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMean(String, SDVariable, SDVariable) - */ - public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) { - return segmentMean(null, data, segmentIds); - } - - /** - * Segment mean operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMean(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment mean output - */ - public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMean", "data", data); - validateInteger("segmentMean", "segmentIds", segmentIds); - SDVariable ret = f().segmentMean(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentMin(String, SDVariable, SDVariable) - */ - public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) { - return segmentMin(null, data, segmentIds); - } - - /** - * Segment min operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentMin(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment min output - */ - public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentMin", "data", data); - validateInteger("segmentMin", "segmentIds", segmentIds); - SDVariable ret = f().segmentMin(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentProd(String, SDVariable, SDVariable) - */ - public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) { - return segmentProd(null, data, segmentIds); - } - - /** - * Segment product operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [18, 36, 16] = [prod(3,6), prod(1,4,9), prod(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentProd(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment product output - */ - public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentProd", "data", data); - validateInteger("segmentProd", "segmentIds", segmentIds); - SDVariable ret = f().segmentProd(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #segmentSum(String, SDVariable, SDVariable) - */ - public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) { - return segmentSum(null, data, segmentIds); - } - - /** - * Segment sum operation.
- * If data = [3, 6, 1, 4, 9, 2, 8]
- * segmentIds = [0, 0, 1, 1, 1, 2, 2]
- * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
- * Note that the segment IDs must be sorted from smallest to largest segment. - * See {@link #unsortedSegmentSum(String, SDVariable, SDVariable, int)} for the same op without this sorted requirement - * - * @param name Name of the output variable. May be null - * @param data Data to perform segment max on - * @param segmentIds Variable for the segment IDs - * @return Segment sum output - */ - public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds) { - validateNumerical("segmentSum", "data", data); - validateInteger("segmentSum", "segmentIds", segmentIds); - SDVariable ret = f().segmentSum(data, segmentIds); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { - return sequenceMask(null, lengths, maxLen, dataType); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { - validateInteger("sequenceMask", "lengths", lengths); - SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) { - SDVariable ret = f().sequenceMask(lengths, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { - return sequenceMask(lengths, null, dataType); - } - - /** - * @see #sequenceMask(String, SDVariable, SDVariable, DataType) - */ - public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { - return sequenceMask(null, lengths, maxLen, dataType); - } - - /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0) - * - * @param name Name of the output variable - * @param lengths Lengths of the sequences - * @param maxLen Maximum sequence length - * @return Output variable - */ - public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, DataType dataType) { - validateInteger("sequenceMask", "lengths", lengths); - SDVariable ret = f().sequenceMask(lengths, maxLen, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the shape of the specified SDVariable as a 1D SDVariable - * - * @param input Input variable - * @return 1D output variable with contents equal to the shape of the input - */ - public SDVariable shape(SDVariable input) { - return shape(null, input); - } - - /** - * Returns the shape of the specified SDVariable as a 1D SDVariable - * - * @param name Name of the output variable - * @param input Input variable - * @return 1D output variable with contents equal to the shape of the input - */ - public SDVariable shape(String name, SDVariable input) { - SDVariable ret = f().shape(input); - return updateVariableNameAndReference(ret, name); - } - - /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the number of elements in the specified array - */ - public SDVariable size(SDVariable in) { - return size(null, in); - } - - /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable - * - * @param name Name of the output variable - * @param in Input variable - * @return 0D (scalar) output variable with value equal to the number of elements in the specified array - */ - public SDVariable size(String name, SDVariable in) { - SDVariable ret = f().size(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #sizeAt(String, SDVariable, int) - */ - public SDVariable sizeAt(SDVariable in, int dimension) { - return sizeAt(null, in, dimension); - } - - /** - * Returns a rank 0 (scalar) variable for the size of the specified dimension. - * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30 - * - * @param name Name of the output variable - * @param in Input variable - * @param dimension Dimension to get size of - * @return Scalar SDVariable for size at specified variable - */ - public SDVariable sizeAt(String name, SDVariable in, int dimension) { - SDVariable ret = f().sizeAt(in, dimension); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #slice(String, SDVariable, int[], int[]) - */ - public SDVariable slice(SDVariable input, int[] begin, int[] size) { - return slice(null, input, begin, size); - } - - public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { - return slice(null, input, begin, size); - } - - /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
- * For example, if input is:
- * [a, b, c]
- * [d, e, f]
- * then slice(input, begin=[0,1], size=[2,1] will return:
- * [b]
- * [e]
- *
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i) - * - * @param name Output variable name - * @param input Variable to get subset of - * @param begin Beginning index. Must be same length as rank of input array - * @param size Size of the output array. Must be same length as rank of input array - * @return Subset of the input - */ - public SDVariable slice(String name, SDVariable input, int[] begin, int[] size) { - SDVariable ret = f().slice(input, begin, size); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable slice(String name, SDVariable input, @NonNull SDVariable begin, @NonNull SDVariable size) { - SDVariable ret = f().slice(input, begin, size); - return updateVariableNameAndReference(ret, name); - } - - - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} - */ - public SDVariable squaredNorm(SDVariable x, int... dimensions) { - return squaredNorm(null, x, false, dimensions); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} - */ - public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("squaredNorm", x); - SDVariable result = f().squaredNorm(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} - */ - public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) { - return squaredNorm(name, x, false, dimensions); - } - - /** - * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} - */ - public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) { - return squaredNorm(null, x, keepDims, dimensions); - } - - /** - * @see #squeeze(String, SDVariable, int) - */ - public SDVariable squeeze(SDVariable x, int axis) { - return squeeze(null, x, axis); - } - - /** - * Remove a single dimension of size 1. - * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c] - * - * @param name Name of the output variable - * @param x Input variable - * @param axis Size 1 dimension to remove - * @return Output variable - */ - public SDVariable squeeze(String name, SDVariable x, int axis) { - SDVariable result = f().squeeze(x, axis); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #stack(String, int, SDVariable...) - */ - public SDVariable stack(int axis, SDVariable... values) { - return stack(null, axis, values); - } - - /** - * Stack a set of N SDVariables of rank X into one rank X+1 variable. - * If inputs have shape [a,b,c] then output has shape:
- * axis = 0: [N,a,b,c]
- * axis = 1: [a,N,b,c]
- * axis = 2: [a,b,N,c]
- * axis = 3: [a,b,c,N]
- * - * @param name Name of the output variable - * @param axis Axis to stack on - * @param values Input variables to stack. Must have the same shape for all inputs - * @return Output variable - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable stack(String name, int axis, SDVariable... values) { - validateSameType("stack", false, values); - SDVariable ret = f().stack(values, axis); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #standardDeviation(String, SDVariable, boolean, int...) - */ - public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) { - return standardDeviation(null, x, biasCorrected, dimensions); - } - - /** - * Stardard deviation array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, int... dimensions) { - return standardDeviation(name, x, biasCorrected, false, dimensions); - } - - /** - * Stardard deviation array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { - validateNumerical("standard deviation", x); - SDVariable result = f().std(x, biasCorrected, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { - return stridedSlice(null, input, begin, end, strides); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides) { - return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[]) - */ - public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { - return stridedSlice(null, input, begin, end, strides); - } - - /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
- * For example, if input is:
- * [a, b, c]
- * [d, e, f]
- * [g, h, i]
- * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1]) will return:
- * [b, c]
- * [h, i]
- *
- * - * @param name Output variable name - * @param input Variable to get subset of - * @param begin Beginning index. Must be same length as rank of input array - * @param end End index. Must be same length as the rank of the array - * @param strides Stride ("step size") for each dimension. Must be same length as the rank of the array. For example, - * stride of 2 means take every second element. - * @return Subset of the input - */ - public SDVariable stridedSlice(String name, SDVariable input, long[] begin, long[] end, long[] strides) { - return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); - } - - /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
- * Operates as described in {@link #stridedSlice(SDVariable, long[], long[], long[])} with some extra mask arrays - * as described below. - * - * @param name Output variable name - * @param in Variable to get subset of - * @param begin Beginning index - * @param end End index - * @param strides Stride ("step size") for each dimension. For example, - * stride of 2 means take every second element. - * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, - * and a value of 0 is used instead for the beginning index for that dimension - * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, - * and a value of size(i)-1 is used instead for the end index for that dimension - * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other - * dimensions are inserted as required at the specified position - * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and - * a size 1 dimension is inserted at this point - * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and - * a size 1 dimension is removed at this point. Note that begin/end/stride values must - * result in a size 1 output for these dimensions - * @return A subset of the input array - */ - public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - } - - /** - * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) - */ - public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - } - - /** - * Sum array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable sum(SDVariable x, int... dimensions) { - return sum(null, x, dimensions); - } - - /** - * Sum array reduction operation, optionally along specified dimensions - * - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable sum(String name, SDVariable x, int... dimensions) { - return sum(name, x, false, dimensions); - } - - /** - * Sum array reduction operation, optionally along specified dimensions.
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true - */ - public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { - validateNumerical("sum reduction", x); - SDVariable result = f().sum(x, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #sum(String, SDVariable, boolean, int...) - */ - public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) { - return sum(null, x, keepDims, dimensions); - } - - /** - * @param x - * @param y - * @param dimensions - * @return - */ - public SDVariable tensorMmul(SDVariable x, - SDVariable y, - int[][] dimensions) { - return tensorMmul(null, x, y, dimensions); - } - - /** - * @param x Input variable x - * @param y Input variable y - * @param dimensions dimensions - * @return Output variable - */ - public SDVariable tensorMmul(String name, - SDVariable x, - SDVariable y, - int[][] dimensions) { - validateNumerical("tensorMmul", x, y); - SDVariable result = f().tensorMmul(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(SDVariable x, int... repeat) { - return tile(null, x, repeat); - } - - /** - * Repeat (tile) the input tensor the specified number of times.
- * For example, if input is
- * [1, 2]
- * [3, 4]
- * and repeat is [2, 3]
- * then output is
- * [1, 2, 1, 2, 1, 2]
- * [3, 4, 3, 4, 3, 4]
- * [1, 2, 1, 2, 1, 2]
- * [3, 4, 3, 4, 3, 4]
- *
- * - * @param name Output variable name - * @param x Input variable - * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array - * @return Output variable - */ - public SDVariable tile(String name, SDVariable x, int... repeat) { - SDVariable result = f().tile(x, repeat); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(SDVariable x, SDVariable repeat) { - return tile(null, x, repeat); - } - - /** - * @see #tile(String, SDVariable, int...) - */ - public SDVariable tile(String name, SDVariable x, SDVariable repeat) { - SDVariable result = f().tile(x, repeat); - return updateVariableNameAndReference(result, name); - } - /** - * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] - * - * @param x Input variable - * @return Output variable (transposed input) - */ - public SDVariable transpose(SDVariable x) { - return transpose(null, x); - } - - /** - * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] - * - * @param name Output variable name - * @param x Input variable - * @return Output variable (transposed input) - */ - public SDVariable transpose(String name, SDVariable x) { - SDVariable result = f().transpose(x); - return updateVariableNameAndReference(result, name); - } - - /** - * See {@link #unsortedSegmentMax(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMax(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment max operation. As per {@link #segmentMax(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment max on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment max output - */ - public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMax", "data", data); - validateInteger("unsortedSegmentMax", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMax(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentMean(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMean(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment mean operation. As per {@link #segmentMean(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment mean on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment mean output - */ - public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMean", "data", data); - validateInteger("unsortedSegmentMean", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMean(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentMin(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentMin(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment min operation. As per {@link #segmentMin(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment min on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment min output - */ - public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentMin", "data", data); - validateInteger("unsortedSegmentMin", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentMin(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentProd(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentProd(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment product operation. As per {@link #segmentProd(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment product on - * @param segmentIds Variable for the segment IDs - * @return Unsorted segment product output - */ - public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentProd", "data", data); - validateInteger("unsortedSegmentProd", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentProd(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentSqrtN(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments) { - return unsortedSegmentSqrtN(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment sqrtN on - * @param segmentIds Variable for the segment IDs - * @return Unsorted segment sqrtN output - */ - public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds, int numSegments) { - SDVariable ret = f().unsortedSegmentSqrtN(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #unsortedSegmentSum(String, SDVariable, SDVariable, int)} - */ - public SDVariable unsortedSegmentSum(@NonNull SDVariable data, @NonNull SDVariable segmentIds, int numSegments) { - return unsortedSegmentSum(null, data, segmentIds, numSegments); - } - - /** - * Unsorted segment sum operation. As per {@link #segmentSum(String, SDVariable, SDVariable)} but without - * the requirement for the indices to be sorted.
- * If data = [1, 3, 2, 6, 4, 9, 8]
- * segmentIds = [1, 0, 2, 0, 1, 1, 2]
- * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
- * - * @param name Name of the output variable - * @param data Data (variable) to perform unsorted segment sum on - * @param segmentIds Variable for the segment IDs - * @param numSegments Number of segments - * @return Unsorted segment sum output - */ - public SDVariable unsortedSegmentSum(String name, @NonNull SDVariable data, @NonNull SDVariable segmentIds, int numSegments) { - validateNumerical("unsortedSegmentSum", "data", data); - validateInteger("unsortedSegmentSum", "segmentIds", segmentIds); - SDVariable ret = f().unsortedSegmentSum(data, segmentIds, numSegments); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(SDVariable value, int axis) { - return unstack(null, value, axis); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(String[] names, @NonNull SDVariable value, int axis) { - SDVariable[] ret = f().unstack(value, axis); - return updateVariableNamesAndReferences(ret, names); - } - - /** - * @see #unstack(String[], SDVariable, int, int) - */ - public SDVariable[] unstack(@NonNull SDVariable value, int axis, int num) { - return unstack(null, value, axis, num); - } - - /** - * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis. - * If input has shape [a,b,c] then output has shape: - * axis = 0: [b,c]
- * axis = 1: [a,c]
- * axis = 2: [a,b]
- * - * @param names Output variable names. May be null - * @param value Input variable to unstack - * @param axis Axis to unstack on - * @param num Number of output variables - * @return Output variables - * @see #stack(String, int, SDVariable...) - */ - public SDVariable[] unstack(String[] names, @NonNull SDVariable value, int axis, int num) { - SDVariable[] ret = f().unstack(value, axis, num); - return updateVariableNamesAndReferences(ret, names); - } - - /** - * @see #variance(String, SDVariable, boolean, int...) - */ - public SDVariable variance(@NonNull SDVariable x, boolean biasCorrected, int... dimensions) { - return variance(null, x, biasCorrected, dimensions); - } - - /** - * Variance array reduction operation, optionally along specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorrected, int... dimensions) { - return variance(name, x, biasCorrected, false, dimensions); - } - - /** - * Variance array reduction operation, optionally along specified dimensions
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Output variable name - * @param x Input variable - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable variance(String name, @NonNull SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { - validateNumerical("variance", x); - SDVariable result = f().variance(x, biasCorrected, keepDims, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable zerosLike(@NonNull SDVariable input) { - return zerosLike(null, input); - } - - public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) { - return zerosLike(null, input, dataType); - } - /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: - * if the input shape changes in later execution, the returned variable's shape will also be updated - * - * @param name Name of the new SDVariable - * @param input Input SDVariable - * @return A new SDVariable with the same (dynamic) shape as the input - */ - public SDVariable zerosLike(String name, @NonNull SDVariable input) { - SDVariable ret = f().zerosLike(name, input); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) { - SDVariable ret = f().zerosLike(name, input, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #any(String, SDVariable, int...)} - */ - public SDVariable any(SDVariable x, int... dimensions){ - return any(null, x, dimensions); - } - //TODO check any w/ no dimensions - - /** - * Boolean or array reduction operation, optionally along specified dimensions - * - * @param name Name of the output variable - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable any(String name, SDVariable x, int... dimensions){ - validateBool("any", x); - SDVariable ret = f().any(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #all(String, SDVariable, int...)} - */ - public SDVariable all(SDVariable x, int... dimensions){ - return all(null, x, dimensions); - } - - - /** - * Boolean and array reduction operation, optionally along specified dimensions - * - * @param name Name of the output variable - * @param x Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable all(String name, SDVariable x, int... dimensions){ - validateBool("all", x); - SDVariable ret = f().all(x, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - return whileLoop(null, null, loopVars, cond, body); - } - - /** - * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} - */ - public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - return whileLoop(null, loopName, loopVars, cond, body); - } - - /** - * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) - * - * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false - * - * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. - * - * See Tensorflow Control Flow Implementation - * - * @param outputNames Names to give the output variables. If null, doesn't rename - * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" - * @param loopVars Loop variables' inputs - * @param cond A lambda evaluating to the loop condition - * @param body A lambda doing the loop operation and returning the new loop variable values - * @return The values of the loop variables once condition is false - */ - public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, - @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ - - final String frameName = sd().newBlockName(loopName == null ? "while" : loopName); - - NameScope loopScope = sd().withNameScope(frameName); - - //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); - - SDVariable[] entered = new SDVariable[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - entered[i] = f().enter(loopVars[i], frameName); - } - - //counter = SD.f().enter(counter, frameName); - - SDVariable[] merged = new SDVariable[loopVars.length]; - Merge[] mergeOps = new Merge[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - // the second arg will later be replaced with the output of NextIteration - // but that isn't available yet (and can't be, as it depends on this) - mergeOps[i] = new Merge(sd(), entered[i], entered[i]); - merged[i] = mergeOps[i].outputVariable(); - } - - //Merge counterMerge = new Merge(SD, counter, counter); - //counter = counterMerge.outputVariable(); - - NameScope condScope = sd().withNameScope("cond"); - SDVariable cond_result = cond.define(sd(), merged); - condScope.close(); - - - if (cond_result.dataType() != DataType.BOOL) - throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); - - - final Set alreadyEntered = Sets.newHashSet(); - SDVariable[] trueSwitches = new SDVariable[loopVars.length]; - SDVariable[] exits = new SDVariable[loopVars.length]; - for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable[] s = f().switchOp(merged[i], cond_result); - trueSwitches[i] = s[1]; - alreadyEntered.add(s[1].name()); - exits[i] = f().exit(s[0]); - } - - //SDVariable[] cs = SD.f().switchOp(counter, cond_result); - //SDVariable counterExit = SD.f().exit(cs[0]); - //counter = cs[1]; - - final Set declared = Sets.newHashSet(sd().variableMap().keySet()); - final Map done = new HashMap<>(); - - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - if(!declared.contains(argument.name())) - return argument; - - if(alreadyEntered.contains(argument.name())) - return argument; - - if(done.containsKey(argument.name())) - return done.get(argument.name()); - - SDVariable e = f().enter(argument, frameName, true); - done.put(argument.name(), e); - return e; - } - }); - - NameScope bodyScope = sd().withNameScope("body"); - SDVariable[] outs = body.define(sd(), trueSwitches); - bodyScope.close(); - sd().removeArgumentInterceptor(); - - //counter.add(1); - - for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable n = f().nextIteration(outs[i]); - mergeOps[i].replaceArg(1,n); - } - - //counterMerge.replaceArg(1, counter); - - loopScope.close(); - return updateVariableNamesAndReferences(exits, outputNames); - } - - /** - * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - return ifCond(null, null, cond, trueBody, falseBody); - } - - - /** - * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} - */ - public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - return ifCond(null, ifName, cond, trueBody, falseBody); - } - - /** - * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) - * - * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody - * - * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. - * - * See Tensorflow Control Flow Implementation - * - * @param outputName Name to give the output variable. If null, doesn't rename - * @param ifName The name of the if block. If null, uses "if" - * @param cond A lambda evaluating to the if condition - * @param trueBody A lambda to be executed if cond is true (the if block) - * @param falseBody A lambda to be executed if cond is false (the else block) - * @return The value of trueBody if cond is true, or falseBody if it isn't - */ - public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, - @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ - - ifName = sd().newBlockName(ifName == null ? "if" : ifName); - - NameScope ifScope = sd().withNameScope(ifName); - - NameScope condScope = sd().withNameScope("cond"); - final SDVariable pred = cond.define(sd()); - condScope.close(); - - if (pred.dataType() != DataType.BOOL) { - //cleanup partially added block - - for(SDVariable v : sd().getVariablesInScope(ifScope)) - sd().getVariables().remove(v.name()); - - for(SameDiffOp op : sd().getOpsInScope(ifScope)) { - for(String in : op.getInputsToOp()){ - sd().removeArgFromOp(in, op.getOp()); - } - sd().getOps().remove(op.getName()); - } - - - throw new IllegalStateException("Can not use " + pred.name() - + " as the condition of an If statement, the condition must be a boolean."); - } - - final Map switches = new HashMap<>(); - - final Set declared = Sets.newHashSet(sd().variableMap().keySet()); - - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - // if its declared in the if, we don't care acout it - if(!declared.contains(argument.name())) - return argument; - - // if we've already added a switch, move on - if(switches.containsKey(argument.name())) - return switches.get(argument.name())[1]; - - SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.name(), s); - return s[1]; - } - }); - NameScope trueScope = sd().withNameScope("trueBody"); - SDVariable trueOut = trueBody.define(sd()); - sd().removeArgumentInterceptor(); - - if(declared.contains(trueOut.name())) { - SDVariable[] s = f().switchOp(trueOut, pred); - switches.put(trueOut.name(), s); - trueOut = s[1]; - } - - trueScope.close(); - - final Set declared2 = Sets.newHashSet(sd().variableMap().keySet()); - sd().addArgumentInterceptor(new ArgumentInterceptor() { - @Override - public SDVariable intercept(SDVariable argument) { - - // if its declared in the if, we don't care acout it - if(!declared2.contains(argument.name())) - return argument; - - // if we've already added a switch, move on - if(switches.containsKey(argument.name())) - return switches.get(argument.name())[0]; - - SDVariable[] s = f().switchOp(argument, pred); - switches.put(argument.name(), s); - return s[0]; - } - }); - NameScope falseScope = sd().withNameScope("falseBody"); - SDVariable falseOut = falseBody.define(sd()); - sd().removeArgumentInterceptor(); - - if(declared2.contains(falseOut.name())) { - SDVariable[] s = f().switchOp(falseOut, pred); - switches.put(falseOut.name(), s); - falseOut = s[0]; - } - falseScope.close(); - - SDVariable output = f().merge(trueOut, falseOut); - - ifScope.close(); - - return updateVariableNameAndReference(output, outputName); - } +public class SDBaseOps { + protected SameDiff sd; + + public SDBaseOps(SameDiff sameDiff) { + this.sd = sameDiff; + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable all(SDVariable x, int... dimensions) { + SDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable all(String name, SDVariable x, int... dimensions) { + SDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable any(SDVariable x, int... dimensions) { + SDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public SDVariable any(String name, SDVariable x, int... dimensions) { + SDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable argmin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable[] inputsB, boolean transposeA, + boolean transposeB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, transposeA, transposeB).outputVariables(); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable[] inputsB, + boolean transposeA, boolean transposeB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, transposeA, transposeB).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable... inputsB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, false, false).outputVariables(); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable... inputsB) { + SDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + SDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(sd,inputsA, inputsB, false, false).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public SDVariable castTo(SDVariable arg, DataType datatype) { + return new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(sd,arg, datatype).outputVariable(); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param name name May be null. Name for the output variable + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public SDVariable castTo(String name, SDVariable arg, DataType datatype) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(sd,arg, datatype).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param inputs Input variables (NUMERIC type) + * @param dimension Dimension to concatenate on + * @return output (NUMERIC type) + */ + public SDVariable concat(int dimension, SDVariable... inputs) { + SDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + return new org.nd4j.linalg.api.ops.impl.shape.Concat(sd,inputs, dimension).outputVariable(); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param name name May be null. Name for the output variable + * @param dimension Dimension to concatenate on + * @param inputs Input variables (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable concat(String name, int dimension, SDVariable... inputs) { + SDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Concat(sd,inputs, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable(); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, + int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(SDVariable in, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, false, false, axis).outputVariable(); + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cumprod(String name, SDVariable in, int... axis) { + SDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, false, false, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable(); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, + int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(SDVariable in, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, false, false, axis).outputVariable(); + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public SDVariable cumsum(String name, SDVariable in, int... axis) { + SDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, false, false, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("dot", "x", x); + SDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.Dot(sd,x, y, dimensions).outputVariable(); + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param name name May be null. Name for the output variable + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("dot", "x", x); + SDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.Dot(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *

+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + */ + public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { + SDValidation.validateNumerical("dynamicPartition", "x", x); + SDValidation.validateInteger("dynamicPartition", "partitions", partitions); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables(); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *

+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + */ + public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions, + int numPartitions) { + SDValidation.validateNumerical("dynamicPartition", "x", x); + SDValidation.validateInteger("dynamicPartition", "partitions", partitions); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) + * @return output Merged output variable (NUMERIC type) + */ + public SDVariable dynamicStitch(SDVariable[] indices, SDVariable... x) { + SDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(sd,indices, x).outputVariable(); + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) + * @return output Merged output variable (NUMERIC type) + */ + public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable... x) { + SDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(sd,indices, x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(SDVariable x, double y) { + SDValidation.validateNumerical("eq", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(sd,x, y).outputVariable(); + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(String name, SDVariable x, double y) { + SDValidation.validateNumerical("eq", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("eq", "x", x); + SDValidation.validateNumerical("eq", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(sd,x, y).outputVariable(); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable eq(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("eq", "x", x); + SDValidation.validateNumerical("eq", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public SDVariable expandDims(SDVariable x, int axis) { + return new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(sd,x, axis).outputVariable(); + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public SDVariable expandDims(String name, SDVariable x, int axis) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(sd,x, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public SDVariable fill(SDVariable shape, DataType dataType, double value) { + SDValidation.validateInteger("fill", "shape", shape); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(sd,shape, dataType, value).outputVariable(); + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param name name May be null. Name for the output variable + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public SDVariable fill(String name, SDVariable shape, DataType dataType, double value) { + SDValidation.validateInteger("fill", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(sd,shape, dataType, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(SDVariable df, int[] indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param name name May be null. Name for the output variable + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(String name, SDVariable df, int[] indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(SDVariable df, SDVariable indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + SDValidation.validateInteger("gather", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param name name May be null. Name for the output variable + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) { + SDValidation.validateNumerical("gather", "df", df); + SDValidation.validateInteger("gather", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Gather(sd,df, indices, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gatherNd(SDVariable df, SDVariable indices) { + SDValidation.validateNumerical("gatherNd", "df", df); + SDValidation.validateNumerical("gatherNd", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.GatherNd(sd,df, indices).outputVariable(); + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param name name May be null. Name for the output variable + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) { + SDValidation.validateNumerical("gatherNd", "df", df); + SDValidation.validateNumerical("gatherNd", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.GatherNd(sd,df, indices).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(SDVariable x, double y) { + SDValidation.validateNumerical("gt", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(sd,x, y).outputVariable(); + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(String name, SDVariable x, double y) { + SDValidation.validateNumerical("gt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gt", "x", x); + SDValidation.validateNumerical("gt", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(sd,x, y).outputVariable(); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gt(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gt", "x", x); + SDValidation.validateNumerical("gt", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gte(SDVariable x, double y) { + SDValidation.validateNumerical("gte", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable gte(String name, SDVariable x, double y) { + SDValidation.validateNumerical("gte", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gte(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gte", "x", x); + SDValidation.validateNumerical("gte", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable gte(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("gte", "x", x); + SDValidation.validateNumerical("gte", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise identity operation: out = x
+ * + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable identity(SDVariable input) { + SDValidation.validateNumerical("identity", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(sd,input).outputVariable(); + } + + /** + * Elementwise identity operation: out = x
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable identity(String name, SDVariable input) { + SDValidation.validateNumerical("identity", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public SDVariable invertPermutation(SDVariable input) { + SDValidation.validateInteger("invertPermutation", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(sd,input).outputVariable(); + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param name name May be null. Name for the output variable + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public SDVariable invertPermutation(String name, SDVariable input) { + SDValidation.validateInteger("invertPermutation", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public SDVariable isNumericTensor(SDVariable x) { + SDValidation.validateNumerical("isNumericTensor", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(sd,x).outputVariable(); + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public SDVariable isNumericTensor(String name, SDVariable x) { + SDValidation.validateNumerical("isNumericTensor", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(DataType dataType, double start, double stop, long number) { + return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable(); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param name name May be null. Name for the output variable + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(String name, DataType dataType, double start, double stop, + long number) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(SDVariable start, SDVariable stop, SDVariable number, + DataType dataType) { + SDValidation.validateNumerical("linspace", "start", start); + SDValidation.validateNumerical("linspace", "stop", stop); + SDValidation.validateInteger("linspace", "number", number); + return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,start, stop, number, dataType).outputVariable(); + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param name name May be null. Name for the output variable + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public SDVariable linspace(String name, SDVariable start, SDVariable stop, SDVariable number, + DataType dataType) { + SDValidation.validateNumerical("linspace", "start", start); + SDValidation.validateNumerical("linspace", "stop", stop); + SDValidation.validateInteger("linspace", "number", number); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,start, stop, number, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(SDVariable x, double y) { + SDValidation.validateNumerical("lt", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(sd,x, y).outputVariable(); + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(String name, SDVariable x, double y) { + SDValidation.validateNumerical("lt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lt", "x", x); + SDValidation.validateNumerical("lt", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(sd,x, y).outputVariable(); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lt(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lt", "x", x); + SDValidation.validateNumerical("lt", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(SDVariable x, double y) { + SDValidation.validateNumerical("lte", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(String name, SDVariable x, double y) { + SDValidation.validateNumerical("lte", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lte", "x", x); + SDValidation.validateNumerical("lte", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(sd,x, y).outputVariable(); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable lte(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("lte", "x", x); + SDValidation.validateNumerical("lte", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public SDVariable matchCondition(SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchCondition", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(sd,in, condition).outputVariable(); + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public SDVariable matchCondition(String name, SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchCondition", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(sd,in, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDim, + int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, + boolean keepDim, int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable(); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, + int... dimensions) { + SDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, false, dimensions).outputVariable(); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Max(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable max(SDVariable first, SDVariable second) { + SDValidation.validateNumerical("max", "first", first); + SDValidation.validateNumerical("max", "second", second); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,first, second).outputVariable(); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable max(String name, SDVariable first, SDVariable second) { + SDValidation.validateNumerical("max", "first", first); + SDValidation.validateNumerical("max", "second", second); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,first, second).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, false, dimensions).outputVariable(); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable mean(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, false, dimensions).outputVariable(); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Min(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public SDVariable min(SDVariable first, SDVariable second) { + SDValidation.validateNumerical("min", "first", first); + SDValidation.validateNumerical("min", "second", second); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,first, second).outputVariable(); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param name name May be null. Name for the output variable + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public SDVariable min(String name, SDVariable first, SDVariable second) { + SDValidation.validateNumerical("min", "first", first); + SDValidation.validateNumerical("min", "second", second); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,first, second).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX, + boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(SDVariable x, double y) { + SDValidation.validateNumerical("neq", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(sd,x, y).outputVariable(); + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(String name, SDVariable x, double y) { + SDValidation.validateNumerical("neq", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("neq", "x", x); + SDValidation.validateNumerical("neq", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(sd,x, y).outputVariable(); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public SDVariable neq(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("neq", "x", x); + SDValidation.validateNumerical("neq", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, false, dimensions).outputVariable(); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm1(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, false, dimensions).outputVariable(); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable norm2(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, false, dimensions).outputVariable(); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable normmax(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off, + DataType dataType) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, + double off, DataType dataType) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, + double off) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(SDVariable indices, int depth) { + SDValidation.validateNumerical("oneHot", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth).outputVariable(); + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param name name May be null. Name for the output variable + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public SDVariable oneHot(String name, SDVariable indices, int depth) { + SDValidation.validateNumerical("oneHot", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable onesLike(SDVariable input) { + SDValidation.validateNumerical("onesLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input).outputVariable(); + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param name name May be null. Name for the output variable + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable onesLike(String name, SDVariable input) { + SDValidation.validateNumerical("onesLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable onesLike(SDVariable input, DataType dataType) { + SDValidation.validateNumerical("onesLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input, dataType).outputVariable(); + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable onesLike(String name, SDVariable input, DataType dataType) { + SDValidation.validateNumerical("onesLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OnesLike(sd,input, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(SDVariable x, SDVariable dimensions) { + SDValidation.validateNumerical("permute", "x", x); + SDValidation.validateInteger("permute", "dimensions", dimensions); + return new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(String name, SDVariable x, SDVariable dimensions) { + SDValidation.validateNumerical("permute", "x", x); + SDValidation.validateInteger("permute", "dimensions", dimensions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public SDVariable permute(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Permute(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, false, dimensions).outputVariable(); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable prod(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(double from, double to, double step, DataType dataType) { + return new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param name name May be null. Name for the output variable + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(String name, double from, double to, double step, DataType dataType) { + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { + SDValidation.validateNumerical("range", "from", from); + SDValidation.validateNumerical("range", "to", to); + SDValidation.validateNumerical("range", "step", step); + return new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param name name May be null. Name for the output variable + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, + DataType dataType) { + SDValidation.validateNumerical("range", "from", from); + SDValidation.validateNumerical("range", "to", to); + SDValidation.validateNumerical("range", "step", step); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.Range(sd,from, to, step, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public SDVariable rank(SDVariable in) { + SDValidation.validateNumerical("rank", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.Rank(sd,in).outputVariable(); + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public SDVariable rank(String name, SDVariable in) { + SDValidation.validateNumerical("rank", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Rank(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDValidation.validateNumerical("replaceWhere", "from", from); + return new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable(); + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param name name May be null. Name for the output variable + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, + Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDValidation.validateNumerical("replaceWhere", "from", from); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(SDVariable update, double value, Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + return new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable(); + } + + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param name name May be null. Name for the output variable + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public SDVariable replaceWhere(String name, SDVariable update, double value, + Condition condition) { + SDValidation.validateNumerical("replaceWhere", "update", update); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(SDVariable x, SDVariable shape) { + SDValidation.validateNumerical("reshape", "x", x); + SDValidation.validateNumerical("reshape", "shape", shape); + return new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(String name, SDVariable x, SDVariable shape) { + SDValidation.validateNumerical("reshape", "x", x); + SDValidation.validateNumerical("reshape", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(SDVariable x, long... shape) { + SDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reshape(String name, SDVariable x, long... shape) { + SDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Reshape(sd,x, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reverse(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(sd,x, dimensions).outputVariable(); + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reverse(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, + int batchDim) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, + int batchDim) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, -1, 0).outputVariable(); + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) { + SDValidation.validateNumerical("reverseSequence", "x", x); + SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, -1, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarFloorMod(SDVariable in, double value) { + SDValidation.validateNumerical("scalarFloorMod", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarFloorMod(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarFloorMod", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public SDVariable scalarMax(SDVariable in, double value) { + SDValidation.validateNumerical("scalarMax", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public SDVariable scalarMax(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarMax", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarMin(SDVariable in, double value) { + SDValidation.validateNumerical("scalarMin", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(sd,in, value).outputVariable(); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarMin(String name, SDVariable in, double value) { + SDValidation.validateNumerical("scalarMin", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(sd,in, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarSet(SDVariable in, double set) { + SDValidation.validateNumerical("scalarSet", "in", in); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(sd,in, set).outputVariable(); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public SDVariable scalarSet(String name, SDVariable in, double set) { + SDValidation.validateNumerical("scalarSet", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(sd,in, set).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterAdd", "ref", ref); + SDValidation.validateNumerical("scatterAdd", "indices", indices); + SDValidation.validateNumerical("scatterAdd", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterAdd", "ref", ref); + SDValidation.validateNumerical("scatterAdd", "indices", indices); + SDValidation.validateNumerical("scatterAdd", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterDiv", "ref", ref); + SDValidation.validateNumerical("scatterDiv", "indices", indices); + SDValidation.validateNumerical("scatterDiv", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterDiv", "ref", ref); + SDValidation.validateNumerical("scatterDiv", "indices", indices); + SDValidation.validateNumerical("scatterDiv", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMax", "ref", ref); + SDValidation.validateNumerical("scatterMax", "indices", indices); + SDValidation.validateNumerical("scatterMax", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMax", "ref", ref); + SDValidation.validateNumerical("scatterMax", "indices", indices); + SDValidation.validateNumerical("scatterMax", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMin", "ref", ref); + SDValidation.validateNumerical("scatterMin", "indices", indices); + SDValidation.validateNumerical("scatterMin", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMin", "ref", ref); + SDValidation.validateNumerical("scatterMin", "indices", indices); + SDValidation.validateNumerical("scatterMin", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterMul", "ref", ref); + SDValidation.validateNumerical("scatterMul", "indices", indices); + SDValidation.validateNumerical("scatterMul", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterMul", "ref", ref); + SDValidation.validateNumerical("scatterMul", "indices", indices); + SDValidation.validateNumerical("scatterMul", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterSub", "ref", ref); + SDValidation.validateNumerical("scatterSub", "indices", indices); + SDValidation.validateNumerical("scatterSub", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterSub", "ref", ref); + SDValidation.validateNumerical("scatterSub", "indices", indices); + SDValidation.validateNumerical("scatterSub", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { + SDValidation.validateNumerical("scatterUpdate", "ref", ref); + SDValidation.validateNumerical("scatterUpdate", "indices", indices); + SDValidation.validateNumerical("scatterUpdate", "updates", updates); + return new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(sd,ref, indices, updates).outputVariable(); + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param name name May be null. Name for the output variable + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, + SDVariable updates) { + SDValidation.validateNumerical("scatterUpdate", "ref", ref); + SDValidation.validateNumerical("scatterUpdate", "indices", indices); + SDValidation.validateNumerical("scatterUpdate", "updates", updates); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(sd,ref, indices, updates).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(sd,data, segmentIds).outputVariable(); + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param name name May be null. Name for the output variable + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds) { + SDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(sd,data, segmentIds).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param name name May be null. Name for the output variable + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param name name May be null. Name for the output variable + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen, + DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, dataType).outputVariable(); + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param name name May be null. Name for the output variable + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) { + SDValidation.validateNumerical("sequenceMask", "lengths", lengths); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public SDVariable shape(SDVariable input) { + SDValidation.validateNumerical("shape", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.Shape(sd,input).outputVariable(); + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public SDVariable shape(String name, SDVariable input) { + SDValidation.validateNumerical("shape", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Shape(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public SDVariable size(SDVariable in) { + SDValidation.validateNumerical("size", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.Size(sd,in).outputVariable(); + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public SDVariable size(String name, SDVariable in) { + SDValidation.validateNumerical("size", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Size(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public SDVariable sizeAt(SDVariable in, int dimension) { + SDValidation.validateNumerical("sizeAt", "in", in); + return new org.nd4j.linalg.api.ops.impl.shape.SizeAt(sd,in, dimension).outputVariable(); + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public SDVariable sizeAt(String name, SDVariable in, int dimension) { + SDValidation.validateNumerical("sizeAt", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SizeAt(sd,in, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(SDVariable input, int[] begin, int... size) { + SDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + return new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param name name May be null. Name for the output variable + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(String name, SDVariable input, int[] begin, int... size) { + SDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { + SDValidation.validateNumerical("slice", "input", input); + SDValidation.validateInteger("slice", "begin", begin); + SDValidation.validateInteger("slice", "size", size); + return new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param name name May be null. Name for the output variable + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public SDVariable slice(String name, SDVariable input, SDVariable begin, SDVariable size) { + SDValidation.validateNumerical("slice", "input", input); + SDValidation.validateInteger("slice", "begin", begin); + SDValidation.validateInteger("slice", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Slice(sd,input, begin, size).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, false, dimensions).outputVariable(); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeeze(SDVariable x, int axis) { + SDValidation.validateNumerical("squeeze", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x, axis).outputVariable(); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public SDVariable squeeze(String name, SDVariable x, int axis) { + SDValidation.validateNumerical("squeeze", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Squeeze(sd,x, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @param axis Axis to stack on + * @return output Output variable (NDARRAY type) + */ + public SDVariable stack(int axis, SDVariable... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + return new org.nd4j.linalg.api.ops.impl.shape.Stack(sd,values, axis).outputVariable(); + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param name name May be null. Name for the output variable + * @param axis Axis to stack on + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable stack(String name, int axis, SDVariable... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Stack(sd,values, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, + boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable(); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, + int... dimensions) { + SDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, + int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param name name May be null. Name for the output variable + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, + long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, + int shrinkAxisMask) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long... strides) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, 0, 0, 0, 0, 0).outputVariable(); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param name name May be null. Name for the output variable + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, + long... strides) { + SDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(sd,in, begin, end, strides, 0, 0, 0, 0, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, keepDims, dimensions).outputVariable(); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, false, dimensions).outputVariable(); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public SDVariable sum(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(sd,x, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, + int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int... dimensionsY) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, false, false, false).outputVariable(); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX, + int... dimensionsY) { + SDValidation.validateNumerical("tensorMmul", "x", x); + SDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(sd,x, y, dimensionsX, dimensionsY, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable tile(SDVariable x, SDVariable repeat) { + SDValidation.validateInteger("tile", "repeat", repeat); + return new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public SDVariable tile(String name, SDVariable x, SDVariable repeat) { + SDValidation.validateInteger("tile", "repeat", repeat); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public SDVariable tile(SDVariable x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + return new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param name name May be null. Name for the output variable + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public SDVariable tile(String name, SDVariable x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Tile(sd,x, repeat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public SDVariable transpose(SDVariable x) { + return new org.nd4j.linalg.api.ops.impl.shape.Transpose(sd,x).outputVariable(); + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public SDVariable transpose(String name, SDVariable x) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Transpose(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMax", "data", data); + SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMax", "data", data); + SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMean", "data", data); + SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMean", "data", data); + SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMin", "data", data); + SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentMin", "data", data); + SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentProd", "data", data); + SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentProd", "data", data); + SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSum", "data", data); + SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + return new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable(); + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param name name May be null. Name for the output variable + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public SDVariable unsortedSegmentSum(String name, SDVariable data, SDVariable segmentIds, + int numSegments) { + SDValidation.validateNumerical("unsortedSegmentSum", "data", data); + SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public SDVariable[] unstack(SDVariable value, int axis, int num) { + return new org.nd4j.linalg.api.ops.impl.shape.Unstack(sd,value, axis, num).outputVariables(); + } + + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public SDVariable[] unstack(String[] names, SDVariable value, int axis, int num) { + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Unstack(sd,value, axis, num).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, false, dimensions).outputVariable(); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable variance(String name, SDVariable x, boolean biasCorrected, int... dimensions) { + SDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable zerosLike(SDVariable input) { + SDValidation.validateNumerical("zerosLike", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(sd,input).outputVariable(); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param name name May be null. Name for the output variable + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public SDVariable zerosLike(String name, SDVariable input) { + SDValidation.validateNumerical("zerosLike", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index d367e3d4a..bb9f027c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -23,8 +23,8 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.enums.DataFormat; import org.nd4j.base.Preconditions; +import org.nd4j.enums.DataFormat; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; @@ -753,6 +753,33 @@ public class SDCNN extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables(); + } + + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input, + Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(sd,input, Pooling2DConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * 2D Convolution layer operation - max pooling 2d
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index f4a490813..ead137a57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2205,7 +2205,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAdd(SDVariable[] inputs) { + public SDVariable mergeAdd(SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); @@ -2219,7 +2219,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAdd(String name, SDVariable[] inputs) { + public SDVariable mergeAdd(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); @@ -2233,7 +2233,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAvg(SDVariable[] inputs) { + public SDVariable mergeAvg(SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); @@ -2247,7 +2247,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeAvg(String name, SDVariable[] inputs) { + public SDVariable mergeAvg(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); @@ -2261,7 +2261,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeMax(SDVariable[] inputs) { + public SDVariable mergeMax(SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); @@ -2275,7 +2275,7 @@ public class SDMath extends SDOps { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public SDVariable mergeMax(String name, SDVariable[] inputs) { + public SDVariable mergeMax(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index 6b1831de7..de8148c02 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -18,17 +18,15 @@ package org.nd4j.autodiff.samediff.ops; -import java.lang.String; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; -import lombok.NonNull; +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; @@ -43,28 +41,26 @@ public class SDRNN extends SDOps { * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "hLast", hLast); - return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); } /** * The GRU cell. Does a single time step operation
* - * @param name name May be null. Name for the output variable + * @param names names May be null. Arrays of names for the output variables. * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { SDValidation.validateNumerical("gru", "x", x); SDValidation.validateNumerical("gru", "hLast", hLast); - GRUCell c = new GRUCell(sd,x, hLast, GRUWeights); - return new GRUCellOutputs(c.outputVariables(name)); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); } /** @@ -75,39 +71,172 @@ public class SDRNN extends SDOps { * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, + public SDVariable[] lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast); - LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); - return new LSTMCellOutputs(c.outputVariables()); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables(); } /** * The LSTM cell. Does a single time step operation.
* - * @param name name May be null. Name for the output variable + * @param names names May be null. Arrays of names for the output variables. * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast, + public SDVariable[] lstmCell(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { SDValidation.validateNumerical("lstmCell", "x", x); SDValidation.validateNumerical("lstmCell", "cLast", cLast); SDValidation.validateNumerical("lstmCell", "yLast", yLast); - LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); - return new LSTMCellOutputs(c.outputVariables(name)); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); } /** - * The LSTM layer. Does multiple time steps.
+ * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(SDVariable x, SDVariable cLast, SDVariable yLast, + SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, + SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(SDVariable x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public SDVariable[] lstmLayer(String[] names, SDVariable x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + SDValidation.validateNumerical("lstmLayer", "x", x); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,x, null, null, null, LSTMLayerWeights, LSTMLayerConfig).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * The LSTM block
* * @param maxTSLength (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) @@ -117,17 +246,17 @@ public class SDRNN extends SDOps { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast, + public SDVariable lstmblock(SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - SDValidation.validateNumerical("lstmLayer", "x", x); - SDValidation.validateNumerical("lstmLayer", "cLast", cLast); - SDValidation.validateNumerical("lstmLayer", "yLast", yLast); - return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmblock", "x", x); + SDValidation.validateNumerical("lstmblock", "cLast", cLast); + SDValidation.validateNumerical("lstmblock", "yLast", yLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); } /** - * The LSTM layer. Does multiple time steps.
+ * The LSTM block
* * @param name name May be null. Name for the output variable * @param maxTSLength (NUMERIC type) @@ -138,13 +267,43 @@ public class SDRNN extends SDOps { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, + public SDVariable lstmblock(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - SDValidation.validateNumerical("lstmLayer", "x", x); - SDValidation.validateNumerical("lstmLayer", "cLast", cLast); - SDValidation.validateNumerical("lstmLayer", "yLast", yLast); - SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmblock", "x", x); + SDValidation.validateNumerical("lstmblock", "cLast", cLast); + SDValidation.validateNumerical("lstmblock", "yLast", yLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * The LSTM block
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmblock(SDVariable x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmblock", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable(); + } + + /** + * The LSTM block
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmblock(String name, SDVariable x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmblock", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(sd,null, x, null, null, LSTMWeights, LSTMConfiguration).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java new file mode 100644 index 000000000..f7f458ffd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/CellAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum CellAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java new file mode 100644 index 000000000..498f825fd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/GateAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum GateAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java new file mode 100644 index 000000000..cd8855b05 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDataFormat.java @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */ +public enum LSTMDataFormat { + TNS, + + NST, + + NTS, + + T2NS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java new file mode 100644 index 000000000..4732fc611 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/LSTMDirectionMode.java @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BIDIR_SUM: 2 = bidirectional sum + * BIDIR_CONCAT: 3 = bidirectional concat + * BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */ +public enum LSTMDirectionMode { + FWD, + + BWD, + + BIDIR_SUM, + + BIDIR_CONCAT, + + BIDIR_EXTRA_DIM +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java new file mode 100644 index 000000000..df034a294 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/OutAct.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Activations */ +public enum OutAct { + TANH, + + RELU, + + SIGMOID, + + AFFINE, + + LEAKY_RELU, + + THRESHHOLD_RELU, + + SCALED_TAHN, + + HARD_SIGMOID, + + ELU, + + SOFTSIGN, + + SOFTPLUS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java new file mode 100644 index 000000000..8b6e2fbd6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/RnnDataFormat.java @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * The data format of the input. Input shape depends on data format (in config):
+ * TNS -> [timeSteps, batchSize, inSize]
+ * NST -> [batchSize, inSize, timeSteps]
+ * NTS -> [batchSize, timeSteps, inSize]
*/ +public enum RnnDataFormat { + TNS, + + NST, + + NTS +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 62edb778f..043a16e87 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -146,6 +146,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index c60b11d23..ff139c236 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -301,24 +301,27 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } } - protected void checkForWorkspaces(CustomOp op) { - for (val input: op.inputArguments()) + protected void checkForWorkspaces(CustomOp op, OpContext oc) { + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkWorkspace(op.opName(), input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkWorkspace(op.opName(), output); } - protected void checkForWorkspaces(Op op) { - val x = op.x(); + protected void checkForWorkspaces(Op op, OpContext oc) { + val x = oc != null ? oc.getInputArray(0) : op.x(); if (x != null) checkWorkspace(op.opName(), x); - val y = op.y(); + val y = oc != null && oc.getInputArrays().size() > 1 ? oc.getInputArray(1) : op.y(); if (y != null) checkWorkspace(op.opName(), y); - val z = op.z(); + val z = oc != null ? oc.getOutputArray(0) : op.z(); if (z != null) checkWorkspace(op.opName(), z); } @@ -346,7 +349,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op, tadBuffers); break; case SCOPE_PANIC: - checkForWorkspaces(op); + checkForWorkspaces(op, null); return 0L; case DISABLED: default: @@ -357,7 +360,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public long profilingHookIn(CustomOp op) { + public long profilingHookIn(CustomOp op, OpContext oc) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processOpCall(op); @@ -368,7 +371,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op); break; case SCOPE_PANIC: - checkForWorkspaces(op); + checkForWorkspaces(op, oc); return 0L; case DISABLED: default: @@ -379,7 +382,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public void profilingHookOut(Op op, long timeStart) { + public void profilingHookOut(Op op, OpContext oc, long timeStart) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processStackCall(op, timeStart); @@ -392,14 +395,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); break; case NAN_PANIC: - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); break; case INF_PANIC: - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); break; case ANY_PANIC: - OpExecutionerUtil.checkForNaN(op); - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForNaN(op, oc); + OpExecutionerUtil.checkForInf(op, oc); break; case DISABLED: default: @@ -413,7 +416,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } @Deprecated - public void profilingHookOut(CustomOp op, long timeStart) { + public void profilingHookOut(CustomOp op, OpContext oc, long timeStart) { switch (profilingMode) { case ALL: OpProfiler.getInstance().processStackCall(op, timeStart); @@ -426,14 +429,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); break; case NAN_PANIC: - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); break; case INF_PANIC: - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); break; case ANY_PANIC: - OpExecutionerUtil.checkForNaN(op); - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForNaN(op, oc); + OpExecutionerUtil.checkForInf(op, oc); break; case DISABLED: default: @@ -442,12 +445,15 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } - public long profilingConfigurableHookIn(CustomOp op) { - for (val arr: op.inputArguments()) + public long profilingConfigurableHookIn(CustomOp op, OpContext oc) { + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val arr: inArgs) if (arr.wasClosed()) throw new IllegalStateException("One of Input arguments was closed before call"); - for (val arr: op.outputArguments()) + for (val arr: outArgs) if (arr.wasClosed()) throw new IllegalStateException("One of Output arguments was closed before call"); @@ -460,7 +466,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { - checkForWorkspaces(op); + checkForWorkspaces(op, oc); } return System.nanoTime(); @@ -491,14 +497,14 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().processOpCall(op, tadBuffers); } if (OpProfiler.getInstance().getConfig().isCheckWorkspaces()) { - checkForWorkspaces(op); + checkForWorkspaces(op, null); } return System.nanoTime(); } - public void profilingConfigurableHookOut(Op op, long timeStart) { + public void profilingConfigurableHookOut(Op op, OpContext oc, long timeStart) { if (OpProfiler.getInstance().getConfig() == null) return; @@ -509,10 +515,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); } if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); } if (OpProfiler.getInstance().getConfig().isCheckForINF()) { - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); } if (OpProfiler.getInstance().getConfig().isNativeStatistics()) { if (op.z() != null) { @@ -531,7 +537,7 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { } } - public void profilingConfigurableHookOut(CustomOp op, long timeStart) { + public void profilingConfigurableHookOut(CustomOp op, OpContext oc, long timeStart) { if (OpProfiler.getInstance().getConfig() == null) return; @@ -542,10 +548,10 @@ public abstract class DefaultOpExecutioner implements OpExecutioner { OpProfiler.getInstance().timeOpCall(op, timeStart); } if (OpProfiler.getInstance().getConfig().isCheckForNAN()) { - OpExecutionerUtil.checkForNaN(op); + OpExecutionerUtil.checkForNaN(op, oc); } if (OpProfiler.getInstance().getConfig().isCheckForINF()) { - OpExecutionerUtil.checkForInf(op); + OpExecutionerUtil.checkForInf(op, oc); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java index 080825433..83421e247 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.java @@ -22,12 +22,15 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.profiler.OpProfiler; +import java.util.List; + /**Utility functions for the DefaultOpExecutioner * @author Alex Black */ @@ -58,7 +61,7 @@ public class OpExecutionerUtil { } if (match > 0) - throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): "); + throw new ND4JOpProfilerException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s)"); } public static void checkForAny(INDArray z) { @@ -92,44 +95,52 @@ public class OpExecutionerUtil { } - public static void checkForNaN(Op op) { + public static void checkForNaN(Op op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; - if (op.z() != null && !(op instanceof MatchCondition)) { - checkForNaN(op.z()); + INDArray z = oc != null ? oc.getOutputArray(0) : op.z(); + if (z != null && !(op instanceof MatchCondition)) { + checkForNaN(z); } } - public static void checkForInf(Op op) { + public static void checkForInf(Op op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; - if (op.z() != null && !(op instanceof MatchCondition)) { - checkForInf(op.z()); + INDArray z = oc != null ? oc.getOutputArray(0) : op.z(); + if (z != null && !(op instanceof MatchCondition)) { + checkForInf(z); } } - public static void checkForInf(CustomOp op) { + public static void checkForInf(CustomOp op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForINF()) return; - for (val input: op.inputArguments()) + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkForInf(input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkForInf(output); } - public static void checkForNaN(CustomOp op) { + public static void checkForNaN(CustomOp op, OpContext oc) { if (!OpProfiler.getInstance().getConfig().isCheckForNAN()) return; - for (val input: op.inputArguments()) + List inArgs = oc != null ? oc.getInputArrays() : op.inputArguments(); + List outArgs = oc != null ? oc.getOutputArrays() : op.outputArguments(); + + for (val input: inArgs) checkForNaN(input); - for (val output: op.outputArguments()) + for (val output: outArgs) checkForNaN(output); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 0a7338814..5f5f6747a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -57,8 +57,12 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { addArgs(); } - public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ - super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax}); + public MaxPoolWithArgmax(@NonNull INDArray input, @NonNull Pooling2DConfig config){ + this(input, null, null, config); + } + + public MaxPoolWithArgmax(@NonNull INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ + super(null, new INDArray[]{input}, wrapFilterNull(output, outArgMax)); config.setType(Pooling2D.Pooling2DType.MAX); this.config = config; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index cf4e87814..fdc1b40fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -45,7 +45,7 @@ public class SConv2D extends Conv2D { } public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, - @NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java new file mode 100644 index 000000000..20d84a2d6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlock.java @@ -0,0 +1,144 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * LSTM layer implemented as a single operation. + * Implementation of operation for LSTM layer with optional peep hole connections.
+ * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and https://research.google.com/pubs/archive/43905.pdf
+ * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
+ * See also: https://arxiv.org/pdf/1503.04069.pdf
+ *

+ * See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.
+ *
+ * Input arrays:
+ * 0: max sequence length; long/int64 scalar
+ * 1: input [seqLength, bS, inSize] at time t
+ * 2: previous/initial cell state [bS, numUnits]
+ * 3: previous/initial output [bS, numUnits]
+ * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
+ * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
+ * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]
+ * 7: weights - cell peephole (t) connections to output gate, [numUnits]
+ * 8: biases, shape [4*numUnits]
+ *
+ * Input integer arguments: set via {@link LSTMConfiguration}
+ * 0: if not zero, provide peephole connections
+ * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
+ *
+ * Input float arguments: set via {@link LSTMConfiguration}
+ * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
+ * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
+ *

+ * Output arrays:
+ * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat
+ * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat
+ * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat
+ * 3: o - Output - output gate activations, rank 3, shape as per dataFormat
+ * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat
+ * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat
+ * 6: y (h) - Current cell output, rank 3, shape as per dataFormat
+ * + * @author Alex Black + */ +public class LSTMBlock extends DynamicCustomOp { + + private LSTMConfiguration configuration; + + @Getter + private LSTMWeights weights; + + public LSTMBlock() { + } + + public LSTMBlock(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); + this.configuration = configuration; + this.weights = weights; + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + public LSTMBlock(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); + this.configuration = lstmConfiguration; + this.weights = lstmWeights; + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMBlock, got %s", inputDataTypes); + //7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input + DataType dt = inputDataTypes.get(1); + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + return Arrays.asList(dt, dt, dt, dt, dt, dt, dt); + } + + @Override + public List doDiff(List grads) { + throw new UnsupportedOperationException("Not yet implemented"); + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + configuration = LSTMConfiguration.builder() + .forgetBias(attributesForNode.get("forget_bias").getF()) + .clippingCellValue(attributesForNode.get("cell_clip").getF()) + .peepHole(attributesForNode.get("use_peephole").getB()) + .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM + .build(); + addIArgument(configuration.iArgs(true)); + addTArgument(configuration.tArgs()); + } + + @Override + public String opName() { + return "lstmBlock"; + } + + @Override + public Map propertiesForFunction() { + return configuration.toProperties(true); + } + + @Override + public String tensorflowName() { + return "BlockLSTM"; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index 59b85f500..a433b23d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -1,5 +1,5 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -13,7 +13,6 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - package org.nd4j.linalg.api.ops.impl.layers.recurrent; import lombok.Getter; @@ -24,89 +23,103 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.shade.guava.primitives.Booleans; +import javax.xml.crypto.Data; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; + /** * LSTM layer implemented as a single operation. * Implementation of operation for LSTM layer with optional peep hole connections.
* S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation and https://research.google.com/pubs/archive/43905.pdf
* Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
* See also: https://arxiv.org/pdf/1503.04069.pdf
- *

- * See also {@link LSTMBlockCell} - lstmBlockCell op is used internally at C++ level for computation.
- *
* Input arrays:
- * 0: max sequence length; long/int64 scalar
- * 1: input [seqLength, bS, inSize] at time t
- * 2: previous/initial cell state [bS, numUnits]
- * 3: previous/initial output [bS, numUnits]
- * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
- * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
- * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits]
- * 7: weights - cell peephole (t) connections to output gate, [numUnits]
- * 8: biases, shape [4*numUnits]
- *
- * Input integer arguments: set via {@link LSTMConfiguration}
- * 0: if not zero, provide peephole connections
- * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
- *
- * Input float arguments: set via {@link LSTMConfiguration}
- * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
- * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
+ * 0: input
+ * [sL, bS, nIn] when dataFormat - TNS
+ * [bS, sL, nIn] when dataFormat - NST
+ * [bS, nIn, sL] when dataFormat - NST
+ * 1: previous/initial cell state
+ * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode
+ * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode
+ * 2: previous/initial output [bS, numUnits]
+ * * shapes [nIn, 4*nOut] for FWD, BWD Direction Mode
+ * * shapes [2, nIn, 4*nOut] BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM Direction Mode
+ * 3 max sequence length [bS]
+ * 4: LSTMLayerWeights - {@link LSTMLayerWeights}
+ * 5: LSTMLayerConfig - {@link LSTMLayerConfig}
*

* Output arrays:
- * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat
- * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat
- * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat
- * 3: o - Output - output gate activations, rank 3, shape as per dataFormat
- * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat
- * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat
- * 6: y (h) - Current cell output, rank 3, shape as per dataFormat
- * - * @author Alex Black + * 0: output h - rank 3 or 4, depends on DirectionMode and dataFormat
+ * 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<
+ * 2: cell state at last step cL - same shape as in hL
*/ public class LSTMLayer extends DynamicCustomOp { - private LSTMConfiguration configuration; + @Getter + private LSTMLayerConfig configuration; @Getter - private LSTMWeights weights; + private LSTMLayerWeights weights; + public LSTMLayer() { } - public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { - super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast)); + public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) { + super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); this.configuration = configuration; this.weights = weights; - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); + + } - public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration) { + public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig) { super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast)); - this.configuration = lstmConfiguration; + this.configuration = LSTMLayerConfig; this.weights = lstmWeights; - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); } @Override public List calculateOutputDataTypes(List inputDataTypes) { - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 9, "Expected exactly 9 inputs to LSTMLayer, got %s", inputDataTypes); + Preconditions.checkState(inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes); //7 outputs, all of same type as input. Note that input 0 is max sequence length (int64), input 1 is actual input DataType dt = inputDataTypes.get(1); + ArrayList list = new ArrayList<>(); + if (configuration.isRetFullSequence()) { + + list.add(dt); + } + + if (configuration.isRetLastC()) { + + list.add(dt); + } + if (configuration.isRetLastH()){ + + list.add(dt); + } + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); - return Arrays.asList(dt, dt, dt, dt, dt, dt, dt); + return list; } @Override @@ -114,31 +127,61 @@ public class LSTMLayer extends DynamicCustomOp { throw new UnsupportedOperationException("Not yet implemented"); } - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - configuration = LSTMConfiguration.builder() - .forgetBias(attributesForNode.get("forget_bias").getF()) - .clippingCellValue(attributesForNode.get("cell_clip").getF()) - .peepHole(attributesForNode.get("use_peephole").getB()) - .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM - .build(); - addIArgument(configuration.iArgs(true)); - addTArgument(configuration.tArgs()); - } @Override public String opName() { - return "lstmBlock"; + return "lstmLayer"; } @Override public Map propertiesForFunction() { - return configuration.toProperties(true); + return configuration.toProperties(true, true); + } + + + public long[] iArgs() { + return new long[]{ + configuration.getLstmdataformat().ordinal(),// INT_ARG(0) + configuration.getDirectionMode().ordinal(), // INT_ARG(1) + configuration.getGateAct().ordinal(), // INT_ARG(2) + configuration.getOutAct().ordinal(), // INT_ARG(3) + configuration.getCellAct().ordinal() // INT_ARG(4) + + }; + } + + public double[] tArgs() { + return new double[]{this.configuration.getCellClip()}; // T_ARG(0) + } + + + public boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + return new boolean[]{ + weights.hasBias(), // hasBiases: B_ARG(0) + maxTSLength != null, // hasSeqLen: B_ARG(1) + yLast != null, // hasInitH: B_ARG(2) + cLast != null, // hasInitC: B_ARG(3) + weights.hasPH(), // hasPH: B_ARG(4) + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + }; + } @Override - public String tensorflowName() { - return "BlockLSTM"; + public int getNumOutputs(){ + + return Booleans.countTrue( + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + ); } + + + } + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java new file mode 100644 index 000000000..27ebbc82f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMActivations.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + + /** + * integer numbers corresponding to activations: + * 0=tanh, + * 1=relu, + * 2=sigmoid, + * 3=affine, + * 4=leaky relu, + * 5= thresholded relu, + * 6=scaled tanh, + * 7=hard sigmoid, + * 8=ELU, + * 9=softsign, + * 10=softplus + */ + public enum LSTMActivations { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + TANH, + RELU, + SIGMOID, + AFFINE, + LEAKY_RELU, + THRESHHOLD_RELU, + SCALED_TAHN, + HARD_SIGMOID, + ELU, + SOFTSIGN, + SOFTPLUS + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java new file mode 100644 index 000000000..788e87d59 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDataFormat.java @@ -0,0 +1,41 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + + /** + * notations
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize]
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + */ + + public enum LSTMDataFormat { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + + TNS, + NTS, + NST, + T2NS + + } + + + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java new file mode 100644 index 000000000..c93bc05f9 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMDirectionMode.java @@ -0,0 +1,38 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + +/** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BIDIR_SUM: 2 = bidirectional sum + * BIDIR_CONCAT: 3 = bidirectional concat + * BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */ + +// const auto directionMode = INT_ARG(1); // direction: + +public enum LSTMDirectionMode { + //Note: ordinal (order) here matters for C++ level. Any new formats hsould be added at end + + + FWD, + BWD, + BIDIR_SUM, + BIDIR_CONCAT, + BIDIR_EXTRA_DIM + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java new file mode 100644 index 000000000..9901213da --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -0,0 +1,119 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; + +import lombok.Builder; +import lombok.Data; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; + +import java.util.LinkedHashMap; +import java.util.Map; + + +@Builder +@Data +public class LSTMLayerConfig { + + + /** + * notations
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + */ + @Builder.Default + private LSTMDataFormat lstmdataformat = LSTMDataFormat.TNS; //INT_ARG(0) + + + /** + * direction
+ * FWD: 0 = fwd + * BWD: 1 = bwd + * BS: 2 = bidirectional sum + * BC: 3 = bidirectional concat + * BE: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + */ + @Builder.Default + private LSTMDirectionMode directionMode = LSTMDirectionMode.FWD; //INT_ARG(1) + + /** + * Activation for input (i), forget (f) and output (o) gates + */ + @Builder.Default + private LSTMActivations gateAct = LSTMActivations.SIGMOID; // INT_ARG(2) + + @Builder.Default + private LSTMActivations cellAct = LSTMActivations.TANH; // INT_ARG(3) + + @Builder.Default + private LSTMActivations outAct = LSTMActivations.TANH; // INT_ARG(4) + + + + + /** + * indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + */ + @Builder.Default + private boolean retFullSequence = true; //B_ARG(5) + + /** + * indicates whether to return output at last time step only, + * in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + private boolean retLastH; //B_ARG(6) + + /** + * indicates whether to return cells state at last time step only, + * in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + private boolean retLastC; // B_ARG(7) + + /** + * Cell clipping value, if it = 0 then do not apply clipping + */ + @Builder.Default + private double cellClip; //T_ARG(0) + + + public Map toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) { + Map ret = new LinkedHashMap<>(); + ret.put("gateAct", gateAct.ordinal()); + ret.put("outAct", outAct.ordinal()); + ret.put("cellAct", cellAct.ordinal()); + ret.put("retFullSequence", retFullSequence); + ret.put("retLastH", retLastH); + ret.put("retLastC", retLastC); + ret.put("cellClip", cellClip); + + if (includeLSTMDataFormat) + ret.put("LSTMDataFormat", lstmdataformat.ordinal()); + if (includeLSTMDirectionMode) + ret.put("LSTMDirectionMode", directionMode.ordinal()); + return ret; + } + +} + + + + + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java index a01be219f..d8a2e6e9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -2,13 +2,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; import java.util.Arrays; import java.util.List; + import lombok.AccessLevel; import lombok.Getter; import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.shade.guava.primitives.Booleans; /** * The outputs of a LSTM layer ({@link LSTMLayer}. @@ -16,165 +21,78 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; @Getter public class LSTMLayerOutputs { - private RnnDataFormat dataFormat; + /** + * The LSTM layer data format ({@link LSTMDataFormat}. + */ + private LSTMDataFormat dataFormat; + /** - * Output - input modulation gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * output h: + * [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + * [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + * [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + * [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + * [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + * [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + * [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + * numbers mean index in corresponding enums {@link LSTMDataFormat} and {@link LSTMDirectionMode} */ - private SDVariable i; + private SDVariable timeSeriesOutput; /** - * Activations, cell state (pre tanh). - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * cell state at last step cL: + * [bS, nOut] when directionMode FWD or BWD + * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable c; + private SDVariable lastCellStateOutput; /** - * Output - forget gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * output at last step hL: + * [bS, nOut] when directionMode FWD or BWD + * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable f; + private SDVariable lastTimeStepOutput; - /** - * Output - output gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable o; - /** - * Output - input gate activations. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable z; + public LSTMLayerOutputs(SDVariable[] outputs, LSTMLayerConfig lstmLayerConfig) { + Preconditions.checkArgument(outputs.length > 0 && outputs.length <= 3, + "Must have from 1 to 3 LSTM layer outputs, got %s", outputs.length); - /** - * Cell state, post tanh. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable h; + int i = 0; + timeSeriesOutput = lstmLayerConfig.isRetFullSequence() ? outputs[i++] : null; + lastTimeStepOutput = lstmLayerConfig.isRetLastH() ? outputs[i++] : null; + lastCellStateOutput = lstmLayerConfig.isRetLastC() ? outputs[i++] : null; - /** - * Current cell output. - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - private SDVariable y; - public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){ - Preconditions.checkArgument(outputs.length == 7, - "Must have 7 LSTM layer outputs, got %s", outputs.length); - - i = outputs[0]; - c = outputs[1]; - f = outputs[2]; - o = outputs[3]; - z = outputs[4]; - h = outputs[5]; - y = outputs[6]; - this.dataFormat = dataFormat; + this.dataFormat = lstmLayerConfig.getLstmdataformat(); } - /** - * Get all outputs returned by the cell. - */ - public List getAllOutputs(){ - return Arrays.asList(i, c, f, o, z, h, y); - } /** - * Get y, the output of the cell for all time steps. - * - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
+ * Get h, the output of the cell for all time steps. + *

+ * Shape depends on data format defined in {@link LSTMLayerConfig }:
+ * for unidirectional: + * TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ * NST: shape [numExamples, inOutSize, timeLength]
+ * NTS: shape [numExamples, timeLength, inOutSize]
+ * for bidirectional: + * T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */ - public SDVariable getOutput(){ - return y; + public SDVariable getOutput() { + Preconditions.checkArgument(timeSeriesOutput != null, "retFullSequence was setted as false in LSTMLayerConfig"); + return timeSeriesOutput; } - /** - * Get c, the cell's state for all time steps. - * - * Shape depends on data format (in layer config):
- * TNS -> [timeSteps, batchSize, numUnits]
- * NST -> [batchSize, numUnits, timeSteps]
- * NTS -> [batchSize, timeSteps, numUnits]
- */ - public SDVariable getState(){ - return c; + public SDVariable getLastState() { + Preconditions.checkArgument(lastCellStateOutput != null, "retLastC was setted as false in LSTMLayerConfig"); + return lastCellStateOutput; } - private SDVariable lastOutput = null; - - /** - * Get y, the output of the cell, for the last time step. - * - * Has shape [batchSize, numUnits]. - */ - public SDVariable getLastOutput(){ - if(lastOutput != null) - return lastOutput; - - switch (dataFormat){ - case TNS: - lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); - break; - case NST: - lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); - break; - case NTS: - lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); - break; - } - return lastOutput; + public SDVariable getLastOutput() { + Preconditions.checkArgument(lastTimeStepOutput != null, "retLastH was setted as false in LSTMLayerConfig"); + return lastTimeStepOutput; } - private SDVariable lastState = null; - - /** - * Get c, the state of the cell, for the last time step. - * - * Has shape [batchSize, numUnits]. - */ - public SDVariable getLastState(){ - if(lastState != null) - return lastState; - - switch (dataFormat){ - case TNS: - lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); - break; - case NST: - lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); - break; - case NTS: - lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); - break; - } - return lastState; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java new file mode 100644 index 000000000..98985df57 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java @@ -0,0 +1,99 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.util.ArrayUtil; + +/** + * The weight configuration of a LSTMLayer. For {@link LSTMLayer} + * @author Alex Black + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class LSTMLayerWeights extends RNNWeights { + + /** + * Input to hidden weights with a shape of [inSize, 4*numUnits]. + * + * Input to hidden and hidden to hidden are concatenated in dimension 0, + * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. + */ + private SDVariable weights; + private INDArray iWeights; + + /** + * hidden to hidden weights (aka "recurrent weights", with a shape of [numUnits, 4*numUnits]. + * + */ + private SDVariable rWeights; + private INDArray irWeights; + + /** + * Peephole weights, with a shape of [3*numUnits]. + */ + private SDVariable peepholeWeights; + private INDArray iPeepholeWeights; + + /** + * Input to hidden and hidden to hidden biases, with shape [4*numUnits]. + */ + private SDVariable bias; + private INDArray iBias; + + @Override + public SDVariable[] args() { + return filterNonNull(weights, rWeights, peepholeWeights, bias); + } + + @Override + public INDArray[] arrayArgs() { + return filterNonNull(iWeights, irWeights, iPeepholeWeights, iBias); + } + + @Override + public SDVariable[] argsWithInputs(SDVariable... inputs){ + Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast + //lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast + return ArrayUtil.filterNull(inputs[0], weights, rWeights, bias, inputs[1], inputs[2], inputs[3], peepholeWeights); + } + + @Override + public INDArray[] argsWithInputs(INDArray... inputs) { + Preconditions.checkArgument(inputs.length == 4, "Expected 4 inputs, got %s", inputs.length); //Order: x, seqLen, yLast, cLast + //lstmLayer c++ op expects: x, Wx, Wr, Wp, b, seqLen, yLast, cLast + return ArrayUtil.filterNull(inputs[0], iWeights, irWeights, iBias, inputs[1], inputs[2], inputs[3], iPeepholeWeights); + } + + + public boolean hasBias() { + return (bias!=null||iBias!=null); + } + + public boolean hasPH() { + return (peepholeWeights!=null||iPeepholeWeights!=null); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index d22478e71..30ca8ebc5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -98,6 +98,7 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } public Mmul(INDArray x, INDArray y) { @@ -110,6 +111,7 @@ public class Mmul extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(transposeX), ArrayUtil.fromBoolean(transposeY), ArrayUtil.fromBoolean(transposeZ)); + mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); } public Mmul() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java index 7bccc8035..97fd5e538 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -49,6 +50,9 @@ public class BatchMmul extends DynamicCustomOp { protected int N; protected int K; + public BatchMmul(SameDiff sameDiff, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) { + this(sameDiff, ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB); + } public BatchMmul(SameDiff sameDiff, SDVariable[] matrices, @@ -85,6 +89,22 @@ public class BatchMmul extends DynamicCustomOp { addArgs(); } + public BatchMmul(INDArray[] matricesA, INDArray[] matricesB, boolean transposeA, boolean transposeB){ + super(ArrayUtils.addAll(matricesA, matricesB), null); + this.batchSize = matricesA.length; + + this.transposeA = transposeA ? 1 : 0; + this.transposeB = transposeB ? 1 : 0; + + long[] firstShape = matricesA[0].shape(); + long[] lastShape = matricesB[0].shape(); + + this.M = transposeA ? (int) firstShape[1]: (int) firstShape[0]; + this.N = transposeA ? (int) firstShape[0]: (int) firstShape[1]; + this.K = transposeB ? (int) lastShape[0]: (int) lastShape[1]; + addArgs(); + } + @Override public int getNumOutputs(){ return batchSize; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index a239bd9ec..593531098 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -34,17 +34,12 @@ import java.util.List; @NoArgsConstructor public class GatherNd extends DynamicCustomOp { - public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) { - super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false); + public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices) { + super(null, sameDiff, new SDVariable[] {input, indices}); } - public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { - super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); - } - - public GatherNd(INDArray[] df, INDArray[] indices) { - addInputArgument(df); - addInputArgument(indices); + public GatherNd(INDArray df, INDArray indices) { + super(new INDArray[]{df, indices}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index 6fca99eae..4bc3b3f63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -16,13 +16,16 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -41,21 +44,27 @@ public class Linspace extends DynamicCustomOp { private DataType dataType; public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { - super(sameDiff, new SDVariable[0]); - addTArgument(start,stop); - addIArgument(number); - addDArgument(dataType); + this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType); } public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ super(sameDiff, new SDVariable[]{from, to, length}); this.dataType = dataType; + addDArgument(dataType); } public Linspace(DataType dataType, double start, double stop, long number) { + this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number)); + } + + public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) { + this(start, stop, number, dataType); + } + + public Linspace(@NonNull INDArray start, @NonNull INDArray stop, @NonNull INDArray number, @NonNull DataType dataType) { + super(new INDArray[]{start, stop, number}, null); + this.dataType = dataType; addDArgument(dataType); - addTArgument(start, stop); - addIArgument(number); } public Linspace(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java index f2c11f1ef..ce83a808c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java @@ -16,9 +16,11 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.ArrayList; @@ -41,6 +43,11 @@ public class MeshGrid extends DynamicCustomOp { this(sd, cartesian, inputs); } + public MeshGrid(@NonNull INDArray[] inputs, boolean cartesian){ + super(inputs, null); + addIArgument(cartesian ? 1 : 0); + } + public MeshGrid(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 6feace53f..2126dfe27 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -44,7 +44,6 @@ import java.util.Map; public class Reshape extends DynamicCustomOp { private long[] shape; - private String arrName; public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) { super(null, sameDiff, new SDVariable[]{i_v}); @@ -56,6 +55,12 @@ public class Reshape extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{i_v, shape}); } + public Reshape(INDArray in, long... shape){ + super(new INDArray[]{in}, null); + this.shape = shape; + addIArgument(shape); + } + public Reshape(INDArray in, INDArray shape){ this(in, shape, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index a2ca91c65..3c3baf1f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -64,15 +65,19 @@ public class SequenceMask extends DynamicCustomOp { addDArgument(dataType); } - public SequenceMask(INDArray input, int maxLen, DataType dataType) { + public SequenceMask(@NonNull INDArray input, int maxLen, DataType dataType) { addInputArgument(input); addIArgument(maxLen); this.dataType = dataType; addDArgument(dataType); } - public SequenceMask(INDArray input, DataType dataType) { - addInputArgument(input); + public SequenceMask(@NonNull INDArray input, @NonNull DataType dataType) { + this(input, null, dataType); + } + + public SequenceMask(@NonNull INDArray input, INDArray maxLength, @NonNull DataType dataType) { + super(wrapFilterNull(input, maxLength), null); this.dataType = dataType; addDArgument(dataType); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index c5f7cdd70..46b8f6286 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -59,6 +59,10 @@ public class Slice extends DynamicCustomOp { addIArgument(size); } + public Slice(@NonNull INDArray input, @NonNull INDArray begin, @NonNull INDArray end){ + super(new INDArray[]{input, begin, end}, null); + } + @Override public String opName() { return "slice"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 89c459be3..17a8beb3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -50,7 +50,7 @@ public class Stack extends DynamicCustomOp { addArgs(); } - public Stack(INDArray input, int axis) { + public Stack(INDArray[] input, int axis) { addInputArgument(input); this.jaxis = axis; addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index a053403af..456edfe1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -98,10 +98,16 @@ public class StridedSlice extends DynamicCustomOp { public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + this(in, ArrayUtil.toLongArray(begin), ArrayUtil.toLongArray(end), ArrayUtil.toLongArray(strides), + beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + + public StridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { addInputArgument(in); - this.begin = ArrayUtil.toLongArray(begin); - this.end = ArrayUtil.toLongArray(end); - this.strides = ArrayUtil.toLongArray(strides); + this.begin = begin; + this.end = end; + this.strides = strides; this.beginMask = beginMask; this.endMask = endMask; this.ellipsisMask = ellipsisMask; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java index 6f9e94de0..b8d952b37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -67,6 +68,13 @@ public class Unstack extends DynamicCustomOp { addArgs(); } + public Unstack(@NonNull INDArray value, int axis, int num){ + super(new INDArray[]{value}, null); + this.jaxis = axis; + this.num = num; + addArgs(); + } + public Unstack(INDArray in, INDArray[] out, int axis){ super(null, new INDArray[]{in}, out, null, (int[])null); this.jaxis = axis; @@ -136,7 +144,8 @@ public class Unstack extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[f1.size()]))); + return Collections.singletonList(sameDiff.stack(jaxis, f1.toArray(new SDVariable[0]))); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index 8d0a9d0d6..b7bd0e0f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -58,6 +58,10 @@ public class Pad extends DynamicCustomOp { this(sd, in, padding, Mode.CONSTANT, padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, double padValue){ + this(in, padding, null, Mode.CONSTANT, padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index b64581b49..3efc13af0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -66,11 +66,8 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } - public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) { + public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) { addInputArgument(input); - for (INDArray part : partitions) - addInputArgument(part); - addIArgument(numPartitions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java index d1b2fdfdb..880717fb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java @@ -16,9 +16,11 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -30,10 +32,14 @@ public class ListDiff extends DynamicCustomOp { // } - public ListDiff(SameDiff sd, SDVariable x, SDVariable y){ + public ListDiff(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){ super(sd, new SDVariable[]{x, y}); } + public ListDiff(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x, y}, null); + } + @Override public String tensorflowName() { return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API? diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java index 1ae979ec7..7b8dbf209 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java @@ -73,12 +73,8 @@ public class XwPlusB extends DynamicCustomOp { SDVariable dLdOut = gradient.get(0); SDVariable dLdb = dLdOut.sum(0); - SDVariable dLdIn = sameDiff.mmul(dLdOut, w, MMulTranspose.builder() - .transposeB(true) - .build()); - SDVariable dLdW = sameDiff.mmul(in, dLdOut, MMulTranspose.builder() - .transposeA(true) - .build()); + SDVariable dLdIn = sameDiff.mmul(dLdOut, w, false, true, false); + SDVariable dLdW = sameDiff.mmul(in, dLdOut, true, false, false); return Arrays.asList(dLdIn, dLdW, dLdb); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index a7dd7c2b9..d588ef4a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -28,6 +28,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -55,24 +56,11 @@ public class Cast extends BaseDynamicTransformOp { addArgs(); } -/* - @Override - public void setValueFor(Field target, Object value) { - if(value == null) { - throw new ND4JIllegalStateException("Unable to set field " + target + " using null value!"); - } - - // FIXME! - if (!(value instanceof DataType)) - return; - - try { - target.set(this, (DataType) value); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } + public Cast(@NonNull INDArray arg, @NonNull DataType dataType){ + super(new INDArray[]{arg}, null); + this.typeDst = dataType; + addArgs(); } - */ @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index ade01281c..b024659c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -73,6 +74,12 @@ public class Range extends DynamicCustomOp { addDArgument(dataType); } + public Range(INDArray from, INDArray to, INDArray step, DataType dataType){ + super(new INDArray[]{from, to, step}, null); + this.dataType = dataType; + addDArgument(dataType); + } + @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index cfaf00d18..64e6a96b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -149,6 +149,60 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); } + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param transposeA Whether to transpose A arrays or not + * @param transposeB Whether to transpose B arrays or not + */ + public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA, + boolean transposeB) { + NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB)); + } + + /** + * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
+ * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
+ * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
+ * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
+ *
+ * The result of this operation will be a batch of multiplied matrices. The
+ * result has the same length as both input batches and each output matrix is of shape (M, K).
+ * + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + */ + public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) { + NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); + Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); + Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); + } + + /** + * Cast the array to a new datatype - for example, Integer -> Float
+ * + * @param arg Input variable to cast (NDARRAY type) + * @param datatype Datatype to cast to + * @return output Output array (after casting) (NDARRAY type) + */ + public INDArray castTo(INDArray arg, DataType datatype) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast(arg, datatype))[0]; + } + /** * Concatenate a set of inputs along the specified dimension.
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
@@ -161,7 +215,7 @@ public class NDBase { * @param dimension Dimension to concatenate on * @return output (NUMERIC type) */ - public INDArray concat(INDArray[] inputs, int dimension) { + public INDArray concat(int dimension, INDArray... inputs) { NDValidation.validateNumerical("concat", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); @@ -274,28 +328,26 @@ public class NDBase { * @param x Input variable (NUMERIC type) * @param partitions 1D input with values 0 to numPartitions-1 (INT type) * @param numPartitions Number of partitions, >= 1 - * @return output Output variables (equal in number to numPartitions) (NUMERIC type) */ - public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) { + public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) { NDValidation.validateNumerical("dynamicPartition", "x", x); NDValidation.validateInteger("dynamicPartition", "partitions", partitions); - Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions)); } /** * Dynamically merge the specified input arrays into a single array, using the specified indices
* - * @param x Input variables. (NUMERIC type) * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ - public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) { - NDValidation.validateNumerical("dynamicStitch", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + public INDArray dynamicStitch(INDArray[] indices, INDArray... x) { NDValidation.validateInteger("dynamicStitch", "indices", indices); Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0]; + NDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; } /** @@ -395,11 +447,9 @@ public class NDBase { * @param indices (NUMERIC type) * @return output (NUMERIC type) */ - public INDArray gatherNd(INDArray[] df, INDArray[] indices) { + public INDArray gatherNd(INDArray df, INDArray indices) { NDValidation.validateNumerical("gatherNd", "df", df); - Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length); NDValidation.validateNumerical("gatherNd", "indices", indices); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; } @@ -516,6 +566,23 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; } + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) + * @param dataType Data type of the output array + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public INDArray linspace(INDArray start, INDArray stop, INDArray number, DataType dataType) { + NDValidation.validateNumerical("linspace", "start", start); + NDValidation.validateNumerical("linspace", "stop", stop); + NDValidation.validateInteger("linspace", "number", number); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; + } + /** * Less than operation: elementwise x < y
* @@ -1071,6 +1138,20 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; } + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Permute dimensions (INT type) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public INDArray permute(INDArray x, INDArray dimensions) { + NDValidation.validateNumerical("permute", "x", x); + NDValidation.validateInteger("permute", "dimensions", dimensions); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + } + /** * Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
@@ -1141,6 +1222,24 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; } + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) { + NDValidation.validateNumerical("range", "from", from); + NDValidation.validateNumerical("range", "to", to); + NDValidation.validateNumerical("range", "step", step); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + } + /** * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
* @@ -1168,6 +1267,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); } + /** + * Element-wise replace where condition:
+ * out[i] = value if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public INDArray replaceWhere(INDArray update, double value, Condition condition) { + NDValidation.validateNumerical("replaceWhere", "update", update); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition)); + } + /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
@@ -1183,6 +1297,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; } + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray reshape(INDArray x, long... shape) { + NDValidation.validateNumerical("reshape", "x", x); + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + } + /** * Reverse the values of an array for the specified dimensions
* If input is:
@@ -1532,6 +1661,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + NDValidation.validateInteger("sequenceMask", "maxLen", maxLen); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + } + /** * see sequenceMask(String, SDVariable, SDVariable, DataType)
* @@ -1601,6 +1745,28 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; } + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @return output Subset of the input (NUMERIC type) + */ + public INDArray slice(INDArray input, INDArray begin, INDArray size) { + NDValidation.validateNumerical("slice", "input", input); + NDValidation.validateInteger("slice", "begin", begin); + NDValidation.validateInteger("slice", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + } + /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
* @@ -1668,7 +1834,8 @@ public class NDBase { * @param axis Axis to stack on * @return output Output variable (NDARRAY type) */ - public INDArray stack(INDArray values, int axis) { + public INDArray stack(int axis, INDArray... values) { + Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; } @@ -1737,7 +1904,7 @@ public class NDBase { * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions * @return output A subset of the input array (NUMERIC type) */ - public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { NDValidation.validateNumerical("stridedSlice", "in", in); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); @@ -1762,7 +1929,7 @@ public class NDBase { * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) * @return output A subset of the input array (NUMERIC type) */ - public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) { + public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) { NDValidation.validateNumerical("stridedSlice", "in", in); Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); @@ -1999,6 +2166,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; } + /** + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
+ * If input has shape [a,b,c] then output has shape:
+ * axis = 0: [b,c]
+ * axis = 1: [a,c]
+ * axis = 2: [a,b]
+ * + * @param value Input variable to unstack (NDARRAY type) + * @param axis Axis to unstack on + * @param num Number of output variables + */ + public INDArray[] unstack(INDArray value, int axis, int num) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num)); + } + /** * Variance array reduction operation, optionally along specified dimensions
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java index cb00a28c2..1e3c89111 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.DataFormat; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; @@ -32,7 +33,6 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.enums.DataFormat; public class NDCNN { public NDCNN() { @@ -370,6 +370,18 @@ public class NDCNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(input, LocalResponseNormalizationConfig))[0]; } + /** + * 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + */ + public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) { + NDValidation.validateNumerical("maxPoolWithArgmax", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax(input, Pooling2DConfig)); + } + /** * 2D Convolution layer operation - max pooling 2d
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java index cdee59ea1..184f3edea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java @@ -222,15 +222,12 @@ public class NDLoss { * * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ - public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) { + public INDArray logLoss(INDArray label, INDArray predictions) { NDValidation.validateNumerical("logLoss", "label", label); NDValidation.validateNumerical("logLoss", "predictions", predictions); - NDValidation.validateNumerical("logLoss", "weights", weights); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index eddbe3db7..bee0da889 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -190,6 +190,58 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x)); } + /** + * Bit shift operation
+ * + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShift(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShift", "x", x); + NDValidation.validateNumerical("bitShift", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0]; + } + + /** + * Right bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShiftRight(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRight", "x", x); + NDValidation.validateNumerical("bitShiftRight", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0]; + } + + /** + * Cyclic bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public INDArray bitShiftRotl(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRotl", "x", x); + NDValidation.validateNumerical("bitShiftRotl", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0]; + } + + /** + * Cyclic right shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public INDArray bitShiftRotr(INDArray x, INDArray shift) { + NDValidation.validateNumerical("bitShiftRotr", "x", x); + NDValidation.validateNumerical("bitShiftRotr", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0]; + } + /** * Element-wise ceiling function: out = ceil(x).
* Rounds each value up to the nearest integer value (if not already an integer)
@@ -346,13 +398,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("cosineDistance", "x", x); NDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions)); } @@ -363,13 +415,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("cosineSimilarity", "x", x); NDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions)); } @@ -501,13 +553,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("euclideanDistance", "x", x); NDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions)); } @@ -665,13 +717,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("hammingDistance", "x", x); NDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions)); } @@ -817,13 +869,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("jaccardDistance", "x", x); NDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions)); } @@ -872,6 +924,18 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions)); } + /** + * Calculates difference between inputs X and Y.
+ * + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public INDArray[] listDiff(INDArray x, INDArray y) { + NDValidation.validateNumerical("listDiff", "x", x); + NDValidation.validateNumerical("listDiff", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(x, y)); + } + /** * Element-wise logarithm function (base e - natural logarithm): out = log(x)
* @@ -940,13 +1004,13 @@ public class NDMath { * * @param x Input variable x (NUMERIC type) * @param y Input variable y (NUMERIC type) - * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1)) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("manhattanDistance", "x", x); NDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions)); } @@ -983,7 +1047,7 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeAdd(INDArray[] inputs) { + public INDArray mergeAdd(INDArray... inputs) { NDValidation.validateNumerical("mergeAdd", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0]; @@ -996,7 +1060,7 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeAvg(INDArray[] inputs) { + public INDArray mergeAvg(INDArray... inputs) { NDValidation.validateNumerical("mergeAvg", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0]; @@ -1009,12 +1073,24 @@ public class NDMath { * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ - public INDArray mergeMax(INDArray[] inputs) { + public INDArray mergeMax(INDArray... inputs) { NDValidation.validateNumerical("mergeMax", "inputs", inputs); Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0]; } + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param inputs (NUMERIC type) + * @param cartesian + */ + public INDArray[] meshgrid(INDArray[] inputs, boolean cartesian) { + NDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian)); + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 04a713ecf..3f9e1431a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -237,12 +237,11 @@ public class NDNN { * Alpha value is most commonly set to 0.01
* * @param x Input variable (NUMERIC type) - * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ - public INDArray leakyRelu(INDArray x, INDArray alpha) { + public INDArray leakyRelu(INDArray x, double alpha) { NDValidation.validateNumerical("leakyRelu", "x", x); - NDValidation.validateNumerical("leakyRelu", "alpha", alpha); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha)); } @@ -250,12 +249,11 @@ public class NDNN { * Leaky ReLU derivative: dOut/dIn given input.
* * @param x Input variable (NUMERIC type) - * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ - public INDArray leakyReluDerivative(INDArray x, INDArray alpha) { + public INDArray leakyReluDerivative(INDArray x, double alpha) { NDValidation.validateNumerical("leakyReluDerivative", "x", x); - NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha)); } @@ -346,6 +344,20 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public INDArray pad(INDArray input, INDArray padding, double constant) { + NDValidation.validateNumerical("pad", "input", input); + NDValidation.validateNumerical("pad", "padding", padding); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0]; + } + /** * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
* out[i] = in[i] if in[i] >= 0
@@ -461,6 +473,17 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; } + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softmax(INDArray x) { + NDValidation.validateNumerical("softmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, -1))[0]; + } + /** * Softmax derivative function
* @@ -519,4 +542,15 @@ public class NDNN { NDValidation.validateNumerical("swish", "x", x); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x)); } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tanh(INDArray x) { + NDValidation.validateNumerical("tanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java index 0587aeda5..9bb7d9640 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java @@ -22,7 +22,9 @@ import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.factory.NDValidation; @@ -38,12 +40,11 @@ public class NDRNN { * @param x Input, with shape [batchSize, inSize] (NUMERIC type) * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object - * @return output The cell's outputs. (NUMERIC type) */ - public INDArray gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { + public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { NDValidation.validateNumerical("gru", "x", x); NDValidation.validateNumerical("gru", "hLast", hLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights)); } /** @@ -54,18 +55,83 @@ public class NDRNN { * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) * @param LSTMWeights Configuration Object * @param LSTMConfiguration Configuration Object - * @return output The cell's outputs (NUMERIC type) */ - public INDArray lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, + public INDArray[] lstmCell(INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { NDValidation.validateNumerical("lstmCell", "x", x); NDValidation.validateNumerical("lstmCell", "cLast", cLast); NDValidation.validateNumerical("lstmCell", "yLast", yLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(x, cLast, yLast, LSTMWeights, LSTMConfiguration)); } /** - * The LSTM layer. Does multiple time steps.
+ * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param maxTSLength maxTSLength with shape [batchSize] (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public INDArray[] lstmLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, + LSTMLayerWeights LSTMLayerWeights, LSTMLayerConfig LSTMLayerConfig) { + NDValidation.validateNumerical("lstmLayer", "x", x); + NDValidation.validateNumerical("lstmLayer", "cLast", cLast); + NDValidation.validateNumerical("lstmLayer", "yLast", yLast); + NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, cLast, yLast, maxTSLength, LSTMLayerWeights, LSTMLayerConfig)); + } + + /** + * Long Short-Term Memory layer - Hochreiter 1997.
+ * SUPPORTS following data formats:\n
+ * for unidirectional: \n" +
+ * TNS: shapes [timeLength, numExamples, inOutSize]\n
+ * NST: shapes [numExamples, inOutSize, timeLength]\n
+ * NTS: shapes [numExamples, timeLength, inOutSize]
+ * for bidirectional:\n
+ * T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n
+ * SUPPORTS following direction modes:\n
+ * FWD: forward
+ * BWD: backward
+ * BIDIR_SUM: bidirectional sum\n
+ * BIDIR_CONCAT: bidirectional concat\n" +
+ * BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"
+ * You may use different gate configurations:
+ * specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n
+ * ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n
+ * Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMLayerWeights Configuration Object + * @param LSTMLayerConfig Configuration Object + */ + public INDArray[] lstmLayer(INDArray x, LSTMLayerWeights LSTMLayerWeights, + LSTMLayerConfig LSTMLayerConfig) { + NDValidation.validateNumerical("lstmLayer", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(x, null, null, null, LSTMLayerWeights, LSTMLayerConfig)); + } + + /** + * The LSTM block
* * @param maxTSLength (NUMERIC type) * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) @@ -75,13 +141,27 @@ public class NDRNN { * @param LSTMConfiguration Configuration Object * @return output The layer's outputs. (NUMERIC type) */ - public INDArray lstmLayer(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, + public INDArray lstmblock(INDArray maxTSLength, INDArray x, INDArray cLast, INDArray yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { - NDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); - NDValidation.validateNumerical("lstmLayer", "x", x); - NDValidation.validateNumerical("lstmLayer", "cLast", cLast); - NDValidation.validateNumerical("lstmLayer", "yLast", yLast); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + NDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength); + NDValidation.validateNumerical("lstmblock", "x", x); + NDValidation.validateNumerical("lstmblock", "cLast", cLast); + NDValidation.validateNumerical("lstmblock", "yLast", yLast); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration))[0]; + } + + /** + * The LSTM block
+ * + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public INDArray lstmblock(INDArray x, LSTMWeights LSTMWeights, + LSTMConfiguration LSTMConfiguration) { + NDValidation.validateNumerical("lstmblock", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock(null, x, null, null, LSTMWeights, LSTMConfiguration))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index a6ccd25ed..bda208ce7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -199,7 +199,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -436,7 +436,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -524,7 +524,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op); naiveExec(op, dimension); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -607,7 +607,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return op.z(); } @@ -772,7 +772,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -863,7 +863,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; @@ -1113,7 +1113,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); Nd4j.getExecutioner().commit(); @@ -1200,7 +1200,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, null, st); return null; } @@ -1296,7 +1296,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -1460,7 +1460,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (ret != null) ret.elementWiseStride(); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return null; } @@ -1579,7 +1579,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return z; } @@ -2292,7 +2292,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, OpContext context) { - long st = profilingConfigurableHookIn(op); + long st = profilingConfigurableHookIn(op, context); val ctx = AtomicAllocator.getInstance().getDeviceContext(); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); @@ -2304,7 +2304,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, context, st); if (context.getOutputArrays().isEmpty()) return new INDArray[0]; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 7a29f71d7..f0488636f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -236,7 +236,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return getZ(op, oc); } @@ -690,7 +690,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return getZ(op, oc); } @@ -774,7 +774,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (z == null) setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); -// op.setZ(Nd4j.create(op.resultType(), op.x().shape())); op.validateDataTypes(oc, experimentalMode.get()); @@ -884,7 +883,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); } public INDArray exec(BroadcastOp op) { @@ -1306,7 +1305,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, oc, st); return z; } @@ -2040,7 +2039,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, @NonNull OpContext context) { - long st = profilingConfigurableHookIn(op); + long st = profilingConfigurableHookIn(op, context); boolean mklOverride = false; try { if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) { @@ -2125,7 +2124,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } finally { if (mklOverride) Nd4jCpu.Environment.getInstance().setUseMKLDNN(true); - profilingConfigurableHookOut(op, st); + profilingConfigurableHookOut(op, context, st); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index eab974821..794348369 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -20,8 +20,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; + import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; @@ -36,6 +38,12 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.factory.Nd4j; @@ -257,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, 2, 2, true); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -578,8 +586,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable dW = sd.var("dW", depthWeightArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, dW, b}; - Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -588,8 +594,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().separableConv2d(in, dW, b, c); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().separableConv2d(in, dW, null, b, c); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -623,8 +629,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable pW = sd.var("pW", pointWeightArr); SDVariable b = sd.var("b", bArr); - //SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; - Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -635,7 +639,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c); - out = sd.nn().tanh(out); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 @@ -675,8 +679,6 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable w = sd.var("W", wArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, w, b}; - DeConv2DConfig deconv = DeConv2DConfig.builder() .kH(kH).kW(kW) .pH(0).pW(0) @@ -685,8 +687,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.f().deconv2d(vars, deconv); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().deconv2d(in, w, b, deconv); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 @@ -723,7 +725,6 @@ public class LayerOpValidation extends BaseOpValidation { //Order: https://github.com/deeplearning4j/libnd4j/blob/6c41ea5528bb1f454e92a9da971de87b93ff521f/include/ops/declarable/generic/convo/conv2d.cpp#L20-L22 //in, w, b - bias is optional - SDVariable[] vars = new SDVariable[]{in, w, b}; Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) @@ -733,8 +734,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.f().conv2d(vars, c); - out = sd.f().tanh(out); + SDVariable out = sd.cnn().conv2d("conv", in, w, b, c); + out = sd.nn().tanh("out", out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -767,7 +768,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig); + SDVariable[] results = sd.cnn().maxPoolWithArgmax(new String[]{"out", "idx"}, in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -797,7 +798,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); - SDVariable out = sd.f().tanh(/*"out",*/ outPool); + SDVariable out = sd.nn().tanh("out", outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -855,7 +856,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); - SDVariable out = sd.f().tanh(/*"out",*/ outPool); + SDVariable out = sd.nn().tanh("out", outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -906,7 +907,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); - out = sd.f().tanh(/*"loss", */out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); @@ -942,7 +943,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -976,8 +977,8 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -996,7 +997,7 @@ public class LayerOpValidation extends BaseOpValidation { int nOut = 4; int mb = 2; - for( int k : new int[]{2, 3}) { + for (int k : new int[]{2, 3}) { for (int sz : new int[]{3, 4, 5}) { for (int s : new int[]{1, 2}) { for (int d : new int[]{1, 2}) { @@ -1018,7 +1019,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); - SDVariable loss = sd.f().tanh(out).std(true).rename("loss"); + SDVariable loss = sd.nn().tanh(out).std(true).rename("loss"); sd.setLossVariables("loss"); @@ -1039,7 +1040,7 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testConv1dForward(){ + public void testConv1dForward() { int nIn = 2; int nOut = 1; int kernel = 3; @@ -1057,7 +1058,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ @@ -1113,7 +1114,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1156,7 +1157,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); - out = sd.math().tanh("loss", out).shape().rename("out"); + out = sd.nn().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1201,13 +1202,13 @@ public class LayerOpValidation extends BaseOpValidation { public void testLayerNorm4d() { int mb = 3; int ch = 4; - for(boolean nchw : new boolean[]{true, false}) { + for (boolean nchw : new boolean[]{true, false}) { double eps = 0.0; INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray mean = x.mean(true, 1, 2, 3); - INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); + INDArray std = Transforms.sqrt(x.var(false, 1, 2, 3).addi(eps)).reshape(mb, 1, 1, 1); INDArray standardized = x.sub(mean).div(std); INDArray exp = standardized.mul(gain4d).add(bias4d); @@ -1274,7 +1275,7 @@ public class LayerOpValidation extends BaseOpValidation { final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(DataType.DOUBLE,4); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain); final INDArray output = Nd4j.zerosLike(res); @@ -1287,7 +1288,7 @@ public class LayerOpValidation extends BaseOpValidation { public void testLayerNormNoDeviation() { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); for (int i = 0; i < 4; i++) { - random.putScalar(1,i, 7); + random.putScalar(1, i, 7); } final INDArray standardized = random.ulike(); @@ -1335,7 +1336,7 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); } @@ -1391,16 +1392,16 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNormMixedOrders(){ + public void testLayerNormMixedOrders() { Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); - INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); - INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); - INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); - INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); + INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f'); + INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c'); + INDArray outFC = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'c'); + INDArray outCF = Nd4j.create(DataType.DOUBLE, new long[]{3, 8}, 'f'); //F in, F out case Nd4j.exec(DynamicCustomOp.builder("layer_norm") @@ -1441,11 +1442,11 @@ public class LayerOpValidation extends BaseOpValidation { public void testBiasAdd_nchw_nhwc() { Nd4j.getRandom().setSeed(12345); - for(boolean nchw : new boolean[]{true, false}) { + for (boolean nchw : new boolean[]{true, false}) { log.info("Starting test: {}", nchw ? "nchw" : "nhwc"); SameDiff sameDiff = SameDiff.create(); - SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4})); + SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2, 4, 3, 3} : new long[]{2, 3, 3, 4})); SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4})); SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw); @@ -1453,10 +1454,10 @@ public class LayerOpValidation extends BaseOpValidation { INDArray exp = in.getArr().dup(); - if(nchw){ - exp.addi(b.getArr().reshape(1,4,1,1)); + if (nchw) { + exp.addi(b.getArr().reshape(1, 4, 1, 1)); } else { - exp.addi(b.getArr().reshape(1,1,1,4)); + exp.addi(b.getArr().reshape(1, 1, 1, 4)); } TestCase tc = new TestCase(sameDiff) @@ -1467,4 +1468,168 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err); } } + + + @Test + public void LSTMLayerTestCase1() { + + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + + + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL)); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NST) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(), + c), c); + + long[] out = new long[]{bS, numUnits, sL}; + long[] hL = new long[]{bS, numUnits}; + long[] cL = new long[]{bS, numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape()); + assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape()); + + + } + + + @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + public void LSTMLayerTestCase2() { + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn)); + + + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.TNS) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(false) + .retLastH(false) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .build(), + c), c); + + + long[] out = new long[]{sL, bS, numUnits}; + assertArrayEquals(out, outputs.getOutput().eval().shape()); + + } + + @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + public void LSTMLayerTestCase3() { + int bS = 5; + int nIn = 3; + int numUnits = 7; + int sL = 10; //small just for test + + SameDiff sd = SameDiff.create(); + + // notations: + // bS - batch size, numExamples + // sL - sequence length, number of time steps, timeLength + // nIn - input size, inOutSize + + // TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"
+ // NST: shape [numExamples, inOutSize, timeLength]
+ // NTS: shape [numExamples, timeLength, inOutSize]
+ // for bidirectional: + // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) + SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn)); + + + // when directionMode >= 2 (BIDIR_CONCAT=3) + // Wx, Wr [2, nIn, 4*nOut] + // hI, cI [2, bS, nOut] + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NTS) + .directionMode(LSTMDirectionMode.BIDIR_CONCAT) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.SOFTPLUS) + .outAct(LSTMActivations.SOFTPLUS) + .retFullSequence(true) + .retLastC(false) + .retLastH(false) + .build(); + + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"}, + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits))) + .build(), + c), c); + + + long[] out = new long[]{bS, sL, 2 * numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 3998bc184..47f383f52 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -548,7 +548,7 @@ public class MiscOpValidation extends BaseOpValidation { INDArray arr2 = Nd4j.rand(new long[]{2, 2, 2}); SDVariable x = sameDiff.var("x", arr); SDVariable y = sameDiff.var("y", arr2); - SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}}); + SDVariable result = sameDiff.tensorMmul(x, y, new int[]{0}, new int[]{1}); assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.eval().shape()); assertEquals(16, sameDiff.numElements()); @@ -689,13 +689,7 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable a = sd.var("a", aArr); SDVariable b = sd.var("b", bArr); - MMulTranspose mt = MMulTranspose.builder() - .transposeA(transposeA) - .transposeB(transposeB) - .transposeResult(transposeResult) - .build(); - - SDVariable mmul = sd.mmul(a, b, mt); + SDVariable mmul = sd.mmul(a, b, transposeA, transposeB, transposeResult); INDArray exp = (transposeA ? aArr.transpose() : aArr); exp = exp.mmul(transposeB ? bArr.transpose() : bArr); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 69385f814..c47d02b04 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -70,7 +70,7 @@ public class RnnOpValidation extends BaseOpValidation { LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); - LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y + LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.name()); @@ -173,7 +173,7 @@ public class RnnOpValidation extends BaseOpValidation { LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); - LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y + LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.name()); @@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation { .cBias(bc) .build(); - List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); + SDVariable[] v = sd.rnn().gru(x, hLast, weights); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.name()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 0cbe52479..47394de1e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -119,7 +119,7 @@ public class ShapeOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); - for (int[] toShape : new int[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) { + for (long[] toShape : new long[][]{{3, 4 * 5}, {3 * 4, 5}, {1, 3 * 4 * 5}, {3 * 4 * 5, 1}}) { for(char order : new char[]{'c','f'}){ INDArray inArr = Nd4j.rand(DataType.DOUBLE, origShape, order).muli(100); @@ -388,10 +388,10 @@ public class ShapeOpValidation extends BaseOpValidation { @Builder(builderClassName = "Builder") @Data private static class SSCase { - private int[] shape; - private int[] begin; - private int[] end; - private int[] strides; + private long[] shape; + private long[] begin; + private long[] end; + private long[] strides; private int beginMask; private int endMask; private int ellipsisMask; @@ -400,22 +400,22 @@ public class ShapeOpValidation extends BaseOpValidation { public static class Builder { - public Builder shape(int... shape) { + public Builder shape(long... shape) { this.shape = shape; return this; } - public Builder begin(int... begin) { + public Builder begin(long... begin) { this.begin = begin; return this; } - public Builder end(int... end) { + public Builder end(long... end) { this.end = end; return this; } - public Builder strides(int... strides) { + public Builder strides(long... strides) { this.strides = strides; return this; } @@ -1571,7 +1571,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); SDVariable x1 = sameDiff.var("x1", arr1); SDVariable x2 = sameDiff.var("x2", arr2); - SDVariable result = sameDiff.parallel_stack(new SDVariable[]{x1, x2}); + SDVariable result = sameDiff.stack(0, new SDVariable[]{x1, x2}); assertArrayEquals(new long[]{2, 3, 2}, result.eval().shape()); assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval()); } @@ -1661,9 +1661,9 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice_full = sd.stridedSlice(in, new int[]{0, 0}, new int[]{3, 4}, new int[]{1, 1}); - SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); - // SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); + SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1}); + SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1}); + // SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2}); sd.outputAll(null); @@ -1679,8 +1679,8 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); - SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); + SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0); + SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0); sd.outputAll(null); @@ -1695,9 +1695,9 @@ public class ShapeOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); //[1:3,...] -> [1:3,:,:] - SDVariable slice = sd.stridedSlice(in, new int[]{1}, new int[]{3}, new int[]{1}, 0, 0, 1 << 1, 0, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0); //[1:3,...,1:4] -> [1:3,:,1:4] - SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); + SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0); sd.outputAll(Collections.emptyMap()); @@ -1710,7 +1710,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); INDArray out = slice.eval(); @@ -1723,7 +1723,7 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); + SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); INDArray out = slice.eval(); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); @@ -1735,9 +1735,9 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); - SDVariable slice = sd.stridedSlice(in, new int[]{0, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); - SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); - SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); + SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); + SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); + SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); sd.outputAll(null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 9be66f484..27a15b517 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -1920,7 +1920,7 @@ public class TransformOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable sdA = sd.var("a", a); SDVariable sdB = sd.var("b", b); - SDVariable t = sd.mmul(sdA, sdB, MMulTranspose.builder().transposeA(transposeA).transposeB(transposeB).transposeResult(transposeResult).build()); + SDVariable t = sd.mmul(sdA, sdB, transposeA, transposeB, transposeResult); t.norm1("out"); String err = OpValidation.validate(new TestCase(sd) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 88915e35b..3e33534e1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -759,8 +759,7 @@ public class SameDiffTests extends BaseNd4jTest { val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); val input1 = sd.var("input", matrix); val input2 = sd.var("input2", vector); - val output = sd - .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); + val output = sd.mmul("output", input1, input2, true, false, false); INDArray out = output.eval(); assertArrayEquals(new long[]{3, 1}, out.shape()); } @@ -2675,7 +2674,7 @@ public class SameDiffTests extends BaseNd4jTest { final long timeSteps = sdInput.getShape()[2]; SDVariable[] outputSlices = new SDVariable[(int) timeSteps]; - final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); + final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2); final val x_0 = inputSlices[0]; outputSlices[0] = x_0; @@ -2702,7 +2701,7 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff sd = SameDiff.create(); final SDVariable sdInput = sd.var("input", input); - final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2); + final SDVariable[] inputSlices = sd.unstack(new String[]{"X_0", "X_1"}, sdInput, 2, 2); final val temp = inputSlices[0].add(inputSlices[1]).div(inputSlices[1]).mul(inputSlices[0]); final val out = temp.add(temp).add(inputSlices[1]); out.norm2("out"); @@ -3242,61 +3241,61 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testNestedIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable a = SD.var("a", Nd4j.createFromArray(2.0)); - SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); - SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); - SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0)); + SameDiff sd = SameDiff.create(); + SDVariable a = sd.var("a", Nd4j.createFromArray(2.0)); + SDVariable b = sd.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = sd.var("c", Nd4j.createFromArray(9.0)); + SDVariable d = sd.var("d", Nd4j.createFromArray(-7.0)); - SDVariable output = SD.ifCond("out", null, - (sd) -> a.lt(b), - (sd) -> sd.ifCond( + SDVariable output = sd.ifCond("out", null, + (s) -> a.lt(b), + (s) -> s.ifCond( (sd2) -> d.lte(0), (sd2) -> c.add(1), (sd2) -> d), - (sd) -> c.add(5)); + (s) -> c.add(5)); INDArray out = output.eval(); assertEquals(Nd4j.createFromArray(10.0), out); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(Nd4j.createFromArray(10.0), SD.output(Collections.emptyMap(), "out").get("out")); + assertEquals(Nd4j.createFromArray(10.0), sd.output(Collections.emptyMap(), "out").get("out")); } @Test public void testWhile() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); - SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gt(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); + SDVariable[] sum = sd.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gt(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); INDArray out = sum[1].eval(); assertEquals(15, out.getInt(0)); String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(15, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(15, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test @Ignore public void testNestedWhile() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); - SDVariable sum2 = SD.constant(0); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); + SDVariable sum2 = sd.constant(0); //TODO creating constant instead of using sum2 causes errors - SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gt(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), - vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2}, + SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gt(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), + vars[1].add(s.whileLoop(new SDVariable[]{vars[0], sum2}, (sd2, vars2) -> vars2[0].gt(0), (sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])}); @@ -3305,23 +3304,23 @@ public class SameDiffTests extends BaseNd4jTest { String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(35, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(35, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test public void testNestedWhileIf() throws IOException { - SameDiff SD = SameDiff.create(); - SDVariable countIn = SD.constant(5); - SDVariable sumIn = SD.constant(0); - SDVariable hundred = SD.constant(100); + SameDiff sd = SameDiff.create(); + SDVariable countIn = sd.constant(5); + SDVariable sumIn = sd.constant(0); + SDVariable hundred = sd.constant(100); - SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, - (sd, vars) -> vars[0].gte(0), - (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( - sd.ifCond((sd2) -> vars[0].eq(0), + SDVariable[] sum = sd.whileLoop(new SDVariable[]{countIn, sumIn}, + (s, vars) -> vars[0].gte(0), + (s, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( + s.ifCond((sd2) -> vars[0].eq(0), (sd2) -> vars[0].add(100), //TODO replace with hundred and things break (sd2) -> vars[0]) )}); @@ -3331,9 +3330,9 @@ public class SameDiffTests extends BaseNd4jTest { String outName = sum[1].name(); - SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + sd = SameDiff.fromFlatBuffers(sd.asFlatBuffers(false)); - assertEquals(115, SD.output(Collections.emptyMap(), outName).get(outName).getInt(0)); + assertEquals(115, sd.output(Collections.emptyMap(), outName).get(outName).getInt(0)); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index 03b469e70..2c1c284bc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -61,7 +61,7 @@ public class OpsMappingTests extends BaseNd4jTest { @Override public long getTimeoutMilliseconds() { - return 180000L; //Can be slow on some CI machines such as PPC + return 360000L; //Can be very slow on some CI machines (PPC) } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 867fe1611..22c6e3a52 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -29,7 +29,10 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.api.ops.impl.shape.Concat; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -473,6 +476,7 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.exec(op); //Should trigger NaN panic fail(); } catch (Exception e){ + e.printStackTrace(); assertTrue(e.getMessage(), e.getMessage().contains("Inf")); } @@ -488,4 +492,55 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForINF(false).build()); } } + + + @Test + public void testOpProfilerOpContextLegacy(){ + + for(boolean nan : new boolean[]{true, false}) { + + INDArray in = Nd4j.valueArrayOf(10, nan ? -1 : 0).castTo(DataType.FLOAT); + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build()); + + OpContext oc = Nd4j.getExecutioner().buildContext(); + oc.setInputArray(0, in); + oc.setOutputArray(0, in.ulike()); + try { + Nd4j.exec(new Log(), oc); + System.out.println(oc.getOutputArray(0)); + fail("Expected op profiler exception"); + } catch (Throwable t) { + //OK + assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + } + } + } + + @Test + public void testOpProfilerOpContextCustomOp(){ + + for(boolean nan : new boolean[]{true, false}) { + + INDArray in = Nd4j.create(DataType.DOUBLE, 10).assign(nan ? Double.NaN : Double.POSITIVE_INFINITY); + INDArray in2 = in.dup(); + + + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().checkForNAN(nan).checkForINF(!nan).build()); + + OpContext oc = Nd4j.getExecutioner().buildContext(); + oc.setIArguments(0); + oc.setInputArray(0, in); + oc.setInputArray(1, in2); + oc.setOutputArray(0, Nd4j.create(DataType.DOUBLE, 20)); + try { + Nd4j.exec(new Concat(), oc); + System.out.println(oc.getOutputArray(0)); + fail("Expected op profiler exception"); + } catch (Throwable t) { + //OK + assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + } + } + } } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index 35e5607a2..39c09e627 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -3579,4 +3579,19 @@ public class ArrayUtil { } return false; } + + public static T[] filterNull(T... in){ + int count = 0; + for( int i=0; i