Small number of fixes + cleanup + some missing op methods + constructors (#100)
* Remove unused op class - DefaultOpConverter Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add SDImage class; INDArray constructor additions Signed-off-by: AlexDBlack <blacka101@gmail.com> * Floordiv Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small polish to image methods Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small DataVec test fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
923ab15583
commit
b8846113bd
|
@ -34,8 +34,8 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class RecordConverterTest {
|
public class RecordConverterTest {
|
||||||
@Test
|
@Test
|
||||||
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
||||||
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 3}, DataType.FLOAT);
|
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
||||||
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 3}, DataType.FLOAT);
|
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
|
||||||
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||||
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||||
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
|
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
|
||||||
|
|
|
@ -67,14 +67,9 @@ namespace nd4j {
|
||||||
auto gradX = OUTPUT_VARIABLE(0);
|
auto gradX = OUTPUT_VARIABLE(0);
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
gradY->assign(x);
|
gradY->assign(0.0f);
|
||||||
std::unique_ptr<NDArray> ySq(y->dup());
|
gradX->assign(0.0f);
|
||||||
ySq->applyTransform(transform::Square, nullptr);
|
|
||||||
gradY->applyPairwiseTransform(pairwise::FloorDiv, ySq.get(), gradY, nullptr);
|
|
||||||
gradY->applyPairwiseTransform(pairwise::Multiply, epsNext, gradY, nullptr);
|
|
||||||
gradY->applyTransform(transform::Neg, nullptr);
|
|
||||||
gradX->assign(epsNext);
|
|
||||||
//gradX->applyPairwiseTransform(pairwise::FloorDiv, y, gradX, nullptr);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -167,6 +167,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
public final SDRNN rnn = new SDRNN(this);
|
public final SDRNN rnn = new SDRNN(this);
|
||||||
/** Op creator object for loss function operations */
|
/** Op creator object for loss function operations */
|
||||||
public final SDLoss loss = new SDLoss(this);
|
public final SDLoss loss = new SDLoss(this);
|
||||||
|
/** Op creator object for image operations */
|
||||||
|
public final SDImage image = new SDImage(this);
|
||||||
|
|
||||||
/** Op creator object for math operations */
|
/** Op creator object for math operations */
|
||||||
public SDMath math(){
|
public SDMath math(){
|
||||||
|
@ -198,6 +200,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
return loss;
|
return loss;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Op creator object for image operations */
|
||||||
|
public SDImage image(){
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class SDImage extends SDOps {
|
||||||
|
public SDImage(SameDiff sameDiff) {
|
||||||
|
super(sameDiff);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
|
||||||
|
*
|
||||||
|
* @param name May be null. Name for the output variable.
|
||||||
|
* @param image Input image, with shape [batch, height, width, channels]
|
||||||
|
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1
|
||||||
|
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes]
|
||||||
|
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth]
|
||||||
|
* @param method Image resize method
|
||||||
|
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
|
||||||
|
* @return Cropped and resized images
|
||||||
|
*/
|
||||||
|
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize,
|
||||||
|
CropAndResize.Method method, double extrapolationValue) {
|
||||||
|
SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
|
||||||
|
*
|
||||||
|
* @param name Map be null. Name for the output variable
|
||||||
|
* @param image Input image to extract image patches from - shape [batch, height, width, channels]
|
||||||
|
* @param kSizes Kernel size - size of the image patches, [height, width]
|
||||||
|
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width]
|
||||||
|
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
|
||||||
|
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
|
||||||
|
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension
|
||||||
|
* @param sameMode Padding algorithm. If true: use Same padding
|
||||||
|
* @return The extracted image patches
|
||||||
|
*/
|
||||||
|
public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes,
|
||||||
|
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode) {
|
||||||
|
SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||||
|
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
|
||||||
|
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
}
|
|
@ -33,7 +33,6 @@ import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||||
import org.nd4j.linalg.api.ops.DefaultOpConverter;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
||||||
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
||||||
|
@ -838,7 +837,6 @@ public class OpValidation {
|
||||||
//Exclude misc
|
//Exclude misc
|
||||||
DynamicCustomOp.class,
|
DynamicCustomOp.class,
|
||||||
GradientBackwardsMarker.class,
|
GradientBackwardsMarker.class,
|
||||||
DefaultOpConverter.class,
|
|
||||||
EqualsWithEps.class,
|
EqualsWithEps.class,
|
||||||
FreeGridOp.class,
|
FreeGridOp.class,
|
||||||
MergeSum.class, //Redundant; we use MergeAdd in samediff instead
|
MergeSum.class, //Redundant; we use MergeAdd in samediff instead
|
||||||
|
|
|
@ -41,7 +41,6 @@ public class ImportClassMapping {
|
||||||
private static final Map<String, DifferentialFunction> ONNX_OP_NAME_MAP = new HashMap<>();
|
private static final Map<String, DifferentialFunction> ONNX_OP_NAME_MAP = new HashMap<>();
|
||||||
|
|
||||||
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
|
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
|
||||||
org.nd4j.linalg.api.ops.DefaultOpConverter.class,
|
|
||||||
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
|
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
|
||||||
org.nd4j.linalg.api.ops.NoOp.class,
|
org.nd4j.linalg.api.ops.NoOp.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||||
|
|
|
@ -1,55 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class DefaultOpConverter extends BaseOp {
|
|
||||||
private static DefaultOpConverter INSTANCE = new DefaultOpConverter();
|
|
||||||
public static DefaultOpConverter getInstance() {
|
|
||||||
return INSTANCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "defaultop";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -135,14 +135,23 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
bArguments = new ArrayList<>();
|
bArguments = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize this operation for execution (pre created ndarrays)
|
||||||
|
*
|
||||||
|
* @param inputs the inputs
|
||||||
|
* @param outputs the outputs of the op, may be null
|
||||||
|
*/
|
||||||
|
public DynamicCustomOp(INDArray[] inputs, INDArray[] outputs) {
|
||||||
|
this(null, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize this operation for execution (pre created ndarrays)
|
* Initialize this operation for execution (pre created ndarrays)
|
||||||
*
|
*
|
||||||
* @param opName the operation opName to use
|
* @param opName the operation opName to use for invocation
|
||||||
* for invocation
|
|
||||||
* @param inputs the inputs
|
* @param inputs the inputs
|
||||||
* @param outputs the outputs of the op
|
* @param outputs the outputs of the op, may be null
|
||||||
*/
|
*/
|
||||||
public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) {
|
public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) {
|
||||||
this(opName, inputs, outputs, Lists.<Double>newArrayList(), Lists.<Integer>newArrayList());
|
this(opName, inputs, outputs, Lists.<Double>newArrayList(), Lists.<Integer>newArrayList());
|
||||||
|
@ -600,6 +609,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static INDArray[] wrapOrNull(INDArray in){
|
||||||
|
return in == null ? null : new INDArray[]{in};
|
||||||
|
}
|
||||||
|
|
||||||
public static class DynamicCustomOpsBuilder {
|
public static class DynamicCustomOpsBuilder {
|
||||||
protected String opName;
|
protected String opName;
|
||||||
protected int numInputs;
|
protected int numInputs;
|
||||||
|
|
|
@ -17,11 +17,13 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.broadcast;
|
package org.nd4j.linalg.api.ops.impl.broadcast;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
@ -42,6 +44,10 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){
|
||||||
|
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "biasadd";
|
return "biasadd";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.broadcast;
|
package org.nd4j.linalg.api.ops.impl.broadcast;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -35,6 +36,10 @@ public class BiasAddGrad extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
public BiasAddGrad() {}
|
public BiasAddGrad() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,11 +16,14 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.image;
|
package org.nd4j.linalg.api.ops.impl.image;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -33,11 +36,32 @@ import java.util.*;
|
||||||
* CropAndResize Op
|
* CropAndResize Op
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
public class CropAndResize extends DynamicCustomOp {
|
public class CropAndResize extends DynamicCustomOp {
|
||||||
public enum Method {BILINEAR, NEAREST};
|
public enum Method {BILINEAR, NEAREST};
|
||||||
protected Method method = Method.BILINEAR;
|
protected Method method = Method.BILINEAR;
|
||||||
protected double extrapolationValue = 0.0;
|
protected double extrapolationValue = 0.0;
|
||||||
|
|
||||||
|
public CropAndResize(@NonNull SameDiff sameDiff, @NonNull SDVariable image, @NonNull SDVariable cropBoxes, @NonNull SDVariable boxIndices,
|
||||||
|
@NonNull SDVariable cropOutSize, @NonNull Method method, double extrapolationValue){
|
||||||
|
super(sameDiff, new SDVariable[]{image, cropBoxes, boxIndices, cropOutSize});
|
||||||
|
this.method = method;
|
||||||
|
this.extrapolationValue = extrapolationValue;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
||||||
|
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){
|
||||||
|
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
|
||||||
|
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
|
||||||
|
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
|
||||||
|
Preconditions.checkArgument(boxIndices.rank() == 1 && cropBoxes.size(0) == boxIndices.size(0),
|
||||||
|
"Box indices must be rank 1 array with shape [num_boxes] (same as cropBoxes.size(0), got array with shape %ndShape", boxIndices);
|
||||||
|
this.method = method;
|
||||||
|
this.extrapolationValue = extrapolationValue;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "crop_and_resize";
|
return "crop_and_resize";
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -54,9 +55,27 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
||||||
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
|
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
|
||||||
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
|
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
|
||||||
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
|
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
|
||||||
|
this.kSizes = kSizes;
|
||||||
|
this.strides = strides;
|
||||||
|
this.rates = rates;
|
||||||
this.isSameMode = sameMode;
|
this.isSameMode = sameMode;
|
||||||
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ExtractImagePatches(@NonNull INDArray input, @NonNull int[] kSizes,
|
||||||
|
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode){
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
|
||||||
|
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
|
||||||
|
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
|
||||||
|
this.kSizes = kSizes;
|
||||||
|
this.strides = strides;
|
||||||
|
this.rates = rates;
|
||||||
|
this.isSameMode = sameMode;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "extract_image_patches";
|
return "extract_image_patches";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.image;
|
package org.nd4j.linalg.api.ops.impl.image;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
@ -27,7 +28,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* IdentityN op wrapper
|
* Non max suppression
|
||||||
*
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@ -35,8 +36,9 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
||||||
|
|
||||||
public NonMaxSuppression() {}
|
public NonMaxSuppression() {}
|
||||||
|
|
||||||
public NonMaxSuppression(SameDiff sameDiff, SDVariable[] input) {
|
public NonMaxSuppression(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||||
super(null, sameDiff, input, false);
|
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) {
|
||||||
|
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -4,6 +4,7 @@ import org.junit.Test;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.ops.NoOp;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.reflections.Reflections;
|
import org.reflections.Reflections;
|
||||||
|
@ -14,7 +15,7 @@ import org.reflections.util.FilterBuilder;
|
||||||
|
|
||||||
import java.lang.reflect.Constructor;
|
import java.lang.reflect.Constructor;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
import java.util.Set;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@ -24,6 +25,18 @@ public class OpConstructorTests extends BaseNd4jTest {
|
||||||
super(backend);
|
super(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Ignore individual classes
|
||||||
|
protected Set<Class<?>> exclude = new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
NoOp.class
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
//Ignore whole sets of classes based on regex
|
||||||
|
protected String[] ignoreRegexes = new String[]{
|
||||||
|
"org\\.nd4j\\.linalg\\.api\\.ops\\.impl\\.controlflow\\..*"
|
||||||
|
};
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void checkForINDArrayConstructors() throws Exception {
|
public void checkForINDArrayConstructors() throws Exception {
|
||||||
/*
|
/*
|
||||||
|
@ -38,11 +51,24 @@ public class OpConstructorTests extends BaseNd4jTest {
|
||||||
Set<Class<? extends DifferentialFunction>> classSet = f.getSubTypesOf(DifferentialFunction.class);
|
Set<Class<? extends DifferentialFunction>> classSet = f.getSubTypesOf(DifferentialFunction.class);
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
List<Class<?>> classes = new ArrayList<>();
|
||||||
for(Class<?> c : classSet){
|
for(Class<?> c : classSet){
|
||||||
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c))
|
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// System.out.println(c.getName());
|
if(exclude.contains(c))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
String cn = c.getName();
|
||||||
|
boolean ignored = false;
|
||||||
|
for(String s : ignoreRegexes ){
|
||||||
|
if(cn.matches(s)){
|
||||||
|
ignored = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(ignored)
|
||||||
|
continue;
|
||||||
|
|
||||||
Constructor<?>[] constructors = c.getConstructors();
|
Constructor<?>[] constructors = c.getConstructors();
|
||||||
boolean foundINDArray = false;
|
boolean foundINDArray = false;
|
||||||
|
@ -56,12 +82,22 @@ public class OpConstructorTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!foundINDArray){
|
if(!foundINDArray){
|
||||||
System.out.println("No INDArray constructor: " + c.getName());
|
classes.add(c);
|
||||||
count++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(0, count);
|
if(!classes.isEmpty()){
|
||||||
|
Collections.sort(classes, new Comparator<Class<?>>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Class<?> o1, Class<?> o2) {
|
||||||
|
return o1.getName().compareTo(o2.getName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
for(Class<?> c : classes){
|
||||||
|
System.out.println("No INDArray constructor: " + c.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals("Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors", 0, classes.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue