diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index 43847244b..ce7d779dc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -34,8 +34,8 @@ import static org.junit.Assert.assertEquals; public class RecordConverterTest { @Test public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { - INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 3}, DataType.FLOAT); - INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 3}, DataType.FLOAT); + INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); + INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT); INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT); INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT); DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp index 9925b554f..10d898522 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp @@ -67,14 +67,9 @@ namespace nd4j { auto gradX = OUTPUT_VARIABLE(0); auto gradY = OUTPUT_VARIABLE(1); - gradY->assign(x); - std::unique_ptr ySq(y->dup()); - ySq->applyTransform(transform::Square, nullptr); - gradY->applyPairwiseTransform(pairwise::FloorDiv, ySq.get(), gradY, nullptr); - gradY->applyPairwiseTransform(pairwise::Multiply, epsNext, gradY, nullptr); - gradY->applyTransform(transform::Neg, nullptr); - gradX->assign(epsNext); - //gradX->applyPairwiseTransform(pairwise::FloorDiv, y, gradX, nullptr); + gradY->assign(0.0f); + gradX->assign(0.0f); + return Status::OK(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 2e6ed9ca1..b3abf7b00 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -167,6 +167,8 @@ public class SameDiff extends SDBaseOps { public final SDRNN rnn = new SDRNN(this); /** Op creator object for loss function operations */ public final SDLoss loss = new SDLoss(this); + /** Op creator object for image operations */ + public final SDImage image = new SDImage(this); /** Op creator object for math operations */ public SDMath math(){ @@ -198,6 +200,10 @@ public class SameDiff extends SDBaseOps { return loss; } + /** Op creator object for image operations */ + public SDImage image(){ + return image; + } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java new file mode 100644 index 000000000..f7166ab5e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -0,0 +1,61 @@ +package org.nd4j.autodiff.samediff.ops; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.impl.image.CropAndResize; +import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; +import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; + +/** + * @author Alex Black + */ +public class SDImage extends SDOps { + public SDImage(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size. + * + * @param name May be null. Name for the output variable. + * @param image Input image, with shape [batch, height, width, channels] + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] + * @param method Image resize method + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return Cropped and resized images + */ + public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize, + CropAndResize.Method method, double extrapolationValue) { + SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. + * + * @param name Map be null. Name for the output variable + * @param image Input image to extract image patches from - shape [batch, height, width, channels] + * @param kSizes Kernel size - size of the image patches, [height, width] + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension + * @param sameMode Padding algorithm. If true: use Same padding + * @return The extracted image patches + */ + public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes, + @NonNull int[] strides, @NonNull int[] rates, boolean sameMode) { + SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + + public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, + @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ + SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); + return updateVariableNameAndReference(out, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index af7829eab..5ea26b0ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -33,7 +33,6 @@ import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.DefaultOpConverter; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; import org.nd4j.linalg.api.ops.custom.BarnesHutGains; @@ -838,7 +837,6 @@ public class OpValidation { //Exclude misc DynamicCustomOp.class, GradientBackwardsMarker.class, - DefaultOpConverter.class, EqualsWithEps.class, FreeGridOp.class, MergeSum.class, //Redundant; we use MergeAdd in samediff instead diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 723b77511..c7f8cbd64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -41,7 +41,6 @@ public class ImportClassMapping { private static final Map ONNX_OP_NAME_MAP = new HashMap<>(); private static final List> fnClasses = Arrays.>asList( - org.nd4j.linalg.api.ops.DefaultOpConverter.class, org.nd4j.linalg.api.ops.DynamicCustomOp.class, org.nd4j.linalg.api.ops.NoOp.class, org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DefaultOpConverter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DefaultOpConverter.java deleted file mode 100644 index b6774113d..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DefaultOpConverter.java +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.imports.NoOpNameFoundException; - -import java.util.List; - -public class DefaultOpConverter extends BaseOp { - private static DefaultOpConverter INSTANCE = new DefaultOpConverter(); - public static DefaultOpConverter getInstance() { - return INSTANCE; - } - - @Override - public List doDiff(List f1) { - return null; - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "defaultop"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 972bf08ef..fa8401bf0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -135,14 +135,23 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { bArguments = new ArrayList<>(); } + /** + * Initialize this operation for execution (pre created ndarrays) + * + * @param inputs the inputs + * @param outputs the outputs of the op, may be null + */ + public DynamicCustomOp(INDArray[] inputs, INDArray[] outputs) { + this(null, inputs, outputs); + } + /** * Initialize this operation for execution (pre created ndarrays) * - * @param opName the operation opName to use - * for invocation + * @param opName the operation opName to use for invocation * @param inputs the inputs - * @param outputs the outputs of the op + * @param outputs the outputs of the op, may be null */ public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) { this(opName, inputs, outputs, Lists.newArrayList(), Lists.newArrayList()); @@ -600,6 +609,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } + protected static INDArray[] wrapOrNull(INDArray in){ + return in == null ? null : new INDArray[]{in}; + } + public static class DynamicCustomOpsBuilder { protected String opName; protected int numInputs; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index 875cdd48f..aed50c987 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -17,11 +17,13 @@ package org.nd4j.linalg.api.ops.impl.broadcast; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; @@ -42,6 +44,10 @@ public class BiasAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input, bias}, false); } + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){ + super(new INDArray[]{input, bias}, wrapOrNull(output)); + } + @Override public String opName() { return "biasadd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java index c49c0d008..0d6ced083 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.broadcast; +import lombok.NonNull; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -35,6 +36,10 @@ public class BiasAddGrad extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{input, bias, gradient}); } + public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, bias, gradient}, wrapOrNull(output)); + } + public BiasAddGrad() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 3e49f4684..4ede302dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.image; +import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -33,11 +36,32 @@ import java.util.*; * CropAndResize Op * @author Alex Black */ +@NoArgsConstructor public class CropAndResize extends DynamicCustomOp { public enum Method {BILINEAR, NEAREST}; protected Method method = Method.BILINEAR; protected double extrapolationValue = 0.0; + public CropAndResize(@NonNull SameDiff sameDiff, @NonNull SDVariable image, @NonNull SDVariable cropBoxes, @NonNull SDVariable boxIndices, + @NonNull SDVariable cropOutSize, @NonNull Method method, double extrapolationValue){ + super(sameDiff, new SDVariable[]{image, cropBoxes, boxIndices, cropOutSize}); + this.method = method; + this.extrapolationValue = extrapolationValue; + addArgs(); + } + + public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, + @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){ + super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null); + Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image); + Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes); + Preconditions.checkArgument(boxIndices.rank() == 1 && cropBoxes.size(0) == boxIndices.size(0), + "Box indices must be rank 1 array with shape [num_boxes] (same as cropBoxes.size(0), got array with shape %ndShape", boxIndices); + this.method = method; + this.extrapolationValue = extrapolationValue; + addArgs(); + } + @Override public String opName() { return "crop_and_resize"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 3436972ee..62194c044 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -54,9 +55,27 @@ public class ExtractImagePatches extends DynamicCustomOp { Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes); Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides); Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates); + this.kSizes = kSizes; + this.strides = strides; + this.rates = rates; this.isSameMode = sameMode; + addArgs(); } + public ExtractImagePatches(@NonNull INDArray input, @NonNull int[] kSizes, + @NonNull int[] strides, @NonNull int[] rates, boolean sameMode){ + super(new INDArray[]{input}, null); + Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes); + Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides); + Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates); + this.kSizes = kSizes; + this.strides = strides; + this.rates = rates; + this.isSameMode = sameMode; + addArgs(); + } + + @Override public String opName() { return "extract_image_patches"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index ae94f20bb..d7161cf5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.image; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -27,7 +28,7 @@ import java.util.Collections; import java.util.List; /** - * IdentityN op wrapper + * Non max suppression * * @author raver119@gmail.com */ @@ -35,8 +36,9 @@ public class NonMaxSuppression extends DynamicCustomOp { public NonMaxSuppression() {} - public NonMaxSuppression(SameDiff sameDiff, SDVariable[] input) { - super(null, sameDiff, input, false); + public NonMaxSuppression(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, + @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) { + super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java index baef89a1a..6644caa4b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java @@ -4,6 +4,7 @@ import org.junit.Test; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.reflections.Reflections; @@ -14,7 +15,7 @@ import org.reflections.util.FilterBuilder; import java.lang.reflect.Constructor; import java.lang.reflect.Modifier; -import java.util.Set; +import java.util.*; import static org.junit.Assert.assertEquals; @@ -24,6 +25,18 @@ public class OpConstructorTests extends BaseNd4jTest { super(backend); } + //Ignore individual classes + protected Set> exclude = new HashSet<>( + Arrays.asList( + NoOp.class + ) + ); + + //Ignore whole sets of classes based on regex + protected String[] ignoreRegexes = new String[]{ + "org\\.nd4j\\.linalg\\.api\\.ops\\.impl\\.controlflow\\..*" + }; + @Test public void checkForINDArrayConstructors() throws Exception { /* @@ -38,11 +51,24 @@ public class OpConstructorTests extends BaseNd4jTest { Set> classSet = f.getSubTypesOf(DifferentialFunction.class); int count = 0; + List> classes = new ArrayList<>(); for(Class c : classSet){ if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c)) continue; -// System.out.println(c.getName()); + if(exclude.contains(c)) + continue; + + String cn = c.getName(); + boolean ignored = false; + for(String s : ignoreRegexes ){ + if(cn.matches(s)){ + ignored = true; + break; + } + } + if(ignored) + continue; Constructor[] constructors = c.getConstructors(); boolean foundINDArray = false; @@ -56,12 +82,22 @@ public class OpConstructorTests extends BaseNd4jTest { } if(!foundINDArray){ - System.out.println("No INDArray constructor: " + c.getName()); - count++; + classes.add(c); } } - assertEquals(0, count); + if(!classes.isEmpty()){ + Collections.sort(classes, new Comparator>() { + @Override + public int compare(Class o1, Class o2) { + return o1.getName().compareTo(o2.getName()); + } + }); + for(Class c : classes){ + System.out.println("No INDArray constructor: " + c.getName()); + } + } + assertEquals("Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors", 0, classes.size()); }