Small number of fixes + cleanup + some missing op methods + constructors (#100)
* Remove unused op class - DefaultOpConverter Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add SDImage class; INDArray constructor additions Signed-off-by: AlexDBlack <blacka101@gmail.com> * Floordiv Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small polish to image methods Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small DataVec test fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
923ab15583
commit
b8846113bd
|
@ -34,8 +34,8 @@ import static org.junit.Assert.assertEquals;
|
|||
public class RecordConverterTest {
|
||||
@Test
|
||||
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
||||
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 3}, DataType.FLOAT);
|
||||
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 3}, DataType.FLOAT);
|
||||
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
|
||||
|
|
|
@ -67,14 +67,9 @@ namespace nd4j {
|
|||
auto gradX = OUTPUT_VARIABLE(0);
|
||||
auto gradY = OUTPUT_VARIABLE(1);
|
||||
|
||||
gradY->assign(x);
|
||||
std::unique_ptr<NDArray> ySq(y->dup());
|
||||
ySq->applyTransform(transform::Square, nullptr);
|
||||
gradY->applyPairwiseTransform(pairwise::FloorDiv, ySq.get(), gradY, nullptr);
|
||||
gradY->applyPairwiseTransform(pairwise::Multiply, epsNext, gradY, nullptr);
|
||||
gradY->applyTransform(transform::Neg, nullptr);
|
||||
gradX->assign(epsNext);
|
||||
//gradX->applyPairwiseTransform(pairwise::FloorDiv, y, gradX, nullptr);
|
||||
gradY->assign(0.0f);
|
||||
gradX->assign(0.0f);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -167,6 +167,8 @@ public class SameDiff extends SDBaseOps {
|
|||
public final SDRNN rnn = new SDRNN(this);
|
||||
/** Op creator object for loss function operations */
|
||||
public final SDLoss loss = new SDLoss(this);
|
||||
/** Op creator object for image operations */
|
||||
public final SDImage image = new SDImage(this);
|
||||
|
||||
/** Op creator object for math operations */
|
||||
public SDMath math(){
|
||||
|
@ -198,6 +200,10 @@ public class SameDiff extends SDBaseOps {
|
|||
return loss;
|
||||
}
|
||||
|
||||
/** Op creator object for image operations */
|
||||
public SDImage image(){
|
||||
return image;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
|
||||
/**
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SDImage extends SDOps {
|
||||
public SDImage(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
|
||||
*
|
||||
* @param name May be null. Name for the output variable.
|
||||
* @param image Input image, with shape [batch, height, width, channels]
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes]
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth]
|
||||
* @param method Image resize method
|
||||
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
|
||||
* @return Cropped and resized images
|
||||
*/
|
||||
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize,
|
||||
CropAndResize.Method method, double extrapolationValue) {
|
||||
SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
|
||||
*
|
||||
* @param name Map be null. Name for the output variable
|
||||
* @param image Input image to extract image patches from - shape [batch, height, width, channels]
|
||||
* @param kSizes Kernel size - size of the image patches, [height, width]
|
||||
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width]
|
||||
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
|
||||
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
|
||||
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension
|
||||
* @param sameMode Padding algorithm. If true: use Same padding
|
||||
* @return The extracted image patches
|
||||
*/
|
||||
public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes,
|
||||
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode) {
|
||||
SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
|
||||
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
|
||||
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
|
@ -33,7 +33,6 @@ import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
|
|||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
import org.nd4j.linalg.api.ops.DefaultOpConverter;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
||||
|
@ -838,7 +837,6 @@ public class OpValidation {
|
|||
//Exclude misc
|
||||
DynamicCustomOp.class,
|
||||
GradientBackwardsMarker.class,
|
||||
DefaultOpConverter.class,
|
||||
EqualsWithEps.class,
|
||||
FreeGridOp.class,
|
||||
MergeSum.class, //Redundant; we use MergeAdd in samediff instead
|
||||
|
|
|
@ -41,7 +41,6 @@ public class ImportClassMapping {
|
|||
private static final Map<String, DifferentialFunction> ONNX_OP_NAME_MAP = new HashMap<>();
|
||||
|
||||
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
|
||||
org.nd4j.linalg.api.ops.DefaultOpConverter.class,
|
||||
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
|
||||
org.nd4j.linalg.api.ops.NoOp.class,
|
||||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class DefaultOpConverter extends BaseOp {
|
||||
private static DefaultOpConverter INSTANCE = new DefaultOpConverter();
|
||||
public static DefaultOpConverter getInstance() {
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "defaultop";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
|
||||
}
|
|
@ -135,14 +135,23 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
bArguments = new ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize this operation for execution (pre created ndarrays)
|
||||
*
|
||||
* @param inputs the inputs
|
||||
* @param outputs the outputs of the op, may be null
|
||||
*/
|
||||
public DynamicCustomOp(INDArray[] inputs, INDArray[] outputs) {
|
||||
this(null, inputs, outputs);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initialize this operation for execution (pre created ndarrays)
|
||||
*
|
||||
* @param opName the operation opName to use
|
||||
* for invocation
|
||||
* @param opName the operation opName to use for invocation
|
||||
* @param inputs the inputs
|
||||
* @param outputs the outputs of the op
|
||||
* @param outputs the outputs of the op, may be null
|
||||
*/
|
||||
public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) {
|
||||
this(opName, inputs, outputs, Lists.<Double>newArrayList(), Lists.<Integer>newArrayList());
|
||||
|
@ -600,6 +609,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
|
||||
}
|
||||
|
||||
protected static INDArray[] wrapOrNull(INDArray in){
|
||||
return in == null ? null : new INDArray[]{in};
|
||||
}
|
||||
|
||||
public static class DynamicCustomOpsBuilder {
|
||||
protected String opName;
|
||||
protected int numInputs;
|
||||
|
|
|
@ -17,11 +17,13 @@
|
|||
package org.nd4j.linalg.api.ops.impl.broadcast;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
@ -42,6 +44,10 @@ public class BiasAdd extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||
}
|
||||
|
||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){
|
||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "biasadd";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.broadcast;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -35,6 +36,10 @@ public class BiasAddGrad extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
||||
}
|
||||
|
||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
||||
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
public BiasAddGrad() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.image;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
|
@ -33,11 +36,32 @@ import java.util.*;
|
|||
* CropAndResize Op
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class CropAndResize extends DynamicCustomOp {
|
||||
public enum Method {BILINEAR, NEAREST};
|
||||
protected Method method = Method.BILINEAR;
|
||||
protected double extrapolationValue = 0.0;
|
||||
|
||||
public CropAndResize(@NonNull SameDiff sameDiff, @NonNull SDVariable image, @NonNull SDVariable cropBoxes, @NonNull SDVariable boxIndices,
|
||||
@NonNull SDVariable cropOutSize, @NonNull Method method, double extrapolationValue){
|
||||
super(sameDiff, new SDVariable[]{image, cropBoxes, boxIndices, cropOutSize});
|
||||
this.method = method;
|
||||
this.extrapolationValue = extrapolationValue;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
||||
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){
|
||||
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
|
||||
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
|
||||
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
|
||||
Preconditions.checkArgument(boxIndices.rank() == 1 && cropBoxes.size(0) == boxIndices.size(0),
|
||||
"Box indices must be rank 1 array with shape [num_boxes] (same as cropBoxes.size(0), got array with shape %ndShape", boxIndices);
|
||||
this.method = method;
|
||||
this.extrapolationValue = extrapolationValue;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "crop_and_resize";
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
|
@ -54,9 +55,27 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
|||
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
|
||||
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
|
||||
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
|
||||
this.kSizes = kSizes;
|
||||
this.strides = strides;
|
||||
this.rates = rates;
|
||||
this.isSameMode = sameMode;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public ExtractImagePatches(@NonNull INDArray input, @NonNull int[] kSizes,
|
||||
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode){
|
||||
super(new INDArray[]{input}, null);
|
||||
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
|
||||
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
|
||||
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
|
||||
this.kSizes = kSizes;
|
||||
this.strides = strides;
|
||||
this.rates = rates;
|
||||
this.isSameMode = sameMode;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "extract_image_patches";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.image;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -27,7 +28,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* IdentityN op wrapper
|
||||
* Non max suppression
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
|
@ -35,8 +36,9 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
|
||||
public NonMaxSuppression() {}
|
||||
|
||||
public NonMaxSuppression(SameDiff sameDiff, SDVariable[] input) {
|
||||
super(null, sameDiff, input, false);
|
||||
public NonMaxSuppression(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) {
|
||||
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -4,6 +4,7 @@ import org.junit.Test;
|
|||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ops.NoOp;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.reflections.Reflections;
|
||||
|
@ -14,7 +15,7 @@ import org.reflections.util.FilterBuilder;
|
|||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
@ -24,6 +25,18 @@ public class OpConstructorTests extends BaseNd4jTest {
|
|||
super(backend);
|
||||
}
|
||||
|
||||
//Ignore individual classes
|
||||
protected Set<Class<?>> exclude = new HashSet<>(
|
||||
Arrays.asList(
|
||||
NoOp.class
|
||||
)
|
||||
);
|
||||
|
||||
//Ignore whole sets of classes based on regex
|
||||
protected String[] ignoreRegexes = new String[]{
|
||||
"org\\.nd4j\\.linalg\\.api\\.ops\\.impl\\.controlflow\\..*"
|
||||
};
|
||||
|
||||
@Test
|
||||
public void checkForINDArrayConstructors() throws Exception {
|
||||
/*
|
||||
|
@ -38,11 +51,24 @@ public class OpConstructorTests extends BaseNd4jTest {
|
|||
Set<Class<? extends DifferentialFunction>> classSet = f.getSubTypesOf(DifferentialFunction.class);
|
||||
|
||||
int count = 0;
|
||||
List<Class<?>> classes = new ArrayList<>();
|
||||
for(Class<?> c : classSet){
|
||||
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c))
|
||||
continue;
|
||||
|
||||
// System.out.println(c.getName());
|
||||
if(exclude.contains(c))
|
||||
continue;
|
||||
|
||||
String cn = c.getName();
|
||||
boolean ignored = false;
|
||||
for(String s : ignoreRegexes ){
|
||||
if(cn.matches(s)){
|
||||
ignored = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(ignored)
|
||||
continue;
|
||||
|
||||
Constructor<?>[] constructors = c.getConstructors();
|
||||
boolean foundINDArray = false;
|
||||
|
@ -56,12 +82,22 @@ public class OpConstructorTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
if(!foundINDArray){
|
||||
System.out.println("No INDArray constructor: " + c.getName());
|
||||
count++;
|
||||
classes.add(c);
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, count);
|
||||
if(!classes.isEmpty()){
|
||||
Collections.sort(classes, new Comparator<Class<?>>() {
|
||||
@Override
|
||||
public int compare(Class<?> o1, Class<?> o2) {
|
||||
return o1.getName().compareTo(o2.getName());
|
||||
}
|
||||
});
|
||||
for(Class<?> c : classes){
|
||||
System.out.println("No INDArray constructor: " + c.getName());
|
||||
}
|
||||
}
|
||||
assertEquals("Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors", 0, classes.size());
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue