Small number of fixes + cleanup + some missing op methods + constructors (#100)

* Remove unused op class - DefaultOpConverter

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add SDImage class; INDArray constructor additions

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Floordiv

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small polish to image methods

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small DataVec test fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-05 22:31:46 +10:00 committed by GitHub
parent 923ab15583
commit b8846113bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 188 additions and 79 deletions

View File

@ -34,8 +34,8 @@ import static org.junit.Assert.assertEquals;
public class RecordConverterTest {
@Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 3}, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 3}, DataType.FLOAT);
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),

View File

@ -67,14 +67,9 @@ namespace nd4j {
auto gradX = OUTPUT_VARIABLE(0);
auto gradY = OUTPUT_VARIABLE(1);
gradY->assign(x);
std::unique_ptr<NDArray> ySq(y->dup());
ySq->applyTransform(transform::Square, nullptr);
gradY->applyPairwiseTransform(pairwise::FloorDiv, ySq.get(), gradY, nullptr);
gradY->applyPairwiseTransform(pairwise::Multiply, epsNext, gradY, nullptr);
gradY->applyTransform(transform::Neg, nullptr);
gradX->assign(epsNext);
//gradX->applyPairwiseTransform(pairwise::FloorDiv, y, gradX, nullptr);
gradY->assign(0.0f);
gradX->assign(0.0f);
return Status::OK();
}

View File

@ -167,6 +167,8 @@ public class SameDiff extends SDBaseOps {
public final SDRNN rnn = new SDRNN(this);
/** Op creator object for loss function operations */
public final SDLoss loss = new SDLoss(this);
/** Op creator object for image operations */
public final SDImage image = new SDImage(this);
/** Op creator object for math operations */
public SDMath math(){
@ -198,6 +200,10 @@ public class SameDiff extends SDBaseOps {
return loss;
}
/** Op creator object for image operations */
public SDImage image(){
return image;
}
/**

View File

@ -0,0 +1,61 @@
package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
/**
* @author Alex Black
*/
public class SDImage extends SDOps {
public SDImage(SameDiff sameDiff) {
super(sameDiff);
}
/**
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
*
* @param name May be null. Name for the output variable.
* @param image Input image, with shape [batch, height, width, channels]
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes]
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth]
* @param method Image resize method
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
* @return Cropped and resized images
*/
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize,
CropAndResize.Method method, double extrapolationValue) {
SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
*
* @param name Map be null. Name for the output variable
* @param image Input image to extract image patches from - shape [batch, height, width, channels]
* @param kSizes Kernel size - size of the image patches, [height, width]
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width]
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension
* @param sameMode Padding algorithm. If true: use Same padding
* @return The extracted image patches
*/
public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes,
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode) {
SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable();
return updateVariableNameAndReference(out, name);
}
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
return updateVariableNameAndReference(out, name);
}
}

View File

@ -33,7 +33,6 @@ import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DefaultOpConverter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
@ -838,7 +837,6 @@ public class OpValidation {
//Exclude misc
DynamicCustomOp.class,
GradientBackwardsMarker.class,
DefaultOpConverter.class,
EqualsWithEps.class,
FreeGridOp.class,
MergeSum.class, //Redundant; we use MergeAdd in samediff instead

View File

@ -41,7 +41,6 @@ public class ImportClassMapping {
private static final Map<String, DifferentialFunction> ONNX_OP_NAME_MAP = new HashMap<>();
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
org.nd4j.linalg.api.ops.DefaultOpConverter.class,
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
org.nd4j.linalg.api.ops.NoOp.class,
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,

View File

@ -1,55 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.imports.NoOpNameFoundException;
import java.util.List;
public class DefaultOpConverter extends BaseOp {
private static DefaultOpConverter INSTANCE = new DefaultOpConverter();
public static DefaultOpConverter getInstance() {
return INSTANCE;
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "defaultop";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
}

View File

@ -135,14 +135,23 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
bArguments = new ArrayList<>();
}
/**
* Initialize this operation for execution (pre created ndarrays)
*
* @param inputs the inputs
* @param outputs the outputs of the op, may be null
*/
public DynamicCustomOp(INDArray[] inputs, INDArray[] outputs) {
this(null, inputs, outputs);
}
/**
* Initialize this operation for execution (pre created ndarrays)
*
* @param opName the operation opName to use
* for invocation
* @param opName the operation opName to use for invocation
* @param inputs the inputs
* @param outputs the outputs of the op
* @param outputs the outputs of the op, may be null
*/
public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) {
this(opName, inputs, outputs, Lists.<Double>newArrayList(), Lists.<Integer>newArrayList());
@ -600,6 +609,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
}
protected static INDArray[] wrapOrNull(INDArray in){
return in == null ? null : new INDArray[]{in};
}
public static class DynamicCustomOpsBuilder {
protected String opName;
protected int numInputs;

View File

@ -17,11 +17,13 @@
package org.nd4j.linalg.api.ops.impl.broadcast;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
@ -42,6 +44,10 @@ public class BiasAdd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {input, bias}, false);
}
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){
super(new INDArray[]{input, bias}, wrapOrNull(output));
}
@Override
public String opName() {
return "biasadd";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.broadcast;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -35,6 +36,10 @@ public class BiasAddGrad extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
}
public BiasAddGrad() {}
@Override

View File

@ -16,11 +16,14 @@
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
@ -33,11 +36,32 @@ import java.util.*;
* CropAndResize Op
* @author Alex Black
*/
@NoArgsConstructor
public class CropAndResize extends DynamicCustomOp {
public enum Method {BILINEAR, NEAREST};
protected Method method = Method.BILINEAR;
protected double extrapolationValue = 0.0;
public CropAndResize(@NonNull SameDiff sameDiff, @NonNull SDVariable image, @NonNull SDVariable cropBoxes, @NonNull SDVariable boxIndices,
@NonNull SDVariable cropOutSize, @NonNull Method method, double extrapolationValue){
super(sameDiff, new SDVariable[]{image, cropBoxes, boxIndices, cropOutSize});
this.method = method;
this.extrapolationValue = extrapolationValue;
addArgs();
}
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
Preconditions.checkArgument(boxIndices.rank() == 1 && cropBoxes.size(0) == boxIndices.size(0),
"Box indices must be rank 1 array with shape [num_boxes] (same as cropBoxes.size(0), got array with shape %ndShape", boxIndices);
this.method = method;
this.extrapolationValue = extrapolationValue;
addArgs();
}
@Override
public String opName() {
return "crop_and_resize";

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
@ -54,9 +55,27 @@ public class ExtractImagePatches extends DynamicCustomOp {
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
this.kSizes = kSizes;
this.strides = strides;
this.rates = rates;
this.isSameMode = sameMode;
addArgs();
}
public ExtractImagePatches(@NonNull INDArray input, @NonNull int[] kSizes,
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode){
super(new INDArray[]{input}, null);
Preconditions.checkState(kSizes.length == 2, "Expected exactly 2 kernel sizes, got %s", kSizes);
Preconditions.checkState(strides.length == 2, "Expected exactly 2 strides, got %s", strides);
Preconditions.checkState(rates.length == 2, "Expected exactly 2 rate values, got %s", rates);
this.kSizes = kSizes;
this.strides = strides;
this.rates = rates;
this.isSameMode = sameMode;
addArgs();
}
@Override
public String opName() {
return "extract_image_patches";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -27,7 +28,7 @@ import java.util.Collections;
import java.util.List;
/**
* IdentityN op wrapper
* Non max suppression
*
* @author raver119@gmail.com
*/
@ -35,8 +36,9 @@ public class NonMaxSuppression extends DynamicCustomOp {
public NonMaxSuppression() {}
public NonMaxSuppression(SameDiff sameDiff, SDVariable[] input) {
super(null, sameDiff, input, false);
public NonMaxSuppression(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) {
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
}
@Override

View File

@ -4,6 +4,7 @@ import org.junit.Test;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.reflections.Reflections;
@ -14,7 +15,7 @@ import org.reflections.util.FilterBuilder;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.Set;
import java.util.*;
import static org.junit.Assert.assertEquals;
@ -24,6 +25,18 @@ public class OpConstructorTests extends BaseNd4jTest {
super(backend);
}
//Ignore individual classes
protected Set<Class<?>> exclude = new HashSet<>(
Arrays.asList(
NoOp.class
)
);
//Ignore whole sets of classes based on regex
protected String[] ignoreRegexes = new String[]{
"org\\.nd4j\\.linalg\\.api\\.ops\\.impl\\.controlflow\\..*"
};
@Test
public void checkForINDArrayConstructors() throws Exception {
/*
@ -38,11 +51,24 @@ public class OpConstructorTests extends BaseNd4jTest {
Set<Class<? extends DifferentialFunction>> classSet = f.getSubTypesOf(DifferentialFunction.class);
int count = 0;
List<Class<?>> classes = new ArrayList<>();
for(Class<?> c : classSet){
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c))
continue;
// System.out.println(c.getName());
if(exclude.contains(c))
continue;
String cn = c.getName();
boolean ignored = false;
for(String s : ignoreRegexes ){
if(cn.matches(s)){
ignored = true;
break;
}
}
if(ignored)
continue;
Constructor<?>[] constructors = c.getConstructors();
boolean foundINDArray = false;
@ -56,12 +82,22 @@ public class OpConstructorTests extends BaseNd4jTest {
}
if(!foundINDArray){
System.out.println("No INDArray constructor: " + c.getName());
count++;
classes.add(c);
}
}
assertEquals(0, count);
if(!classes.isEmpty()){
Collections.sort(classes, new Comparator<Class<?>>() {
@Override
public int compare(Class<?> o1, Class<?> o2) {
return o1.getName().compareTo(o2.getName());
}
});
for(Class<?> c : classes){
System.out.println("No INDArray constructor: " + c.getName());
}
}
assertEquals("Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors", 0, classes.size());
}