[WIP] Java wrappers (#126)
* shift/rshift/rotl/rotr java/sd wrappers Signed-off-by: raver119 <raver119@gmail.com> * few additional wrappers Signed-off-by: raver119 <raver119@gmail.com> * minor naming tweak Signed-off-by: raver119 <raver119@gmail.com>master
parent
7fbc4b0933
commit
3ba616a161
|
@ -229,44 +229,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
|
||||
|
@ -289,25 +252,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
||||
|
@ -1317,6 +1263,21 @@ public class DifferentialFunctionFactory {
|
|||
return new Xor(sameDiff(), ix, iy).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable shift(SDVariable ix, int shift) {
|
||||
return new ShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rshift(SDVariable ix, int shift) {
|
||||
return new RShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rotl(SDVariable ix, int shift) {
|
||||
return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rotr(SDVariable ix, int shift) {
|
||||
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
||||
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
||||
|
@ -2231,6 +2192,10 @@ public class DifferentialFunctionFactory {
|
|||
return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables());
|
||||
}
|
||||
|
||||
public List<SDVariable> modBp(SDVariable x, SDVariable y, SDVariable grad) {
|
||||
return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables());
|
||||
}
|
||||
|
||||
|
||||
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
|
@ -2238,6 +2203,10 @@ public class DifferentialFunctionFactory {
|
|||
|
||||
}
|
||||
|
||||
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable div(SDVariable differentialFunction, SDVariable i_v) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
|
|
|
@ -804,6 +804,34 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
return sameDiff.updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Floor division operation: elementwise {@code this // x}<br>
|
||||
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
|
||||
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Variable to perform operation with
|
||||
* @return Output (result) SDVariable
|
||||
*/
|
||||
public SDVariable fdiv(String name, SDVariable x) {
|
||||
val result = sameDiff.f().floorDiv(this, x);
|
||||
return sameDiff.updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Modulo operation: elementwise {@code this / x}<br>
|
||||
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
|
||||
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Variable to perform operation with
|
||||
* @return Output (result) SDVariable
|
||||
*/
|
||||
public SDVariable mod(String name, SDVariable x) {
|
||||
val result = sameDiff.f().mod(this, x);
|
||||
return sameDiff.updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #mul(String, double)}
|
||||
*/
|
||||
|
|
|
@ -2421,6 +2421,58 @@ public class SDMath extends SDOps {
|
|||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shift integer bits to the left, i.e. var << 4
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitShift(String name, SDVariable x, int shift) {
|
||||
validateInteger("shift_bits", x);
|
||||
SDVariable result = f().shift(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Shift integer bits to the right, i.e. var >> 4
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitShiftRight(String name, SDVariable x, int shift) {
|
||||
validateInteger("rshift_bits", x);
|
||||
SDVariable result = f().rshift(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitRotl(String name, SDVariable x, int shift) {
|
||||
validateInteger("cyclic_shift_bits", x);
|
||||
SDVariable result = f().rotl(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitRotr(String name, SDVariable x, int shift) {
|
||||
validateInteger("cyclic_rshift_bits", x);
|
||||
SDVariable result = f().rotr(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
|
||||
*
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Element-wise roll operation, rolls bits to the left, <<
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CyclicRShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public CyclicRShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public CyclicRShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
}
|
||||
|
||||
public CyclicRShiftBits() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "cyclic_rshift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Element-wise roll operation, rolls bits to the left, <<
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CyclicShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public CyclicShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public CyclicShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
}
|
||||
|
||||
public CyclicShiftBits() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "cyclic_shift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Element-wise shift operation, shift bits to the right, >>
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class RShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public RShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public RShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
}
|
||||
|
||||
public RShiftBits() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "rshift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Element-wise shift operation, shift bits to the left, <<
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public ShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
}
|
||||
|
||||
public ShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
}
|
||||
|
||||
public ShiftBits() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "shift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Modulo operation
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ModOp extends BaseDynamicTransformOp {
|
||||
public static final String OP_NAME = "mod";
|
||||
|
||||
public ModOp() {}
|
||||
|
||||
public ModOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||
super(sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public ModOp(INDArray first, INDArray second, INDArray result){
|
||||
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
|
||||
}
|
||||
|
||||
public ModOp(INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return OP_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
return "Mod";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "mod";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
return f().modBp(larg(), rarg(), i_v.get(0));
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
/**
|
||||
* Modulo backprop operation. Supports 'undoing' of auto broadcast as applied in div op forward pass
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class ModBpOp extends BaseArithmeticBackpropOp {
|
||||
|
||||
public ModBpOp() {}
|
||||
|
||||
public ModBpOp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable eps) {
|
||||
super(sameDiff, x,y,eps);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "mod_bp";
|
||||
}
|
||||
}
|
|
@ -11131,7 +11131,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
|
||||
// for (int e = 0; e < opLimit; e++) {
|
||||
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
|
||||
// shapeList->push_back(newshape);
|
||||
// }
|
||||
|
@ -11168,7 +11169,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
|
||||
// for (int e = 0; e < opLimit; e++) {
|
||||
// Nd4jLong* newshape;
|
||||
// COPY_SHAPE(inputShape->at(0), newshape);
|
||||
// shapeList->push_back(CONSTANT(newshape));
|
||||
|
@ -11191,7 +11193,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// REGISTER_C(NAME)
|
||||
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
|
||||
// auto shapeList = SHAPELIST();
|
||||
// for (int e = 0; e < block.width(); e++) {
|
||||
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
|
||||
// for (int e = 0; e < opLimit; e++) {
|
||||
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
|
||||
// shapeList->push_back(newshape);
|
||||
// }
|
||||
|
|
Loading…
Reference in New Issue