From 3ba616a161ee2637402d847a7fdb191f06c445bd Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 16 Aug 2019 13:14:26 +0300 Subject: [PATCH] [WIP] Java wrappers (#126) * shift/rshift/rotl/rotr java/sd wrappers Signed-off-by: raver119 * few additional wrappers Signed-off-by: raver119 * minor naming tweak Signed-off-by: raver119 --- .../DifferentialFunctionFactory.java | 83 ++++++------------- .../nd4j/autodiff/samediff/SDVariable.java | 28 +++++++ .../nd4j/autodiff/samediff/ops/SDMath.java | 52 ++++++++++++ .../transforms/custom/CyclicRShiftBits.java | 80 ++++++++++++++++++ .../transforms/custom/CyclicShiftBits.java | 80 ++++++++++++++++++ .../impl/transforms/custom/RShiftBits.java | 80 ++++++++++++++++++ .../ops/impl/transforms/custom/ShiftBits.java | 80 ++++++++++++++++++ .../transforms/pairwise/arithmetic/ModOp.java | 69 +++++++++++++++ .../pairwise/arithmetic/bp/ModBpOp.java | 39 +++++++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 9 +- 10 files changed, 540 insertions(+), 60 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 34800ca07..c03ff276a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -229,44 +229,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign; -import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; -import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd; -import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch; -import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; -import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace; -import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB; +import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; @@ -289,25 +252,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; @@ -1317,6 +1263,21 @@ public class DifferentialFunctionFactory { return new Xor(sameDiff(), ix, iy).outputVariable(); } + public SDVariable shift(SDVariable ix, int shift) { + return new ShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rshift(SDVariable ix, int shift) { + return new RShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rotl(SDVariable ix, int shift) { + return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rotr(SDVariable ix, int shift) { + return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); + } public SDVariable eq(SDVariable iX, SDVariable i_y) { return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); @@ -2231,6 +2192,10 @@ public class DifferentialFunctionFactory { return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables()); } + public List modBp(SDVariable x, SDVariable y, SDVariable grad) { + return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables()); + } + public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { validateDifferentialFunctionsameDiff(differentialFunction); @@ -2238,6 +2203,10 @@ public class DifferentialFunctionFactory { } + public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { + validateDifferentialFunctionsameDiff(differentialFunction); + return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); + } public SDVariable div(SDVariable differentialFunction, SDVariable i_v) { validateDifferentialFunctionsameDiff(differentialFunction); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index f618b1186..64749da1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -804,6 +804,34 @@ public class SDVariable extends DifferentialFunction implements Serializable { return sameDiff.updateVariableNameAndReference(result, name); } + /** + * Floor division operation: elementwise {@code this // x}
+ * If this and x variables have equal shape, the output shape is the same as the inputs.
+ * Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. + * + * @param name Name of the output variable + * @param x Variable to perform operation with + * @return Output (result) SDVariable + */ + public SDVariable fdiv(String name, SDVariable x) { + val result = sameDiff.f().floorDiv(this, x); + return sameDiff.updateVariableNameAndReference(result, name); + } + + /** + * Modulo operation: elementwise {@code this / x}
+ * If this and x variables have equal shape, the output shape is the same as the inputs.
+ * Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. + * + * @param name Name of the output variable + * @param x Variable to perform operation with + * @return Output (result) SDVariable + */ + public SDVariable mod(String name, SDVariable x) { + val result = sameDiff.f().mod(this, x); + return sameDiff.updateVariableNameAndReference(result, name); + } + /** * See {@link #mul(String, double)} */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 5543db3c9..70eaa5cd9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2421,6 +2421,58 @@ public class SDMath extends SDOps { return updateVariableNameAndReference(result, name); } + /** + * Shift integer bits to the left, i.e. var << 4 + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitShift(String name, SDVariable x, int shift) { + validateInteger("shift_bits", x); + SDVariable result = f().shift(x, shift); + return updateVariableNameAndReference(result, name); + } + + /** + * Shift integer bits to the right, i.e. var >> 4 + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitShiftRight(String name, SDVariable x, int shift) { + validateInteger("rshift_bits", x); + SDVariable result = f().rshift(x, shift); + return updateVariableNameAndReference(result, name); + } + + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4) + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitRotl(String name, SDVariable x, int shift) { + validateInteger("cyclic_shift_bits", x); + SDVariable result = f().rotl(x, shift); + return updateVariableNameAndReference(result, name); + } + + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4) + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitRotr(String name, SDVariable x, int shift) { + validateInteger("cyclic_rshift_bits", x); + SDVariable result = f().rotr(x, shift); + return updateVariableNameAndReference(result, name); + } + /** * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java new file mode 100644 index 000000000..318a7dc02 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise roll operation, rolls bits to the left, << + * + * @author raver119@gmail.com + */ +public class CyclicRShiftBits extends BaseDynamicTransformOp { + + public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public CyclicRShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public CyclicRShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public CyclicRShiftBits() {} + + @Override + public String opName() { + return "cyclic_rshift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java new file mode 100644 index 000000000..b4291c5df --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise roll operation, rolls bits to the left, << + * + * @author raver119@gmail.com + */ +public class CyclicShiftBits extends BaseDynamicTransformOp { + + public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public CyclicShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public CyclicShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public CyclicShiftBits() {} + + @Override + public String opName() { + return "cyclic_shift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java new file mode 100644 index 000000000..80697efa3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise shift operation, shift bits to the right, >> + * + * @author raver119@gmail.com + */ +public class RShiftBits extends BaseDynamicTransformOp { + + public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public RShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public RShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public RShiftBits() {} + + @Override + public String opName() { + return "rshift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java new file mode 100644 index 000000000..8c652f72d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise shift operation, shift bits to the left, << + * + * @author raver119@gmail.com + */ +public class ShiftBits extends BaseDynamicTransformOp { + + public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public ShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public ShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public ShiftBits() {} + + @Override + public String opName() { + return "shift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java new file mode 100644 index 000000000..289333f96 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java @@ -0,0 +1,69 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.List; + +/** + * Modulo operation + * + * @author raver119@gmail.com + */ +public class ModOp extends BaseDynamicTransformOp { + public static final String OP_NAME = "mod"; + + public ModOp() {} + + public ModOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { + super(sameDiff, args, inPlace); + } + + public ModOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + + public ModOp(INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); + } + + @Override + public String opName() { + return OP_NAME; + } + + @Override + public String onnxName() { + return "Mod"; + } + + @Override + public String tensorflowName() { + return "mod"; + } + + @Override + public List doDiff(List i_v) { + return f().modBp(larg(), rarg(), i_v.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java new file mode 100644 index 000000000..a9c401dd7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +/** + * Modulo backprop operation. Supports 'undoing' of auto broadcast as applied in div op forward pass + * + * @author raver119@gmail.com + */ +public class ModBpOp extends BaseArithmeticBackpropOp { + + public ModBpOp() {} + + public ModBpOp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable eps) { + super(sameDiff, x,y,eps); + } + + @Override + public String opName() { + return "mod_bp"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 65c8f4b9f..e0d53a66f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -11131,7 +11131,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < block.width(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // shapeList->push_back(newshape); // } @@ -11168,7 +11169,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < block.width(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // Nd4jLong* newshape; // COPY_SHAPE(inputShape->at(0), newshape); // shapeList->push_back(CONSTANT(newshape)); @@ -11191,7 +11193,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < block.width(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // shapeList->push_back(newshape); // }