diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 3086b0f1b..621dac941 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -1288,6 +1288,22 @@ public class DifferentialFunctionFactory {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
}
+ public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
+ return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
+ }
+
+ public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
+ return new BitwiseAnd(sameDiff(), x, y).outputVariable();
+ }
+
+ public SDVariable bitwiseOr(SDVariable x, SDVariable y){
+ return new BitwiseOr(sameDiff(), x, y).outputVariable();
+ }
+
+ public SDVariable bitwiseXor(SDVariable x, SDVariable y){
+ return new BitwiseXor(sameDiff(), x, y).outputVariable();
+ }
+
public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
index e09ceda75..0b5a4c03f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
@@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
*/
public final SDImage image = new SDImage(this);
+ /**
+ * Op creator object for bitwise operations
+ */
+ public final SDBitwise bitwise = new SDBitwise(this);
+
/**
* Op creator object for math operations
*/
@@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
return image;
}
+ /**
+ * Op creator object for bitwise operations
+ */
+ public SDBitwise bitwise(){
+ return bitwise;
+ }
+
/**
* For import, many times we have variables
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java
new file mode 100644
index 000000000..0857b2b42
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java
@@ -0,0 +1,205 @@
+package org.nd4j.autodiff.samediff.ops;
+
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+
+import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
+
+/**
+ *
+ */
+public class SDBitwise extends SDOps {
+ public SDBitwise(SameDiff sameDiff) {
+ super(sameDiff);
+ }
+
+ /**
+ * See {@link #leftShift(String, SDVariable, SDVariable)}
+ */
+ public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
+ return leftShift(null, x, y);
+ }
+
+ /**
+ * Bitwise left shift operation. Supports broadcasting.
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x Input to be bit shifted (must be an integer type)
+ * @param y Amount to shift elements of x array (must be an integer type)
+ * @return Bitwise shifted input x
+ */
+ public SDVariable leftShift(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise left shift", x);
+ validateInteger("bitwise left shift", y);
+
+ SDVariable ret = f().shift(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #rightShift(String, SDVariable, SDVariable)}
+ */
+ public SDVariable rightShift(SDVariable x, SDVariable y){
+ return rightShift(null, x, y);
+ }
+
+ /**
+ * Bitwise right shift operation. Supports broadcasting.
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x Input to be bit shifted (must be an integer type)
+ * @param y Amount to shift elements of x array (must be an integer type)
+ * @return Bitwise shifted input x
+ */
+ public SDVariable rightShift(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise right shift", x);
+ validateInteger("bitwise right shift", y);
+
+ SDVariable ret = f().rshift(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
+ */
+ public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
+ return leftShiftCyclic(null, x, y);
+ }
+
+ /**
+ * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x Input to be bit shifted (must be an integer type)
+ * @param y Amount to shift elements of x array (must be an integer type)
+ * @return Bitwise cyclic shifted input x
+ */
+ public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise left shift (cyclic)", x);
+ validateInteger("bitwise left shift (cyclic)", y);
+
+ SDVariable ret = f().rotl(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
+ */
+ public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
+ return rightShiftCyclic(null, x, y);
+ }
+
+ /**
+ * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x Input to be bit shifted (must be an integer type)
+ * @param y Amount to shift elements of x array (must be an integer type)
+ * @return Bitwise cyclic shifted input x
+ */
+ public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise right shift (cyclic)", x);
+ validateInteger("bitwise right shift (cyclic)", y);
+
+ SDVariable ret = f().rotr(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
+ */
+ public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
+ return bitsHammingDistance(null, x, y);
+ }
+
+ /**
+ * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x First input array. Must be integer type.
+ * @param y First input array. Must be integer type, same type as x
+ * @return
+ */
+ public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise hamming distance", x);
+ validateInteger("bitwise hamming distance", y);
+
+ SDVariable ret = f().bitwiseHammingDist(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #and(String, SDVariable, SDVariable)}
+ */
+ public SDVariable and(SDVariable x, SDVariable y){
+ return and(null, x, y);
+ }
+
+ /**
+ * Bitwise AND operation. Supports broadcasting.
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x First input array. Must be integer type.
+ * @param y First input array. Must be integer type, same type as x
+ * @return Bitwise AND array
+ */
+ public SDVariable and(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise AND", x);
+ validateInteger("bitwise AND", y);
+
+ SDVariable ret = f().bitwiseAnd(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #or(String, SDVariable, SDVariable)}
+ */
+ public SDVariable or(SDVariable x, SDVariable y){
+ return or(null, x, y);
+ }
+
+ /**
+ * Bitwise OR operation. Supports broadcasting.
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x First input array. Must be integer type.
+ * @param y First input array. Must be integer type, same type as x
+ * @return Bitwise OR array
+ */
+ public SDVariable or(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise OR", x);
+ validateInteger("bitwise OR", y);
+
+ SDVariable ret = f().bitwiseOr(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+
+ /**
+ * See {@link #xor(String, SDVariable, SDVariable)}
+ */
+ public SDVariable xor(SDVariable x, SDVariable y){
+ return xor(null, x, y);
+ }
+
+ /**
+ * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ *
+ * @param name Name of the output variable. May be null.
+ * @param x First input array. Must be integer type.
+ * @param y First input array. Must be integer type, same type as x
+ * @return Bitwise XOR array
+ */
+ public SDVariable xor(String name, SDVariable x, SDVariable y){
+ validateInteger("bitwise XOR", x);
+ validateInteger("bitwise XOR", y);
+
+ SDVariable ret = f().bitwiseXor(x, y);
+ return updateVariableNameAndReference(ret, name);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
index 3a89b7339..19b534a97 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
@@ -353,6 +353,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
+ org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java
new file mode 100644
index 000000000..1fa749830
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java
@@ -0,0 +1,37 @@
+package org.nd4j.linalg.api.ops.impl.transforms.custom;
+
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+
+import java.util.Collections;
+import java.util.List;
+
+public class BitsHammingDistance extends DynamicCustomOp {
+
+ public BitsHammingDistance(){ }
+
+ public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
+ super(sd, new SDVariable[]{x, y});
+ }
+
+ public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
+ super(new INDArray[]{x, y}, null);
+ }
+
+ @Override
+ public String opName() {
+ return "bits_hamming_distance";
+ }
+
+ @Override
+ public List calculateOutputDataTypes(List dataTypes){
+ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
+ Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
+ return Collections.singletonList(DataType.LONG);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java
index 3a9173654..a8b4ebbb0 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java
@@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
- throw new NoOpNameFoundException("No onnx op opName found for " + opName());
+ throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java
index 20b6f6955..ea7ae1715 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java
@@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
- throw new NoOpNameFoundException("No onnx op opName found for " + opName());
+ throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java
index 4435615f5..3cc03d12b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java
@@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
- throw new NoOpNameFoundException("No onnx op opName found for " + opName());
+ throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java
index 5501324f2..a9eebb14e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java
@@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
- throw new NoOpNameFoundException("No onnx op opName found for " + opName());
+ throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}