From 5e152c0d9a9b3029c3539a6d23efebdae605475c Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 2 Dec 2019 12:23:06 +0200 Subject: [PATCH] TF import tests - adding missing operations (#65) * Add and fix mappings. * Intermediate * Added and fixed some mappings * Added op * Missing constructors added. * Added new mappings * SDImage wrappers and minor tweaks. * Added missing constructor * Some corrections * Cleanup * Small fixes * Ops wrappers * Minor fixes. * Max Pooling * MaxPoolWithArgmax * Some fixes * Ignores for failures * Some ops fixed. * Some fixes * Missing package added * Some fixes * Ignored tests fixed. * Some fixes * Merge master * bitcast fix Signed-off-by: raver119 * Bitcast fixed --- .../DifferentialFunctionFactory.java | 30 ++ .../nd4j/autodiff/samediff/ops/SDBitwise.java | 12 + .../nd4j/autodiff/samediff/ops/SDImage.java | 63 +++- .../nd4j/autodiff/samediff/ops/SDMath.java | 52 ++++ .../org/nd4j/autodiff/samediff/ops/SDNN.java | 32 ++ .../converters/ImportClassMapping.java | 14 +- .../linalg/api/ops/custom/AdjustContrast.java | 21 +- .../api/ops/custom/AdjustContrastV2.java | 20 +- .../nd4j/linalg/api/ops/custom/AdjustHue.java | 69 +++++ .../api/ops/custom/AdjustSaturation.java | 68 +++++ .../api/ops/custom/BaseAdjustContrast.java | 22 +- .../nd4j/linalg/api/ops/custom/BetaInc.java | 67 +++++ .../nd4j/linalg/api/ops/custom/BitCast.java | 24 +- .../api/ops/custom/CompareAndBitpack.java | 21 +- .../linalg/api/ops/custom/DivideNoNan.java | 21 +- .../api/ops/custom/DrawBoundingBoxes.java | 23 +- .../FakeQuantWithMinMaxVarsPerChannel.java | 23 +- .../linalg/api/ops/custom/FusedBatchNorm.java | 64 ++++ .../linalg/api/ops/custom/MatrixBandPart.java | 65 ++++ .../nd4j/linalg/api/ops/custom/Polygamma.java | 66 ++++ .../linalg/api/ops/custom/RandomCrop.java | 61 ++++ .../org/nd4j/linalg/api/ops/custom/Roll.java | 64 ++++ .../linalg/api/ops/custom/ToggleBits.java | 64 ++++ .../api/ops/impl/image/NonMaxSuppression.java | 9 +- .../layers/convolution/MaxPoolWithArgmax.java | 284 ++++++++++++++++++ .../impl/layers/convolution/MaxPooling2D.java | 4 +- .../ops/impl/transforms/clip/ClipByValue.java | 2 +- .../impl/transforms/custom/RShiftBits.java | 8 +- .../ops/impl/transforms/custom/ShiftBits.java | 8 +- .../transforms/custom/UniqueWithCounts.java | 4 +- .../pairwise/arithmetic/CopyOp.java | 4 +- .../transforms/pairwise/arithmetic/ModOp.java | 2 +- .../impl/transforms/pairwise/bool/Not.java | 8 +- .../api/ops/impl/transforms/strict/GELU.java | 13 - .../random/custom/DistributionUniform.java | 14 +- .../linalg/api/ops/random/impl/DropOut.java | 12 - .../opvalidation/LayerOpValidation.java | 29 ++ .../TFGraphs/TFGraphTestAllSameDiff.java | 30 +- .../nd4j/linalg/custom/CustomOpsTests.java | 234 +++++++++++++++ 39 files changed, 1545 insertions(+), 86 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index abea31459..7f59d24e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2616,6 +2616,36 @@ public class DifferentialFunctionFactory { return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); } + public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { + return new BetaInc(sameDiff, a, b, x).outputVariable(); + } + + public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset, + SDVariable dataFormat, SDVariable isTraining) { + return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables(); + } + + public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) { + return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); + } + + public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) { + return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); + } + + public SDVariable polygamma(SDVariable n, SDVariable x) { + return new Polygamma(sameDiff, n,x).outputVariable(); + } + + public SDVariable roll(SDVariable input, SDVariable shift) { + return new Roll(sameDiff, input, shift).outputVariable(); + } + + public SDVariable toggleBits(SDVariable x) { + return new ToggleBits(sameDiff, x).outputVariable(); + } + + public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index 0857b2b42..a255afbc3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -202,4 +202,16 @@ public class SDBitwise extends SDOps { SDVariable ret = f().bitwiseXor(x, y); return updateVariableNameAndReference(ret, name); } + + /** + * Flip bits + * + * @param name Name of the output variable + * @param x input array + * @return array after flipping each input bit + */ + public SDVariable toggleBits(String name, SDVariable x) { + SDVariable res = f().toggleBits(x); + return updateVariableNameAndReference(res, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index f7166ab5e..bf71a665e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -3,6 +3,10 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.custom.AdjustContrast; +import org.nd4j.linalg.api.ops.custom.AdjustHue; +import org.nd4j.linalg.api.ops.custom.AdjustSaturation; +import org.nd4j.linalg.api.ops.custom.RandomCrop; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; @@ -52,10 +56,67 @@ public class SDImage extends SDOps { return updateVariableNameAndReference(out, name); } - + /** + * Greedily selects a subset of bounding boxes in descending order of score + * @param name Might be null. Name for the output variable + * @param boxes 2D array of shape [num_boxes,4] + * @param scores vector of shape [num_boxes] + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold float - threshold for deciding when to remove boxes based on score + * @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size + */ public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); return updateVariableNameAndReference(out, name); } + + /** + * Adjusts contrast of RGB or grayscale images. + * @param name name for the output variable + * @param in images to adjust. 3D shape or higher. + * @param factor float multiplier for adjusting contrast. + * @return Contrast-adjusted image + */ + public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { + SDVariable out = new AdjustContrast(sd, in, factor).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Adjust saturation of RGB images + * @param name name for the output variable + * @param in RGB image as 3D array + * @param factor factor for saturation + * @return adjusted image + */ + public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { + SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Adjust hue of RGB image + * @param name name for the output variable + * @param in RGB image as 3D array + * @param delta value to add to hue channel + * @return adjusted image + */ + public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) { + SDVariable out = new AdjustHue(sd, in, delta).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Randomly crops image + * @param name name for the output variable + * @param input input array + * @param shape shape for crop + * @return cropped array + */ + public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) { + SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); + return updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 10fc0b44a..0d0da022e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2496,5 +2496,57 @@ public class SDMath extends SDOps { return updateVariableNameAndReference(res, name); } + /** + * Compute the regularized incomplete beta integral + * + * @param name Name of the output variable + * @param a input array + * @param b input array + * @param x input array + * @return array + */ + public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) { + SDVariable res = f().betainc(a,b,x); + return updateVariableNameAndReference(res, name); + } + /** + * Copy a tensor setting everything outside a central band in each innermost matrix. + * + * @param name Name of the output variable + * @param input Rank k array + * @param minLower Number of subdiagonals to keep. + * @param maxUpper Number of superdiagonals to keep. + * @return Rank k array of the same shape as input. + */ + public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) { + SDVariable res = f().matrixBandPart(input,minLower,maxUpper); + return updateVariableNameAndReference(res, name); + } + + /** + * Polygamma function + * + * @param name Name of the output variable + * @param n array + * @param x array + * @return array + */ + public SDVariable polygamma(String name, SDVariable n, SDVariable x) { + SDVariable res = f().polygamma(n,x); + return updateVariableNameAndReference(res, name); + } + + /** + * Rolls the elements of input + * + * @param name Name of the output variable + * @param input array + * @param shift number of places to shift elements + * @return array + */ + public SDVariable roll(String name, SDVariable input, SDVariable shift) { + SDVariable res = f().roll(input,shift); + return updateVariableNameAndReference(res, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7b1cc5768..63aab3f33 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.factory.Nd4j; @@ -1032,4 +1033,35 @@ public class SDNN extends SDOps { ); } } + + /** + * Max pooling on the input and outputs both max values and indices + * + * @param name Name of the output variable + * @param x input array + * @return output array and argmax array + */ + public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) { + SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig); + return sd.updateVariableNamesAndReferences(res, names); + } + + /** + * Batch normalization + * + * @param name Name of the output variable + * @param x 4D array + * @param scale vector for scaling factor of normalized x + * @param offset vector to shift to the normalized x + * @param dataFormat integer scalar - data format + * @param isTraining boolean scalar - is training mode + * @return y: 4D array + * batch_mean: vector + * batch_var: vector + */ + public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset, + SDVariable dataFormat, SDVariable isTraining) { + SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining); + return sd.updateVariableNamesAndReferences(res, names); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index b1de641b4..5b60ac0b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -46,7 +46,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, - org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.Flatten.class, @@ -122,6 +121,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative.class, @@ -589,7 +589,17 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class, - org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class + org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, + org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class, + org.nd4j.linalg.api.ops.custom.AdjustSaturation.class, + org.nd4j.linalg.api.ops.custom.AdjustHue.class, + org.nd4j.linalg.api.ops.custom.FusedBatchNorm.class, + org.nd4j.linalg.api.ops.custom.BetaInc.class, + org.nd4j.linalg.api.ops.custom.MatrixBandPart.class, + org.nd4j.linalg.api.ops.custom.Polygamma.class, + org.nd4j.linalg.api.ops.custom.RandomCrop.class, + org.nd4j.linalg.api.ops.custom.Roll.class, + org.nd4j.linalg.api.ops.custom.ToggleBits.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 2d0ac235f..68daf6788 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -1,5 +1,22 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,11 +31,11 @@ public class AdjustContrast extends BaseAdjustContrast { public AdjustContrast() {super();} - public AdjustContrast(INDArray in, double factor, INDArray out) { + public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) { super(in, factor, out); } - public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super(sameDiff,new SDVariable[]{in,factor}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 9ebb3ea6f..71c752485 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -1,5 +1,21 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,11 +30,11 @@ public class AdjustContrastV2 extends BaseAdjustContrast { public AdjustContrastV2() {super();} - public AdjustContrastV2(INDArray in, double factor, INDArray out) { + public AdjustContrastV2(@NonNull INDArray in, double factor, INDArray out) { super(in, factor, out); } - public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { + public AdjustContrastV2(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super( sameDiff,new SDVariable[]{in,factor}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java new file mode 100644 index 000000000..e1a5b0a7a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java @@ -0,0 +1,69 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class AdjustHue extends DynamicCustomOp { + public AdjustHue() {} + + public AdjustHue(@NonNull INDArray in, double delta, INDArray out) { + this(in, delta); + if (out != null) { + outputArguments.add(out); + } + } + + public AdjustHue(@NonNull INDArray in, double delta) { + Preconditions.checkArgument(in.rank() >= 3, + "AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank()); + Preconditions.checkArgument(-1.0 <= delta && delta <= 1.0, "AdjustHue: parameter delta must be within [-1, 1] interval," + + " but got %s instead", delta); + inputArguments.add(in); + + addTArgument(delta); + } + + public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { + super(sameDiff,new SDVariable[]{in,factor}); + } + + @Override + public String opName() { + return "adjust_hue"; + } + + @Override + public String tensorflowName() { + return "AdjustHue"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java new file mode 100644 index 000000000..e9f1f90c8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class AdjustSaturation extends DynamicCustomOp { + + public AdjustSaturation() {} + + public AdjustSaturation(@NonNull INDArray in, double factor, INDArray out) { + this(in, factor); + if (out != null) { + outputArguments.add(out); + } + } + + public AdjustSaturation(@NonNull INDArray in, double factor) { + Preconditions.checkArgument(in.rank() >= 3, + "AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank()); + inputArguments.add(in); + + addTArgument(factor); + } + + public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { + super(sameDiff, new SDVariable[]{in, factor}); + } + + @Override + public String opName() { + return "adjust_saturation"; + } + + @Override + public String tensorflowName() { + return "AdjustSaturation"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index 25cddd741..a5e296043 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -1,5 +1,21 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,16 +30,16 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { public BaseAdjustContrast() { } - public BaseAdjustContrast(INDArray in, double factor, INDArray out) { + public BaseAdjustContrast(@NonNull INDArray in, double factor, INDArray out) { Preconditions.checkArgument(in.rank() >= 3, - String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); + "AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank()); inputArguments.add(in); outputArguments.add(out); addTArgument(factor); } - public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { + public BaseAdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable[] vars) { super("", sameDiff, vars); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java new file mode 100644 index 000000000..ce45869cc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class BetaInc extends DynamicCustomOp { + + public BetaInc() {} + + public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input, + INDArray output) { + addInputArgument(a_input, b_input, x_input); + if (output != null) { + addOutputArgument(output); + } + } + + public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input) { + inputArguments.add(a_input); + inputArguments.add(b_input); + inputArguments.add(x_input); + } + + public BetaInc(@NonNull SameDiff sameDiff, @NonNull SDVariable a, @NonNull SDVariable b, @NonNull SDVariable x) { + super(sameDiff, new SDVariable[]{a,b,x}); + } + + @Override + public String opName() { + return "betainc"; + } + + @Override + public String tensorflowName() { + return "Betainc"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index ebae33fce..cafc228f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import lombok.val; @@ -20,6 +35,8 @@ import java.util.Map; public class BitCast extends DynamicCustomOp { public BitCast() {} + private DataType dtype; + public BitCast(INDArray in, DataType dataType, INDArray out) { this(in, dataType.toInt(), out); } @@ -28,6 +45,8 @@ public class BitCast extends DynamicCustomOp { inputArguments.add(in); outputArguments.add(out); iArguments.add(Long.valueOf(dataType)); + + dtype = DataType.fromInt(dataType); } public BitCast(INDArray in, DataType dataType) { @@ -37,6 +56,7 @@ public class BitCast extends DynamicCustomOp { public BitCast(INDArray in, int dataType) { inputArguments.add(in); iArguments.add(Long.valueOf(dataType)); + dtype = DataType.fromInt(dataType); } public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { @@ -49,6 +69,8 @@ public class BitCast extends DynamicCustomOp { val t = nodeDef.getAttrOrDefault("type", null); val type = ArrayOptionsHelper.convertToDataType(t.getType()); addIArgument(type.toInt()); + + dtype = type; } @Override @@ -65,6 +87,6 @@ public class BitCast extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Collections.singletonList(dtype); } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java index d69c73da4..e8285fe9b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -9,9 +24,13 @@ import org.nd4j.linalg.factory.Nd4j; public class CompareAndBitpack extends DynamicCustomOp { public CompareAndBitpack() {} - public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + public CompareAndBitpack(INDArray in, double threshold) { inputArguments.add(in); inputArguments.add(Nd4j.scalar(threshold)); + } + + public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + this(in, threshold); outputArguments.add(out); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index 801384bfd..af62c8443 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.apache.commons.math3.analysis.function.Divide; @@ -16,9 +31,13 @@ public class DivideNoNan extends DynamicCustomOp { public DivideNoNan() { } - public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + public DivideNoNan(INDArray in1, INDArray in2) { inputArguments.add(in1); inputArguments.add(in2); + } + + public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + this(in1,in2); outputArguments.add(out); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index 57551c84c..b92a6f8f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -13,11 +28,15 @@ import java.util.List; public class DrawBoundingBoxes extends DynamicCustomOp { public DrawBoundingBoxes() {} - public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, - INDArray output) { + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors) { inputArguments.add(images); inputArguments.add(boxes); inputArguments.add(colors); + } + + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, + INDArray output) { + this(images, boxes, colors); outputArguments.add(output); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index ef150843d..c63cd3b56 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -13,14 +28,18 @@ import java.util.List; public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public FakeQuantWithMinMaxVarsPerChannel() {} - public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, - INDArray output) { + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) { Preconditions.checkArgument(min.isVector() && max.isVector() && min.length() == max.length(), "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); inputArguments.add(x); inputArguments.add(min); inputArguments.add(max); + } + + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, + INDArray output) { + this(x,min,max); outputArguments.add(output); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java new file mode 100644 index 000000000..691e5d43f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class FusedBatchNorm extends DynamicCustomOp { + + public FusedBatchNorm() {} + + public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset, + int dataFormat, int isTraining, + INDArray yOut, INDArray batchMeanOut, INDArray batchMeanVar) { + addInputArgument(x, scale, offset); + addIArgument(dataFormat, isTraining); + if (yOut != null && batchMeanOut != null && batchMeanVar != null) { + addOutputArgument(yOut, batchMeanOut, batchMeanVar); + } + } + + public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, + @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) { + super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining}); + } + + @Override + public String opName() { + return "fused_batch_norm"; + } + + @Override + public String tensorflowName() { + return "FusedBatchNormV2"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java new file mode 100644 index 000000000..46d29608e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class MatrixBandPart extends DynamicCustomOp { + + public MatrixBandPart() {} + + public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) { + Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher"); + long N = input.size(-2); + long M = input.size(-1); + Preconditions.checkArgument(minLower > -N && minLower < N, "MatrixBandPart: lower diagonal count %s should be less than %s", + minLower, N); + Preconditions.checkArgument(maxUpper > -M && maxUpper < M, "MatrixBandPart: upper diagonal count %s should be less than %s.", + maxUpper, M); + addInputArgument(input); + addIArgument(minLower, maxUpper); + } + + public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable minLower, SDVariable maxUpper) { + super("", sameDiff, new SDVariable[]{input, minLower, maxUpper}); + } + + @Override + public String opName() { + return "matrix_band_part"; + } + + @Override + public String tensorflowName() { + return "MatrixBandPart"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java new file mode 100644 index 000000000..3b528eb62 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java @@ -0,0 +1,66 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class Polygamma extends DynamicCustomOp { + + public Polygamma() {} + + public Polygamma(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Polygamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Polygamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Polygamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "polygamma"; + } + + @Override + public String tensorflowName() { + return "Polygamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java new file mode 100644 index 000000000..1f3f2e3ea --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java @@ -0,0 +1,61 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.rng.Random; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class RandomCrop extends DynamicCustomOp { + + public RandomCrop() {} + + public RandomCrop(@NonNull INDArray input, @NonNull INDArray shape) { + Preconditions.checkArgument(shape.isVector(),"RandomCrop:Shape tensor should be a vector"); + Preconditions.checkArgument(input.rank() == shape.length(), "RandomCrop:The length of the shape vector is not match input rank"); + addInputArgument(input, shape); + } + + public RandomCrop(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shape) { + super("", sameDiff, new SDVariable[]{input, shape}); + } + + @Override + public String opName() { + return "random_crop"; + } + + @Override + public String tensorflowName() { + return "RandomCrop"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 4*/, + "Expected 4 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(DataType.FLOAT); //TF import: always returns float32... + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java new file mode 100644 index 000000000..9ce7aa641 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class Roll extends DynamicCustomOp { + + public Roll() {} + + public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) { + Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank"); + Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length"); + addInputArgument(input, axes, shifts); + } + + public Roll(@NonNull INDArray input, int shift) { + addInputArgument(input); + addIArgument(shift); + } + + public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift) { + super("", sameDiff, new SDVariable[]{input,shift}); + } + + @Override + public String opName() { + return "roll"; + } + + @Override + public String tensorflowName() { + return "Roll"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java new file mode 100644 index 000000000..641cb4117 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class ToggleBits extends DynamicCustomOp { + + public ToggleBits() {} + + public ToggleBits(@NonNull INDArray input, INDArray output) { + this(input); + if (output != null) { + addOutputArgument(output); + } + } + + public ToggleBits(@NonNull INDArray input) { + addInputArgument(input); + } + + public ToggleBits(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + super("", sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "toggle_bits"; + } + + @Override + public String tensorflowName() { + return "Invert"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index d7161cf5f..75b82dc29 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; @@ -41,6 +42,12 @@ public class NonMaxSuppression extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); } + public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) { + addInputArgument(boxes,scores); + addIArgument(maxOutSize); + addTArgument(iouThreshold, scoreThreshold); + } + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); @@ -53,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java new file mode 100644 index 000000000..58602d85e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -0,0 +1,284 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.util.ArrayUtil; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.lang.reflect.Field; +import java.util.*; + +@Slf4j +@Getter +public class MaxPoolWithArgmax extends DynamicCustomOp { + + protected Pooling2DConfig config; + protected DataType outputType; + + public MaxPoolWithArgmax() { + } + + @Builder(builderMethodName = "sameDiffBuilder") + @SuppressWarnings("Used in lombok") + public MaxPoolWithArgmax(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { + super(null, sameDiff, new SDVariable[]{input}, false); + + config.setType(Pooling2D.Pooling2DType.MAX); + this.config = config; + addArgs(); + } + + public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ + super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax}); + config.setType(Pooling2D.Pooling2DType.MAX); + + this.config = config; + addArgs(); + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + + @Override + public Map propertiesForFunction() { + if(config == null && iArguments.size() > 0){ + //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object + config = Pooling2DConfig.builder() + .kH(iArguments.get(0)) + .kW(iArguments.get(1)) + .sH(iArguments.get(2)) + .sW(iArguments.get(3)) + .pH(iArguments.get(4)) + .pW(iArguments.get(5)) + .dH(iArguments.get(6)) + .dW(iArguments.get(7)) + .isSameMode(iArguments.get(8) == 1) + .extra(iArguments.get(9)) + .isNHWC(iArguments.get(10) == 1) + .type(Pooling2D.Pooling2DType.MAX) + .build(); + } + return config.toProperties(); + } + + private void addArgs() { + addIArgument(config.getKH(), + config.getKW(), + config.getSH(), + config.getSW(), + config.getPH(), + config.getPW(), + config.getDH(), + config.getDW(), + ArrayUtil.fromBoolean(config.isSameMode()), + (int) config.getExtra(), + ArrayUtil.fromBoolean(config.isNHWC()) + ); + + } + + + public String getPoolingPrefix() { + return "max"; + } + + @Override + public List doDiff(List f1) { + List ret = new ArrayList<>(); + List inputs = new ArrayList<>(); + inputs.addAll(Arrays.asList(args())); + inputs.add(f1.get(0)); + Pooling2DDerivative pooling2DDerivative = Pooling2DDerivative.derivativeBuilder() + .inputs(inputs.toArray(new SDVariable[inputs.size()])) + .sameDiff(sameDiff) + .config(config) + .build(); + ret.addAll(Arrays.asList(pooling2DDerivative.outputVariables())); + return ret; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + val aStrides = nodeDef.getAttrOrThrow("strides"); + val tfStrides = aStrides.getList().getIList(); + + val aKernels = nodeDef.getAttrOrThrow("ksize"); + val tfKernels = aKernels.getList().getIList(); + + int sH = 0; + int sW = 0; + + int pH = 0; + int pW = 0; + + int kH = 0; + int kW = 0; + + val aPadding = nodeDef.getAttrOrThrow("padding"); + val padding = aPadding.getList().getIList(); + + val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", ""); + + boolean isSameMode = paddingMode.equalsIgnoreCase("SAME"); + + String data_format = "nhwc"; + if (nodeDef.containsAttr("data_format")) { + val attr = nodeDef.getAttrOrThrow("data_format"); + + data_format = attr.getS().toStringUtf8().toLowerCase(); + } + + if (data_format.equalsIgnoreCase("nhwc")) { + sH = tfStrides.get(1).intValue(); + sW = tfStrides.get(2).intValue(); + + kH = tfKernels.get(1).intValue(); + kW = tfKernels.get(2).intValue(); + + pH = padding.size() > 0 ? padding.get(1).intValue() : 0; + pW = padding.size() > 0 ? padding.get(2).intValue() : 0; + } else { + sH = tfStrides.get(2).intValue(); + sW = tfStrides.get(3).intValue(); + + kH = tfKernels.get(2).intValue(); + kW = tfKernels.get(3).intValue(); + + pH = padding.size() > 0 ? padding.get(2).intValue() : 0; + pW = padding.size() > 0 ? padding.get(3).intValue() : 0; + } + + Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder() + .sH(sH) + .sW(sW) + .type(Pooling2D.Pooling2DType.MAX) + .isSameMode(isSameMode) + .kH(kH) + .kW(kW) + .pH(pH) + .pW(pW) + .isNHWC(data_format.equalsIgnoreCase("nhwc")) + .extra(1.0) // averaging only for non-padded values + .build(); + this.config = pooling2DConfig; + addArgs(); + if(attributesForNode.containsKey("argmax")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); + } else { + outputType = DataType.UINT32; + } + } + + @Override + public Map> mappingsForFunction() { + Map> ret = new HashMap<>(); + Map map = new HashMap<>(); + val strideMapping = PropertyMapping.builder() + .tfAttrName("strides") + .onnxAttrName("strides") + .propertyNames(new String[]{"sW", "sH"}) + .build(); + + val paddingMapping = PropertyMapping.builder() + .onnxAttrName("padding") + .tfAttrName("padding") + .propertyNames(new String[]{"pH", "pW"}) + .build(); + + val kernelMapping = PropertyMapping.builder() + .propertyNames(new String[]{"kH", "kW"}) + .tfInputPosition(1) + .onnxAttrName("ksize") + .build(); + + val dilationMapping = PropertyMapping.builder() + .onnxAttrName("dilations") + .propertyNames(new String[]{"dW", "dH"}) + .tfAttrName("rates") + .build(); + + + //data_format + val dataFormatMapping = PropertyMapping.builder() + .propertyNames(new String[]{"isNHWC"}) + .tfAttrName("data_format") + .build(); + + map.put("sW", strideMapping); + map.put("sH", strideMapping); + map.put("kH", kernelMapping); + map.put("kW", kernelMapping); + map.put("dW", dilationMapping); + map.put("dH", dilationMapping); + map.put("pH", paddingMapping); + map.put("pW", paddingMapping); + map.put("isNHWC", dataFormatMapping); + + ret.put(onnxName(), map); + ret.put(tensorflowName(), map); + return ret; + } + + @Override + public String opName() { + return "max_pool_with_argmax"; + } + + @Override + public String onnxName() { + return "MaxPoolWithArgmax"; + } + + @Override + public String tensorflowName() { + return "MaxPoolWithArgmax"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); + List result = new ArrayList<>(); + result.add(inputDataTypes.get(0)); + result.add(outputType == null ? DataType.UINT32 : outputType); + return result; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index 09e928d2f..ad7984f2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -293,8 +293,8 @@ public class MaxPooling2D extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "MaxPool"; + public String[] tensorflowNames() { + return new String[]{"MaxPool","MaxPoolV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index 3927ba2bc..99f65f46d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -68,7 +68,7 @@ public class ClipByValue extends DynamicCustomOp { @Override public String opName() { - return "clipbyvalue"; + return "ClipByValue"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 3cc03d12b..6e87a05c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -53,15 +53,9 @@ public class RShiftBits extends BaseDynamicTransformOp { return "rshift_bits"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - @Override public String tensorflowName() { - throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); + return "RightShift"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index a9eebb14e..038cca54b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -53,15 +53,9 @@ public class ShiftBits extends BaseDynamicTransformOp { return "shift_bits"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - @Override public String tensorflowName() { - throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); + return "LeftShift"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java index 3f3bdbe74..74a7397a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java @@ -46,8 +46,8 @@ public class UniqueWithCounts extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "UniqueWithCounts"; + public String[] tensorflowNames() { + return new String[]{"UniqueWithCounts","UniqueWithCountsV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java index 5397108c6..3ee75d23d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java @@ -77,8 +77,8 @@ public class CopyOp extends BaseTransformSameOp { } @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + public String[] tensorflowNames() { + return new String[]{"Copy","DeepCopy","CopyHost"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java index 289333f96..46d477310 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java @@ -57,7 +57,7 @@ public class ModOp extends BaseDynamicTransformOp { @Override public String tensorflowName() { - return "mod"; + return "Mod"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index b49a89200..95bd0bf41 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -66,13 +66,7 @@ public class Not extends BaseTransformBoolOp { public String onnxName() { return "Not"; } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); - //return "Not"; - } - + @Override public List doDiff(List f1) { return Collections.singletonList(f().zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index ec91a98e6..b33ea8b8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -59,19 +59,6 @@ public class GELU extends BaseTransformStrictOp { return "gelu"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() - { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - //return "GELU"; - } - - @Override public List doDiff(List i_v) { SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java index ecc76a1b2..682d7c230 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java @@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -71,7 +72,12 @@ public class DistributionUniform extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - AttrValue v = attributesForNode.get("dtype"); + AttrValue vDtype = attributesForNode.get("dtype"); + AttrValue vTout = attributesForNode.get("Tout"); + if (vDtype == null && vTout == null) { + throw new ND4JIllegalStateException("Unable to find output data type for node " + nodeDef.getName()); + } + AttrValue v = vDtype == null ? vTout : vDtype; dataType = TFGraphMapper.convertType(v.getType()); addIArgument(dataType.toInt()); addTArgument(0.0, 1.0); //TF version is hardcoded 0 to 1 @@ -92,8 +98,8 @@ public class DistributionUniform extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "RandomUniform"; + public String[] tensorflowNames() { + return new String[]{"RandomUniform","RandomUniformInt"}; } @Override @@ -103,7 +109,7 @@ public class DistributionUniform extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 1*/, "Expected input datatypes for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape if(dataType != null){ return Collections.singletonList(dataType); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java index 32a823ac1..742e28113 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java @@ -65,18 +65,6 @@ public class DropOut extends BaseRandomOp { return "dropout"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName()); - //return opName(); - } - @Override public Type opType() { return Type.RANDOM ; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 6f4acd079..9dd529399 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -736,6 +736,35 @@ public class LayerOpValidation extends BaseOpValidation { // sd.execBackwards(); // TODO: test failing here } + @Test + public void testMaxPoolingArgMax() { + Nd4j.getRandom().setSeed(12345); + int nIn = 3; + int kH = 2; + int kW = 2; + + int mb = 3; + int imgH = 8; + int imgW = 8; + + SameDiff sd = SameDiff.create(); + INDArray inArr = Nd4j.rand(new int[]{mb, nIn, imgH, imgW}); + + SDVariable in = sd.var("in", inArr); + + Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder() + .kH(kH).kW(kW) + .pH(0).pW(0) + .sH(1).sW(1) + .dH(1).dW(1) + .isSameMode(true) + .build(); + + SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); + assertArrayEquals(inArr.shape(), results[0].eval().shape()); + assertArrayEquals(inArr.shape(), results[1].eval().shape()); + } + @Test public void testMaxPooling2dBasic() { Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 277bb8a83..ec65d71df 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -76,8 +76,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "adjust_contrast/.*", //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 "bincount/.*", - // Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400 - "bitcast/.*", // Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393 "is_strictly_increasing/emptyArrayTest/.*", @@ -116,20 +114,32 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 "zeros_like/rank2_float32_dtype_int.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399 - "crop_and_resize.*", - - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401 - "draw_bounding_boxes.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 "fake_quant/min_max_args_per_channel.*", // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 "resize_bilinear/int32.*", - // Suggesting TF 1.15 bug - see https://github.com/eclipse/deeplearning4j/issues/8449 - "non_max_suppression_v2/float16.*" + // Suggesting TF 1.15 bug + "non_max_suppression_v2/float16.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450 + "betainc.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452 + "polygamma.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453 + "roll/.*", + + // 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455 + "matrix_band_part/.*", + + // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458 + "adjust_hue/.*", + + // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459 + "adjust_saturation/.*" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index c01cf1942..fbb1ddb85 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -32,7 +32,10 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; +import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; +import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; @@ -53,6 +56,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static java.lang.Float.NaN; import static org.junit.Assert.*; /** @@ -867,6 +871,26 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); } + @Test + public void testAdjustSaturation() { + INDArray in = Nd4j.createFromArray(new double[]{50,100,78, 118.5,220,112.5,190,163.5,230, 255,128.5,134}).reshape(2,2,3); + INDArray out = Nd4j.create(in.shape()); + INDArray expected = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); + + Nd4j.exec(new AdjustSaturation(in, 2.0, out)); + assertEquals(expected, out); + } + + @Test + public void testAdjustHue() { + INDArray in = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); + INDArray out = Nd4j.create(in.shape()); + INDArray expected = Nd4j.createFromArray(new double[]{100,0,44, 208,5,220, 177,230,97, 2,255,244}).reshape(2,2,3); + + Nd4j.exec(new AdjustHue(in, 0.5, out)); + assertEquals(expected, out); + } + @Test public void testBitCast() { INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); @@ -1088,6 +1112,216 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + @Test + public void testBetaInc() { + Nd4j.getRandom().setSeed(10); + INDArray a = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray b = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray x = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray expected = Nd4j.createFromArray(new float[]{0.4121f, 0.3926f, 0.4082f, + 0.4414f, 0.5000f, 0.5703f, + 0.6562f, 0.7656f, 0.8828f}).reshape(3,3); + + BetaInc op = new BetaInc(a,b,x); + INDArray[] out = Nd4j.exec(op); + assertArrayEquals(expected.shape(), out[0].shape()); + for (int i = 0; i < 3; ++i) + assertArrayEquals(expected.toDoubleMatrix()[i], out[0].toDoubleMatrix()[i], 1e-4); + } + + @Test + public void testFusedBatchNorm() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4); + INDArray scale = Nd4j.create(DataType.DOUBLE, 4); + scale.assign(0.5); + INDArray offset = Nd4j.create(DataType.DOUBLE, 4); + offset.assign(2.0); + + INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape()); + INDArray batchMean = Nd4j.create(4); + INDArray batchVar = Nd4j.create(4); + + FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1, + y, batchMean, batchVar); + + INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462, + 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, + 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, + 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, + 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, + 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, + 2.79662538, 2.79662538, 2.79662538}).reshape(x.shape()); + INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.}); + INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526}); + Nd4j.exec(op); + assertArrayEquals(expectedY.shape(), y.shape()); + assertArrayEquals(expectedBatchMean.shape(), batchMean.shape()); + assertArrayEquals(expectedBatchVar.shape(), batchVar.shape()); + } + + @Test + public void testMatrixBandPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + val op = new MatrixBandPart(x,1,1); + INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + /*expected.putScalar(0, 0, 2, 0.); + expected.putScalar(1, 0, 2, 0.); + expected.putScalar(0, 2, 0, 0.); + expected.putScalar(1, 2, 0, 0.);*/ + + INDArray[] out = Nd4j.exec(op); + assertEquals(expected, x); + } + + @Test + public void testPolygamma() { + INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3); + INDArray x = Nd4j.create(DataType.FLOAT, 3,3); + x.assign(0.5); + INDArray expected = Nd4j.createFromArray(new float[]{4.934802f, -16.828796f, 97.409088f, -771.474243f, + 7691.113770f, -92203.460938f, 1290440.250000f, -20644900.000000f, 3.71595e+08f}).reshape(3,3); + INDArray output = Nd4j.create(DataType.FLOAT, expected.shape()); + val op = new Polygamma(x,n,output); + Nd4j.exec(op); + assertEquals(expected, output); + } + + @Test + public void testRandomCrop() { + INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4); + INDArray shape = Nd4j.createFromArray(new int[] {1,2,3}); + val op = new RandomCrop(x, shape); + INDArray[] res = Nd4j.exec(op); + } + + @Test + public void testRoll() { + INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). + reshape(2,2,4,2); + + INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }).reshape(x.shape()); + val op = new Roll(x, 6); + INDArray[] res = Nd4j.exec(op); + assertEquals(expected, res[0]); + } + + @Test + public void testToggleBits() { + INDArray input = Nd4j.createFromArray(new int[]{2,2}); + INDArray expected = Nd4j.createFromArray(new int[]{-3,-3}); + ToggleBits op = new ToggleBits(input); + val result = Nd4j.exec(op); + assertEquals(expected, result[0]); + } + + @Ignore("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449") + @Test + public void testNonMaxSuppression() { + INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4); + INDArray scores = Nd4j.createFromArray(new float[]{0.0029f, 0.8135f, 0.4873f}); + val op = new NonMaxSuppression(boxes,scores,2,0.5,0.5); + val res = Nd4j.exec(op); + assertEquals(new long[]{1}, res[0].shape()); + } + + @Test + public void testMatrixBand() { + INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, + 0.7271f,0.1804f,0.5056f,0.8925f, + 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); + MatrixBandPart op = new MatrixBandPart(input,1,-1); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + } + + @Ignore("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450") + @Test + public void testBetaInc1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); + INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f}); + INDArray c = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); + BetaInc op = new BetaInc(a,b,c); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.9122f, 0.6344f, 0.8983f, 0.6245f}); + assertEquals(expected, ret[0]); + } + + @Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452") + @Test + public void testPolygamma1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4); + INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4); + INDArray expected = Nd4j.createFromArray(new float[]{NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN, }).reshape(3,4); + Polygamma op = new Polygamma(a,b); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453") + @Test + public void testRoll1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); + Roll op = new Roll(a,Nd4j.scalar(2),Nd4j.scalar(0)); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f}); + assertEquals(expected, ret[0]); + } + + @Test + public void testAdjustHueShape(){ + INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, + 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, + 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f, + 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, + 0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f, + 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f, + 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, + 0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, + 0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f, + 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f, + 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, + 0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, + 0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f, + 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, + 0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, + 0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f, + 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f, + 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, + 0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, + 0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f, + 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f, + 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, + 0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, + 0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f, + 0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f, + 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, + 0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, + 0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3); + + AdjustHue op = new AdjustHue(image, 0.2f); + INDArray[] res = Nd4j.exec(op); + System.out.println(res[0]); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape()); + } + @Test public void testBitCastShape_3(){ val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2);