[WIP] Java wrappers (#126)

* shift/rshift/rotl/rotr java/sd wrappers

Signed-off-by: raver119 <raver119@gmail.com>

* few additional wrappers

Signed-off-by: raver119 <raver119@gmail.com>

* minor naming tweak

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-16 13:14:26 +03:00 committed by GitHub
parent 7fbc4b0933
commit 3ba616a161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 540 additions and 60 deletions

View File

@ -229,44 +229,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
@ -289,25 +252,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
@ -1317,6 +1263,21 @@ public class DifferentialFunctionFactory {
return new Xor(sameDiff(), ix, iy).outputVariable();
}
public SDVariable shift(SDVariable ix, int shift) {
return new ShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rshift(SDVariable ix, int shift) {
return new RShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotl(SDVariable ix, int shift) {
return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotr(SDVariable ix, int shift) {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
@ -2231,6 +2192,10 @@ public class DifferentialFunctionFactory {
return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables());
}
public List<SDVariable> modBp(SDVariable x, SDVariable y, SDVariable grad) {
return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables());
}
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);
@ -2238,6 +2203,10 @@ public class DifferentialFunctionFactory {
}
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);
return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
}
public SDVariable div(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);

View File

@ -804,6 +804,34 @@ public class SDVariable extends DifferentialFunction implements Serializable {
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* Floor division operation: elementwise {@code this // x}<br>
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
*
* @param name Name of the output variable
* @param x Variable to perform operation with
* @return Output (result) SDVariable
*/
public SDVariable fdiv(String name, SDVariable x) {
val result = sameDiff.f().floorDiv(this, x);
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* Modulo operation: elementwise {@code this / x}<br>
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
*
* @param name Name of the output variable
* @param x Variable to perform operation with
* @return Output (result) SDVariable
*/
public SDVariable mod(String name, SDVariable x) {
val result = sameDiff.f().mod(this, x);
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* See {@link #mul(String, double)}
*/

View File

@ -2421,6 +2421,58 @@ public class SDMath extends SDOps {
return updateVariableNameAndReference(result, name);
}
/**
* Shift integer bits to the left, i.e. var << 4
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShift(String name, SDVariable x, int shift) {
validateInteger("shift_bits", x);
SDVariable result = f().shift(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Shift integer bits to the right, i.e. var >> 4
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShiftRight(String name, SDVariable x, int shift) {
validateInteger("rshift_bits", x);
SDVariable result = f().rshift(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotl(String name, SDVariable x, int shift) {
validateInteger("cyclic_shift_bits", x);
SDVariable result = f().rotl(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotr(String name, SDVariable x, int shift) {
validateInteger("cyclic_rshift_bits", x);
SDVariable result = f().rotr(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
*

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise roll operation, rolls bits to the left, <<
*
* @author raver119@gmail.com
*/
public class CyclicRShiftBits extends BaseDynamicTransformOp {
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public CyclicRShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public CyclicRShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public CyclicRShiftBits() {}
@Override
public String opName() {
return "cyclic_rshift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise roll operation, rolls bits to the left, <<
*
* @author raver119@gmail.com
*/
public class CyclicShiftBits extends BaseDynamicTransformOp {
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public CyclicShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public CyclicShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public CyclicShiftBits() {}
@Override
public String opName() {
return "cyclic_shift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise shift operation, shift bits to the right, >>
*
* @author raver119@gmail.com
*/
public class RShiftBits extends BaseDynamicTransformOp {
public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public RShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public RShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public RShiftBits() {}
@Override
public String opName() {
return "rshift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise shift operation, shift bits to the left, <<
*
* @author raver119@gmail.com
*/
public class ShiftBits extends BaseDynamicTransformOp {
public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public ShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public ShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public ShiftBits() {}
@Override
public String opName() {
return "shift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,69 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.List;
/**
* Modulo operation
*
* @author raver119@gmail.com
*/
public class ModOp extends BaseDynamicTransformOp {
public static final String OP_NAME = "mod";
public ModOp() {}
public ModOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
super(sameDiff, args, inPlace);
}
public ModOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public ModOp(INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}
@Override
public String opName() {
return OP_NAME;
}
@Override
public String onnxName() {
return "Mod";
}
@Override
public String tensorflowName() {
return "mod";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return f().modBp(larg(), rarg(), i_v.get(0));
}
}

View File

@ -0,0 +1,39 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
/**
* Modulo backprop operation. Supports 'undoing' of auto broadcast as applied in div op forward pass
*
* @author raver119@gmail.com
*/
public class ModBpOp extends BaseArithmeticBackpropOp {
public ModBpOp() {}
public ModBpOp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable eps) {
super(sameDiff, x,y,eps);
}
@Override
public String opName() {
return "mod_bp";
}
}

View File

@ -11131,7 +11131,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < block.width(); e++) {
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
// for (int e = 0; e < opLimit; e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape);
// }
@ -11168,7 +11169,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < block.width(); e++) {
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
// for (int e = 0; e < opLimit; e++) {
// Nd4jLong* newshape;
// COPY_SHAPE(inputShape->at(0), newshape);
// shapeList->push_back(CONSTANT(newshape));
@ -11191,7 +11193,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// REGISTER_C(NAME)
// nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) {
// auto shapeList = SHAPELIST();
// for (int e = 0; e < block.width(); e++) {
// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs();
// for (int e = 0; e < opLimit; e++) {
// auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e)));
// shapeList->push_back(newshape);
// }