From 191bda3228cba86c8b74143e6708ff28fdb120b5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 20 Apr 2020 16:57:00 +1000 Subject: [PATCH] Base namespace (#287) * wip Signed-off-by: Robert Altena * up to assign operation. Signed-off-by: Robert Altena * fix Imax, IMin. Signed-off-by: Robert Altena * concat. Signed-off-by: Robert Altena * dynamicPartition Signed-off-by: Robert Altena * new ops up to gte. Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * updated review items. Signed-off-by: Robert Altena * up to matchCondition. Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * up to OneHot. Signed-off-by: Robert Altena * wip. up to permute. Signed-off-by: Robert Altena * wip. up to rank. Signed-off-by: Robert Altena * wip. up to scatterMul. Signed-off-by: Robert Altena * resolving code review issues. Signed-off-by: Robert Altena * wip. inclides UnsortedSegment ops. Signed-off-by: Robert Altena * wip. up to stridedSlice. Signed-off-by: Robert Altena * fix stridedSlice. Signed-off-by: Robert Altena * first pass of SDBaseops.kt complete. Signed-off-by: Robert Altena * fix review items. Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * put branch in compilable state. Signed-off-by: Robert Altena * add NDBaseTest. fix dynamicPartition signature. failed fix of assign. Signed-off-by: Robert Altena * make tests public. Signed-off-by: Robert Altena * adds tests up to invertedPermutation. Signed-off-by: Robert Altena * fix ScalarEquals, Assign. Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * updates NDBaseTest. Signed-off-by: Robert Altena * updates 'check' comments based on test pass/fail. Signed-off-by: Robert Altena * fix scalar ops. Update tests, Signed-off-by: Robert Altena * dev-tools review items. wip. Signed-off-by: Robert Altena * dev-tools code review items. Signed-off-by: Robert Altena * Test fixes Signed-off-by: Alex Black * complete review items. Signed-off-by: Robert Altena * Comment for logged issue; fix test case Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * More fixes * wip Signed-off-by: Robert Altena * undo changes to Nd4jCpu.java Signed-off-by: Robert Altena * update tests. Signed-off-by: Robert Altena * Fixes and regenerate Signed-off-by: AlexDBlack * Small test fixes Signed-off-by: AlexDBlack * small fixes to tests. Signed-off-by: Robert Altena * Cleanup Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Small CUDAExecutioner fix Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Small CudaExecutioner fix Signed-off-by: Alex Black * Another small CudaExecutioner fix Signed-off-by: Alex Black * Another small CudaExecutioner fix Signed-off-by: Alex Black Co-authored-by: Robert Altena --- .../linalg/api/ops/impl/indexaccum/IMax.java | 9 +- .../linalg/api/ops/impl/indexaccum/IMin.java | 1 + .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 5 +- .../api/ops/impl/reduce/TensorMmul.java | 57 +- .../api/ops/impl/reduce/floating/NormMax.java | 1 - .../ops/impl/reduce/floating/SquaredNorm.java | 4 + .../impl/reduce/longer/MatchCondition.java | 10 +- .../linalg/api/ops/impl/reduce/same/Sum.java | 1 - .../impl/scalar/comparison/ScalarEquals.java | 2 +- .../scalar/comparison/ScalarGreaterThan.java | 4 +- .../comparison/ScalarGreaterThanOrEqual.java | 4 +- .../scalar/comparison/ScalarLessThan.java | 4 +- .../comparison/ScalarLessThanOrEqual.java | 4 +- .../scalar/comparison/ScalarNotEquals.java | 3 +- .../api/ops/impl/scatter/ScatterAdd.java | 9 +- .../api/ops/impl/scatter/ScatterDiv.java | 8 +- .../api/ops/impl/scatter/ScatterMax.java | 9 +- .../api/ops/impl/scatter/ScatterMin.java | 9 +- .../api/ops/impl/scatter/ScatterMul.java | 9 +- .../api/ops/impl/scatter/ScatterSub.java | 8 +- .../api/ops/impl/scatter/ScatterUpdate.java | 8 +- .../linalg/api/ops/impl/shape/Concat.java | 2 +- .../linalg/api/ops/impl/shape/ExpandDims.java | 11 +- .../linalg/api/ops/impl/shape/GatherNd.java | 4 + .../linalg/api/ops/impl/shape/Linspace.java | 1 + .../linalg/api/ops/impl/shape/OneHot.java | 17 +- .../api/ops/impl/shape/ParallelStack.java | 5 + .../linalg/api/ops/impl/shape/Permute.java | 11 +- .../linalg/api/ops/impl/shape/Reshape.java | 8 +- .../nd4j/linalg/api/ops/impl/shape/Size.java | 5 +- .../nd4j/linalg/api/ops/impl/shape/Slice.java | 6 +- .../api/ops/impl/shape/StridedSlice.java | 1 + .../nd4j/linalg/api/ops/impl/shape/Tile.java | 13 +- .../linalg/api/ops/impl/shape/Transpose.java | 4 +- .../bool/MatchConditionTransform.java | 2 - .../comparison/CompareAndReplace.java | 2 +- .../ops/impl/transforms/custom/Assign.java | 5 + .../transforms/custom/DynamicPartition.java | 15 +- .../impl/transforms/custom/DynamicStitch.java | 11 +- .../ops/impl/transforms/custom/EqualTo.java | 8 +- .../api/ops/impl/transforms/custom/Fill.java | 13 +- .../impl/transforms/custom/GreaterThan.java | 8 +- .../transforms/custom/GreaterThanOrEqual.java | 4 +- .../transforms/custom/IsNumericTensor.java | 4 +- .../ops/impl/transforms/custom/LessThan.java | 8 +- .../transforms/custom/LessThanOrEqual.java | 8 +- .../api/ops/impl/transforms/custom/Max.java | 8 +- .../api/ops/impl/transforms/custom/Min.java | 8 +- .../impl/transforms/custom/NotEqualTo.java | 8 +- .../transforms/custom/ReverseSequence.java | 16 +- .../transforms/custom/segment/SegmentMax.java | 4 +- .../custom/segment/SegmentMean.java | 8 +- .../transforms/custom/segment/SegmentMin.java | 8 +- .../custom/segment/SegmentProd.java | 8 +- .../transforms/custom/segment/SegmentSum.java | 8 +- .../ops/impl/transforms/same/Identity.java | 2 +- .../segment/UnsortedSegmentMax.java | 9 +- .../segment/UnsortedSegmentMean.java | 5 +- .../segment/UnsortedSegmentMin.java | 5 +- .../segment/UnsortedSegmentProd.java | 5 +- .../segment/UnsortedSegmentSqrtN.java | 12 +- .../segment/UnsortedSegmentSum.java | 5 +- .../org/nd4j/linalg/factory/NDValidation.java | 30 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 6 +- .../nd4j/linalg/indexing/BooleanIndexing.java | 4 +- .../nd4j/linalg/learning/AdaGradUpdater.java | 1 + .../nd4j/linalg/learning/NadamUpdater.java | 2 - .../ops/executioner/CudaExecutioner.java | 15 +- .../nativecpu/ops/NativeOpExecutioner.java | 4 +- .../nd4j/linalg/factory/ops/NDBaseTest.java | 1030 +++++++++++++++++ .../linalg/indexing/BooleanIndexingTest.java | 16 +- 71 files changed, 1338 insertions(+), 234 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index c01be78f9..127239bc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -45,15 +45,14 @@ public class IMax extends BaseIndexAccumulation { super(x, z, dimensions); } - public IMax(INDArray x, boolean keepDims, int... dimensions) { - super(x, keepDims, dimensions); - - } - public IMax(INDArray x, int... dimensions) { super(x, null, dimensions); } + public IMax(INDArray x, boolean keepDims, int... dimensions) { + super(x, null, dimensions); + this.keepDims = keepDims; + } @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index e668f1ee0..a459e8c9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -53,6 +53,7 @@ public class IMin extends BaseIndexAccumulation { } + @Override public int opNum() { return 1; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 479c794c2..d4bec6ac2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -141,8 +141,8 @@ public class Mmul extends DynamicCustomOp { boolean transposeZ) { super(null,sameDiff,new SDVariable[]{x,y}); addIArgument(ArrayUtil.fromBoolean(transposeX), - ArrayUtil.fromBoolean(transposeY), - ArrayUtil.fromBoolean(transposeZ)); + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); addTArgument(alpha, beta); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); @@ -306,4 +306,3 @@ public class Mmul extends DynamicCustomOp { return Collections.singletonList(dataTypes.get(0)); } } - diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index f58347492..820df18ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -51,6 +51,35 @@ public class TensorMmul extends DynamicCustomOp { protected boolean addedEdges; protected MMulTranspose mMulTranspose; + + public TensorMmul(INDArray x, INDArray y, int[][] axes) { + this(x,y,axes[0], axes[1], false, false, false); + } + + /** + * Initialize with the given + * input, pairwise transform, result, and number + * of elements + * + * @param x the input + * @param y the pairwise transform + * @param z the result + */ + public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) { + this(x, y, axes[0], axes[1], false, false, false); + } + + public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null,new INDArray[]{x, y},null); + this.axes = new int[][]{dimensionsX, dimensionsY}; + addIArgument(dimensionsX.length); + addIArgument(dimensionsX); + addIArgument(dimensionsY.length); + addIArgument(dimensionsY); + addBArgument(transposeX, transposeY, transposeZ); + } + public TensorMmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, @@ -229,34 +258,6 @@ public class TensorMmul extends DynamicCustomOp { return sameDiff.reshape(ret, aPlusB); } - - public TensorMmul(INDArray x, INDArray y, int[][] axes) { - super(null,new INDArray[]{x, y},null); - this.axes = axes; - this.extraArgs = new Object[] {axes}; - } - - /** - * Initialize with the given - * input, pairwise transform, result, and number - * of elements - * - * @param x the input - * @param y the pairwise transform - * @param z the result - */ - public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) { - super(null,new INDArray[]{x, y, z},null); - this.axes = axes; - } - - public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, - boolean transposeX, boolean transposeY, boolean transposeZ) { - super(null,new INDArray[]{x, y},null); - this.axes = new int[][]{dimensionsX, dimensionsY}; - addBArgument(transposeX, transposeY, transposeZ); - } - @Override public String opName() { return "tensordot"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java index a7cf398f5..ea3fd140d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java @@ -48,7 +48,6 @@ public class NormMax extends BaseReduceFloatOp { super(x, null, z, dimensions); } - public NormMax(INDArray x, int... dimensions) { super(x, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index 2af86c181..f80a712d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -48,6 +48,10 @@ public class SquaredNorm extends BaseReduceFloatOp { public SquaredNorm(){} + public SquaredNorm(INDArray x, int... dimensions){ + super(x, dimensions); + } + @Override public int opNum() { return 7; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java index 0fb4db830..f2f097aa9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java @@ -52,11 +52,15 @@ public class MatchCondition extends BaseReduceLongOp { public MatchCondition() {} - public MatchCondition(INDArray x, Condition condition, int... dimensions) { this(x, Nd4j.EPS_THRESHOLD, condition, dimensions); } + public MatchCondition(INDArray x, Condition condition, boolean keepDims, int... dimensions) { + this(x, Nd4j.EPS_THRESHOLD, condition, dimensions); + this.keepDims = keepDims; + } + public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) { super(x); this.compare = condition.getValue(); @@ -68,10 +72,6 @@ public class MatchCondition extends BaseReduceLongOp { defineDimensions(dimensions); } - public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) { - this(in, condition, dimensions); - } - @Override public int opNum() { return 2; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index e6fa79bb0..f2c0b1d40 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -41,7 +41,6 @@ public class Sum extends BaseReduceSameOp { super(sameDiff, i_v, i_v2, dimensions); } - public Sum() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java index ad2aa9b50..bafe8db88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java @@ -40,7 +40,7 @@ public class ScalarEquals extends BaseScalarBoolOp { } public ScalarEquals(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java index eafbcbc1a..524e66baa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -56,7 +58,7 @@ public class ScalarGreaterThan extends BaseScalarBoolOp { } public ScalarGreaterThan(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java index 0948c01ab..09c001dda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -40,7 +42,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp { } public ScalarGreaterThanOrEqual(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 6f72490a1..740d05a79 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -18,9 +18,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -39,7 +41,7 @@ public class ScalarLessThan extends BaseScalarBoolOp { } public ScalarLessThan(INDArray x, Number num) { - super(x, num); + this(x, null, num); } public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java index 6c9a3a893..343051ec6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -49,7 +51,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { } public ScalarLessThanOrEqual(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java index f050b686e..52f4b7a99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java @@ -41,10 +41,9 @@ public class ScalarNotEquals extends BaseScalarBoolOp { } public ScalarNotEquals(INDArray x, Number num) { - super(x, num); + this(x, null, num); } - public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) { super(sameDiff, i_v, scalar); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 1846ab8f8..1c524ccf1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,12 +45,12 @@ public class ScatterAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterAdd(){} + public ScatterAdd(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_add"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 75badc9c2..c2993eb23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,11 +45,12 @@ public class ScatterDiv extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); + public ScatterDiv() {} + + public ScatterDiv(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); } - public ScatterDiv() {} @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 2dead9742..12fe3380a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -42,12 +43,12 @@ public class ScatterMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMax(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMax() {} + public ScatterMax(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_max"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 4af8a2cd2..91a49487e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -42,12 +43,12 @@ public class ScatterMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMin(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMin() {} + public ScatterMin(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_min"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 48e1a00bd..705b85d3d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,12 +45,12 @@ public class ScatterMul extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMul(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMul() {} + public ScatterMul(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_mul"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index f66f7d689..15e6d5ac2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -44,12 +44,12 @@ public class ScatterSub extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterSub(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterSub() {} + public ScatterSub(INDArray ref, INDArray indices, INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_sub"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index c5644faa5..2e87af624 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -54,12 +54,12 @@ public class ScatterUpdate extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterUpdate(){} + public ScatterUpdate(INDArray ref, INDArray indices, INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_upd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index bddcef970..85cf62247 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -44,7 +44,7 @@ public class Concat extends DynamicCustomOp { } public Concat(int concatDimension, INDArray... arrays) { - super(null, arrays, new INDArray[0]); + super(null, arrays, null); this.concatDimension = concatDimension; addIArgument(concatDimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index f0a6f436a..c51207d64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -67,15 +67,16 @@ public class ExpandDims extends DynamicCustomOp { super(null, inputs, outputs); } - public ExpandDims(INDArray input, int axis) { - addInputArgument(input); - addIArgument(axis); - } - public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } + public ExpandDims(INDArray x, int axis){ + super(new INDArray[]{x}, null); + this.jaxis = axis; + addIArgument(axis); + } + @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index 593531098..7ac8429c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -42,6 +42,10 @@ public class GatherNd extends DynamicCustomOp { super(new INDArray[]{df, indices}, null); } + public GatherNd(INDArray[] inputs, INDArray[] outputs){ + super(inputs, outputs); + } + @Override public String opName() { return "gather_nd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index ede375163..f83d61c0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index affc603e9..c08dcb1d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -67,8 +67,7 @@ public class OneHot extends DynamicCustomOp { } public OneHot(INDArray indices, int depth) { - addInputArgument(indices); - addIArgument(depth); + this(indices, null, depth, 0, 1.0, 0.0); } public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { @@ -80,14 +79,16 @@ public class OneHot extends DynamicCustomOp { addArgs(); } - public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { - addInputArgument(indices); - addIArgument(depth, axis); - addTArgument(on, off); - addDArgument(dataType); + public OneHot(INDArray indices, int depth, int axis, double on, double off) { + this(indices, null, depth, axis, on, off); } - + public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { + this(indices, null, depth, axis, on, off); + this.outputType = dataType; + if (outputType != null) + addDArgument(outputType); + } protected void addArgs() { addIArgument(jaxis); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 1827a6589..b4d71b40f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -44,6 +45,10 @@ public class ParallelStack extends DynamicCustomOp { super(null, sameDiff, values, false); } + public ParallelStack(INDArray[] inputs){ + super(inputs, null); + } + @Override public String opName() { return "parallel_stack"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index 870a72a2b..85b9e2bb4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -55,15 +55,16 @@ public class Permute extends Transpose { addIArgument(permuteDims); } - public Permute(INDArray input, int... permuteDims){ - addInputArgument(input); - addIArgument(permuteDims); - } - public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ super(sd, input, permuteDims); } + public Permute(INDArray input, int... permuteDims){ + super(input, null); + this.permuteDims = permuteDims; + addIArgument(permuteDims); + } + public Permute() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 6c5c0f9d6..47960fad3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -41,6 +42,7 @@ import java.util.Map; * @author Adam Gibson */ @Slf4j +@NoArgsConstructor public class Reshape extends DynamicCustomOp { private long[] shape; @@ -61,15 +63,13 @@ public class Reshape extends DynamicCustomOp { addIArgument(shape); } - public Reshape(INDArray in, INDArray shape){ - this(in, shape, null); - } public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){ super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List)null); } - public Reshape() { + public Reshape(INDArray in, INDArray shape){ + this(in, shape, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java index ce3ce9cae..acec28f68 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java @@ -48,11 +48,10 @@ public class Size extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, false); } - public Size(INDArray in) { - addInputArgument(in); + public Size(INDArray in){ + super(new INDArray[] {in}, null); } - @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index b9bcff540..effe95b2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -54,8 +54,10 @@ public class Slice extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{input, begin, end}); } - public Slice(INDArray in, int[] begin, int... size) { - addInputArgument(in); + public Slice(INDArray input, int[] begin, int... size){ + super(new INDArray[] {input}, null); + this.begin = begin; + this.size = size; addIArgument(begin); addIArgument(size); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 33c79e217..53deb43ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index e90e31427..687342ed8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -67,13 +67,16 @@ public class Tile extends DynamicCustomOp { this(inputs,outputs,axis,false); } - public Tile(INDArray x, INDArray repeat) { - addInputArgument(x, repeat); + public Tile(INDArray x, INDArray repeat){ + super(null, new INDArray[] {x, repeat}, null); + this.jaxis = null; } - public Tile(INDArray x, int... repeat) { - addInputArgument(x); - addIArgument(repeat); + public Tile(INDArray inputs, int... axis){ + super(null, new INDArray[] {inputs}, null); + this.jaxis = axis; + this.is_static_reps = true; + addArguments(); } public Tile() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 95215b686..ea4096f63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -60,8 +60,8 @@ public class Transpose extends DynamicCustomOp { super(null, new INDArray[]{input}, result == null ? null : new INDArray[]{result}, null, (List) null); } - public Transpose(INDArray input) { - addInputArgument(input); + public Transpose(INDArray input){ + this(input, null); } public Transpose() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java index dea1c9c3b..b3810ba15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java @@ -62,12 +62,10 @@ public class MatchConditionTransform extends BaseTransformBoolOp { this(x, z, Nd4j.EPS_THRESHOLD, condition); } - public MatchConditionTransform(INDArray x, @NonNull Condition condition) { this(x, null, Nd4j.EPS_THRESHOLD, condition); } - public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) { super(x, null, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java index 80ab7fd35..9c5b54c72 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java @@ -69,7 +69,7 @@ public class CompareAndReplace extends BaseTransformSameOp { * @param condition */ public CompareAndReplace(INDArray x, INDArray y, Condition condition) { - this(x, y, x, condition); + this(x, y, null, condition); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java index ca466ae34..e9ce44c57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -46,6 +47,10 @@ public class Assign extends DynamicCustomOp { super(null,inputs, outputs); } + public Assign(INDArray x, INDArray y ) { + this( new INDArray[]{y ,x},new INDArray[]{y}); // TODO: Still check. y cannot be null, must be same shape as x. + } + @Override public void addIArgument(int... arg) { super.addIArgument(arg); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 718120bf7..db11a6a6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -67,11 +68,19 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } - public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) { - addInputArgument(input); - addIArgument(numPartitions); + public DynamicPartition(@NonNull INDArray input, @NonNull INDArray partitions, int numPartitions) { + super(new INDArray[]{input, partitions}, null); + this.numPartitions = numPartitions; + addArgs(); } + public DynamicPartition(INDArray x, INDArray [] partitions, int numPartitions){ + //TODO; This needs fixing. + super(new INDArray[]{x}, null); + // this.partitions = partitions; + this.numPartitions = numPartitions; + addArgs(); + } @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 60b2bf942..8c94c3e54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -61,14 +62,8 @@ public class DynamicStitch extends DynamicCustomOp { this.numPartitions = inputs.length; } - public DynamicStitch(INDArray[] inputs, INDArray[] indices) { - for (INDArray input : inputs) { - addInputArgument(input); - } - - for (INDArray index : indices) { - addInputArgument(index); - } + public DynamicStitch(@NonNull INDArray[] indices, @NonNull INDArray[] inputs) { + super(ArrayUtils.addAll(indices, inputs), null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index 0d1214c9a..e58609f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -48,14 +48,14 @@ public class EqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } - public EqualTo( INDArray x, INDArray y) { - addInputArgument(x, y); - } - public EqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public EqualTo(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "equals"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index 73f221f35..0e39eb77b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -55,19 +56,21 @@ public class Fill extends DynamicCustomOp { super(null,sameDiff, new SDVariable[] {shape}, false); this.value = value; this.outputDataType = outputDataType; + this.outputDataType = outputDataType; addArgs(); } + public Fill(INDArray shape, DataType outputDataType, double value) { + super(new INDArray[]{shape, Nd4j.scalar(outputDataType, value)}, null); + this.value = value; + this.outputDataType = outputDataType; + } + public Fill(INDArray shape, INDArray result, double value) { super(null, shape, result, Collections.singletonList(value), null); this.value = value; } - public Fill(INDArray shape, DataType dataType, double value) { - super(null, shape, null, Collections.singletonList(value), null); - this.value = value; - } - public Fill(INDArray shape, INDArray value, INDArray result) { super(null, new INDArray[]{shape, value}, new INDArray[]{result}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 6a1ecc2cf..e4b28e56a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -49,14 +49,14 @@ public class GreaterThan extends BaseDynamicTransformOp { super(inputs, outputs); } - public GreaterThan( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public GreaterThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public GreaterThan(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "greater"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index dfb7fe8dd..0ebea5e9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -52,9 +52,9 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { this(new INDArray[]{x, y}, new INDArray[]{z}); } - public GreaterThanOrEqual(INDArray x, INDArray y) { - this(new INDArray[]{x,y}, null); + public GreaterThanOrEqual(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java index 88c0a84ba..4f1804480 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java @@ -45,8 +45,8 @@ public class IsNumericTensor extends DynamicCustomOp { super(null, inputs, outputs); } - public IsNumericTensor(INDArray input) { - addInputArgument(input); + public IsNumericTensor(INDArray inputs) { + super( new INDArray[] {inputs}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index b1a38e0ff..0445b58c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -49,14 +49,14 @@ public class LessThan extends BaseDynamicTransformOp { super(inputs, outputs); } - public LessThan( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public LessThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public LessThan(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "less"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 0ca6bf7e6..06e03335c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -48,14 +48,14 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { super(inputs, outputs); } - public LessThanOrEqual( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public LessThanOrEqual(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "less_equal"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index d2451c0f8..0197f0c79 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -48,12 +48,12 @@ public class Max extends BaseDynamicTransformOp { super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); } - public Max( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public Max( INDArray first, INDArray second){ + this(first, second, null); } - public Max( INDArray x, INDArray y) { - addInputArgument(x,y); + public Max( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java index c195178c2..cf0cf9c58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java @@ -48,12 +48,12 @@ public class Min extends BaseDynamicTransformOp { super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); } - public Min( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public Min( INDArray first, INDArray second){ + this(first, second, null); } - public Min( INDArray x, INDArray y) { - addInputArgument(x,y); + public Min( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index 69d724a7e..ba2e36ea2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -48,14 +48,14 @@ public class NotEqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } - public NotEqualTo( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public NotEqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public NotEqualTo(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "not_equals"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 50332daf6..f7494c618 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -59,6 +59,17 @@ public class ReverseSequence extends DynamicCustomOp { addArguments(); } + public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim){ + super(new INDArray[]{x, seq_lengths}, null); + this.seqDim = seqDim; + this.batchDim = batchDim; + addArguments(); + } + + public ReverseSequence(INDArray x, INDArray seq_lengths){ + this(x, seq_lengths, 1, 0); + } + private void addArguments(){ addIArgument(seqDim); addIArgument(batchDim); @@ -67,11 +78,6 @@ public class ReverseSequence extends DynamicCustomOp { public ReverseSequence() { } - public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { - addInputArgument(x, seq_lengths); - addIArgument(seqDim, batchDim); - } - @Override public String opName() { return "reverse_sequence"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index c98cd7d5b..217062e09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -39,8 +39,8 @@ public class SegmentMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMax(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); + public SegmentMax(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); } public SegmentMax(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index eca108b2a..79e887001 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -39,12 +39,12 @@ public class SegmentMean extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMean(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentMean(){ } + public SegmentMean(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_mean"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index f070dc8d1..367d6ee1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -39,12 +39,12 @@ public class SegmentMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMin(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentMin(){ } + public SegmentMin(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_min"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 71d0dd2c3..ddd719045 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -39,12 +39,12 @@ public class SegmentProd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentProd(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentProd(){ } + public SegmentProd(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_prod"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index a74aded65..9f6269848 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -39,12 +39,12 @@ public class SegmentSum extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentSum(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentSum(){ } + public SegmentSum(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_sum"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index dcee02131..6918b8d2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -42,7 +42,7 @@ public class Identity extends BaseDynamicTransformOp { } public Identity(INDArray x){ - addInputArgument(x); + super(new INDArray[]{x}, null); } public Identity(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index aeb543ea8..6f0432ba5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -41,13 +41,14 @@ public class UnsortedSegmentMax extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMax(){ } + + public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } - public UnsortedSegmentMax(){ } - @Override public String opName(){ return "unsorted_segment_max"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index e869a84eb..ef39f1c91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -45,8 +45,9 @@ public class UnsortedSegmentMean extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index fd0f5fd05..6dc6e7737 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -45,8 +45,9 @@ public class UnsortedSegmentMin extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 12ec63222..f753fe6dc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -45,8 +45,9 @@ public class UnsortedSegmentProd extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index 9d7aceb96..ea5285f12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -39,18 +39,18 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { private int numSegments; - public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); - addIArgument(numSegments); - this.numSegments = numSegments; - } - public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); this.numSegments = numSegments; addIArgument(numSegments); } + public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_sqrt_n"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 5e5cfd12e..d0d5b095b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -46,8 +46,9 @@ public class UnsortedSegmentSum extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java index 4986b8277..e9dd0f840 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -130,13 +130,20 @@ public class NDValidation { " type; got array with non-integer data type " + v.dataType()); } - public static void validateInteger(String opName, String inputName, INDArray[] vars) { - for (INDArray v : vars) { - if (v == null) - return; - if (!v.dataType().isIntType()) + /** + * Validate that the operation is being applied on an integer type INDArray [] + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, String inputName, INDArray [] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (!v[i].dataType().isIntType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + - " type; got array with non-integer data type " + v.dataType()); + " type; got array with non-integer data type member" + v[i].dataType()); } } @@ -246,11 +253,12 @@ public class NDValidation { } public static boolean isSameType(INDArray[] x) { - DataType firstDataType = x[0].dataType(); - if (x.length > 1) { - for (int i = 1; i < x.length; ++i) { - if (firstDataType != x[i].dataType()) - return false; + if(x.length == 0) + return true; + DataType first = x[0].dataType(); + for( int i=1; i