diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 3086b0f1b..621dac941 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1288,6 +1288,22 @@ public class DifferentialFunctionFactory { return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); } + public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { + return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ + return new BitwiseAnd(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseOr(SDVariable x, SDVariable y){ + return new BitwiseOr(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseXor(SDVariable x, SDVariable y){ + return new BitwiseXor(sameDiff(), x, y).outputVariable(); + } + public SDVariable eq(SDVariable iX, SDVariable i_y) { return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index e09ceda75..0b5a4c03f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps { */ public final SDImage image = new SDImage(this); + /** + * Op creator object for bitwise operations + */ + public final SDBitwise bitwise = new SDBitwise(this); + /** * Op creator object for math operations */ @@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps { return image; } + /** + * Op creator object for bitwise operations + */ + public SDBitwise bitwise(){ + return bitwise; + } + /** * For import, many times we have variables diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java new file mode 100644 index 000000000..0857b2b42 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -0,0 +1,205 @@ +package org.nd4j.autodiff.samediff.ops; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; + +/** + * + */ +public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * See {@link #leftShift(String, SDVariable, SDVariable)} + */ + public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ + return leftShift(null, x, y); + } + + /** + * Bitwise left shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable leftShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift", x); + validateInteger("bitwise left shift", y); + + SDVariable ret = f().shift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShift(String, SDVariable, SDVariable)} + */ + public SDVariable rightShift(SDVariable x, SDVariable y){ + return rightShift(null, x, y); + } + + /** + * Bitwise right shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable rightShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift", x); + validateInteger("bitwise right shift", y); + + SDVariable ret = f().rshift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ + return leftShiftCyclic(null, x, y); + } + + /** + * Bitwise left cyclical shift operation. Supports broadcasting. + * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code leftShiftCyclic(01110000, 2) -> 11000001} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift (cyclic)", x); + validateInteger("bitwise left shift (cyclic)", y); + + SDVariable ret = f().rotl(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ + return rightShiftCyclic(null, x, y); + } + + /** + * Bitwise right cyclical shift operation. Supports broadcasting. + * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code rightShiftCyclic(00001110, 2) -> 10000011} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift (cyclic)", x); + validateInteger("bitwise right shift (cyclic)", y); + + SDVariable ret = f().rotr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} + */ + public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ + return bitsHammingDistance(null, x, y); + } + + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return + */ + public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise hamming distance", x); + validateInteger("bitwise hamming distance", y); + + SDVariable ret = f().bitwiseHammingDist(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #and(String, SDVariable, SDVariable)} + */ + public SDVariable and(SDVariable x, SDVariable y){ + return and(null, x, y); + } + + /** + * Bitwise AND operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise AND array + */ + public SDVariable and(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise AND", x); + validateInteger("bitwise AND", y); + + SDVariable ret = f().bitwiseAnd(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #or(String, SDVariable, SDVariable)} + */ + public SDVariable or(SDVariable x, SDVariable y){ + return or(null, x, y); + } + + /** + * Bitwise OR operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise OR array + */ + public SDVariable or(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise OR", x); + validateInteger("bitwise OR", y); + + SDVariable ret = f().bitwiseOr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #xor(String, SDVariable, SDVariable)} + */ + public SDVariable xor(SDVariable x, SDVariable y){ + return xor(null, x, y); + } + + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise XOR array + */ + public SDVariable xor(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise XOR", x); + validateInteger("bitwise XOR", y); + + SDVariable ret = f().bitwiseXor(x, y); + return updateVariableNameAndReference(ret, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3a89b7339..19b534a97 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -353,6 +353,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java new file mode 100644 index 000000000..1fa749830 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java @@ -0,0 +1,37 @@ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class BitsHammingDistance extends DynamicCustomOp { + + public BitsHammingDistance(){ } + + public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){ + super(sd, new SDVariable[]{x, y}); + } + + public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "bits_hamming_distance"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes); + Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes); + return Collections.singletonList(DataType.LONG); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java index 3a9173654..a8b4ebbb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java index 20b6f6955..ea7ae1715 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 4435615f5..3cc03d12b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index 5501324f2..a9eebb14e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); }