Add SameDiff.bitwise namespace (#232)

* #8196 add SameDiff.bitwise namespace

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add BitsHammingDistance, add remaining bitwise ops to bitwise namespace

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-04 22:34:31 +10:00 committed by GitHub
parent 548044a1e2
commit 03c52ef9dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 275 additions and 4 deletions

View File

@ -1288,6 +1288,22 @@ public class DifferentialFunctionFactory {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
} }
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
return new BitwiseOr(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
return new BitwiseXor(sameDiff(), x, y).outputVariable();
}
public SDVariable eq(SDVariable iX, SDVariable i_y) { public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
} }

View File

@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
*/ */
public final SDImage image = new SDImage(this); public final SDImage image = new SDImage(this);
/**
* Op creator object for bitwise operations
*/
public final SDBitwise bitwise = new SDBitwise(this);
/** /**
* Op creator object for math operations * Op creator object for math operations
*/ */
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
return image; return image;
} }
/**
* Op creator object for bitwise operations
*/
public SDBitwise bitwise(){
return bitwise;
}
/** /**
* For import, many times we have variables * For import, many times we have variables

View File

@ -0,0 +1,205 @@
package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
/**
*
*/
public class SDBitwise extends SDOps {
public SDBitwise(SameDiff sameDiff) {
super(sameDiff);
}
/**
* See {@link #leftShift(String, SDVariable, SDVariable)}
*/
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
return leftShift(null, x, y);
}
/**
* Bitwise left shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift", x);
validateInteger("bitwise left shift", y);
SDVariable ret = f().shift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShift(String, SDVariable, SDVariable)}
*/
public SDVariable rightShift(SDVariable x, SDVariable y){
return rightShift(null, x, y);
}
/**
* Bitwise right shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift", x);
validateInteger("bitwise right shift", y);
SDVariable ret = f().rshift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
return leftShiftCyclic(null, x, y);
}
/**
* Bitwise left cyclical shift operation. Supports broadcasting.
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift (cyclic)", x);
validateInteger("bitwise left shift (cyclic)", y);
SDVariable ret = f().rotl(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
return rightShiftCyclic(null, x, y);
}
/**
* Bitwise right cyclical shift operation. Supports broadcasting.
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift (cyclic)", x);
validateInteger("bitwise right shift (cyclic)", y);
SDVariable ret = f().rotr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
*/
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
return bitsHammingDistance(null, x, y);
}
/**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return
*/
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
validateInteger("bitwise hamming distance", x);
validateInteger("bitwise hamming distance", y);
SDVariable ret = f().bitwiseHammingDist(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #and(String, SDVariable, SDVariable)}
*/
public SDVariable and(SDVariable x, SDVariable y){
return and(null, x, y);
}
/**
* Bitwise AND operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise AND array
*/
public SDVariable and(String name, SDVariable x, SDVariable y){
validateInteger("bitwise AND", x);
validateInteger("bitwise AND", y);
SDVariable ret = f().bitwiseAnd(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #or(String, SDVariable, SDVariable)}
*/
public SDVariable or(SDVariable x, SDVariable y){
return or(null, x, y);
}
/**
* Bitwise OR operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise OR array
*/
public SDVariable or(String name, SDVariable x, SDVariable y){
validateInteger("bitwise OR", x);
validateInteger("bitwise OR", y);
SDVariable ret = f().bitwiseOr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #xor(String, SDVariable, SDVariable)}
*/
public SDVariable xor(SDVariable x, SDVariable y){
return xor(null, x, y);
}
/**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise XOR array
*/
public SDVariable xor(String name, SDVariable x, SDVariable y){
validateInteger("bitwise XOR", x);
validateInteger("bitwise XOR", y);
SDVariable ret = f().bitwiseXor(x, y);
return updateVariableNameAndReference(ret, name);
}
}

View File

@ -353,6 +353,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,

View File

@ -0,0 +1,37 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class BitsHammingDistance extends DynamicCustomOp {
public BitsHammingDistance(){ }
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
super(sd, new SDVariable[]{x, y});
}
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
super(new INDArray[]{x, y}, null);
}
@Override
public String opName() {
return "bits_hamming_distance";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
return Collections.singletonList(DataType.LONG);
}
}

View File

@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }