Add SameDiff.bitwise namespace (#232)
* #8196 add SameDiff.bitwise namespace Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add BitsHammingDistance, add remaining bitwise ops to bitwise namespace Signed-off-by: AlexDBlack <blacka101@gmail.com> * fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
548044a1e2
commit
03c52ef9dd
|
@ -1288,6 +1288,22 @@ public class DifferentialFunctionFactory {
|
||||||
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
|
||||||
|
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseOr(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseXor(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
||||||
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
|
||||||
*/
|
*/
|
||||||
public final SDImage image = new SDImage(this);
|
public final SDImage image = new SDImage(this);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Op creator object for bitwise operations
|
||||||
|
*/
|
||||||
|
public final SDBitwise bitwise = new SDBitwise(this);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Op creator object for math operations
|
* Op creator object for math operations
|
||||||
*/
|
*/
|
||||||
|
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
return image;
|
return image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Op creator object for bitwise operations
|
||||||
|
*/
|
||||||
|
public SDBitwise bitwise(){
|
||||||
|
return bitwise;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For import, many times we have variables
|
* For import, many times we have variables
|
||||||
|
|
|
@ -0,0 +1,205 @@
|
||||||
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class SDBitwise extends SDOps {
|
||||||
|
public SDBitwise(SameDiff sameDiff) {
|
||||||
|
super(sameDiff);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #leftShift(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
|
||||||
|
return leftShift(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left shift operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise left shift", x);
|
||||||
|
validateInteger("bitwise left shift", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().shift(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #rightShift(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable rightShift(SDVariable x, SDVariable y){
|
||||||
|
return rightShift(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right shift operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise right shift", x);
|
||||||
|
validateInteger("bitwise right shift", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rshift(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
|
||||||
|
return leftShiftCyclic(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left cyclical shift operation. Supports broadcasting.
|
||||||
|
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||||
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise cyclic shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise left shift (cyclic)", x);
|
||||||
|
validateInteger("bitwise left shift (cyclic)", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rotl(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
|
||||||
|
return rightShiftCyclic(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right cyclical shift operation. Supports broadcasting.
|
||||||
|
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||||
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise cyclic shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise right shift (cyclic)", x);
|
||||||
|
validateInteger("bitwise right shift (cyclic)", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rotr(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
|
||||||
|
return bitsHammingDistance(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||||
|
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise hamming distance", x);
|
||||||
|
validateInteger("bitwise hamming distance", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseHammingDist(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #and(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable and(SDVariable x, SDVariable y){
|
||||||
|
return and(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise AND operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise AND array
|
||||||
|
*/
|
||||||
|
public SDVariable and(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise AND", x);
|
||||||
|
validateInteger("bitwise AND", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseAnd(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #or(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable or(SDVariable x, SDVariable y){
|
||||||
|
return or(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise OR operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise OR array
|
||||||
|
*/
|
||||||
|
public SDVariable or(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise OR", x);
|
||||||
|
validateInteger("bitwise OR", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseOr(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #xor(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable xor(SDVariable x, SDVariable y){
|
||||||
|
return xor(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise XOR array
|
||||||
|
*/
|
||||||
|
public SDVariable xor(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise XOR", x);
|
||||||
|
validateInteger("bitwise XOR", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseXor(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
}
|
|
@ -353,6 +353,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class BitsHammingDistance extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public BitsHammingDistance(){ }
|
||||||
|
|
||||||
|
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
|
||||||
|
super(sd, new SDVariable[]{x, y});
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
|
||||||
|
super(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bits_hamming_distance";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
|
||||||
|
return Collections.singletonList(DataType.LONG);
|
||||||
|
}
|
||||||
|
}
|
|
@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue