ResizeBicubic added (#117)

* ResizeBicubic added
Some fixes.

* Test fixed

* Narrowed argument type changed to boolean

* Clean up
master
Alexander Stoyakin 2019-12-09 09:25:39 +02:00 committed by Alex Black
parent cea68c18f1
commit 927d591421
10 changed files with 159 additions and 45 deletions

View File

@ -52,12 +52,11 @@ namespace nd4j {
if (block.getIArguments() && block.getIArguments()->size())
numBits = INT_ARG(0);
bool narrowed = false;
//INT_ARG(1);
if (block.getIArguments()->size() == 2) {
numBits = INT_ARG(0);
narrowed = INT_ARG(1);
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits);
if (block.getBArguments() && block.getBArguments()->size()) {
narrowed = B_ARG(0);
}
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \
bits for quantization should be in between 2 and 16, but %i was given.", numBits);
helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output);
return ND4J_STATUS_OK;
}

View File

@ -2612,8 +2612,9 @@ public class DifferentialFunctionFactory {
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
}
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max,
int num_bits, boolean narrow) {
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable();
}
public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) {

View File

@ -87,6 +87,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
org.nd4j.linalg.api.ops.impl.image.ResizeBicubic.class,
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class,

View File

@ -21,30 +21,46 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
protected boolean narrowRange;
protected int numBits;
public FakeQuantWithMinMaxVarsPerChannel() {}
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
Preconditions.checkArgument(min.isVector() && max.isVector() &&
min.length() == max.length(),
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
inputArguments.add(x);
inputArguments.add(min);
inputArguments.add(max);
addInputArgument(x,min,max);
addIArgument(num_bits);
addBArgument(narrow);
}
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
INDArray output) {
this(x,min,max);
outputArguments.add(output);
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits) {
this(x, min, max, num_bits, false);
}
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, boolean narrow) {
this(x, min, max, 8, narrow);
}
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
this(x, min, max, 8, false);
}
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max,
int num_bits, boolean narrow) {
super("", sameDiff, new SDVariable[]{x, min, max});
addIArgument(num_bits);
addBArgument(narrow);
}
@Override
@ -57,6 +73,18 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
return "FakeQuantWithMinMaxVarsPerChannel";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("narrow_range")){
this.narrowRange = attributesForNode.get("narrow_range").getB();
}
if(attributesForNode.containsKey("num_bits")) {
this.numBits = (int) attributesForNode.get("num_bits").getI();
}
addIArgument(numBits);
addBArgument(narrowRange);
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 inputs, got %s", inputDataTypes);

View File

@ -0,0 +1,82 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* ResizeBicubic op wrapper
* @author Alexander Stoyakin
*/
@NoArgsConstructor
public class ResizeBicubic extends DynamicCustomOp {
protected boolean alignCorners = false;
protected boolean alignPixelCenters = false;
public ResizeBicubic(@NonNull INDArray image, INDArray size, boolean alignCorners, boolean alignPixelCenters) {
addInputArgument(image, size);
addBArgument(alignCorners, alignPixelCenters);
}
public ResizeBicubic(@NonNull SameDiff sameDiff, @NonNull SDVariable image,
SDVariable size, boolean alignCorners, boolean alignPixelCenters) {
super(sameDiff, new SDVariable[]{image, size});
addBArgument(alignCorners, alignPixelCenters);
}
@Override
public String opName() {
return "resize_bicubic";
}
@Override
public String tensorflowName() {
return "ResizeBicubic";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
this.alignCorners = attributesForNode.get("align_corners").getB();
this.alignPixelCenters = attributesForNode.get("half_pixel_centers").getB();
addBArgument(alignCorners, alignPixelCenters);
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(Nd4j.defaultFloatingPointType());
}
}

View File

@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -37,11 +38,21 @@ public class FakeQuantWithMinMaxArgs extends DynamicCustomOp {
addArgs();
}
public FakeQuantWithMinMaxArgs(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
Preconditions.checkArgument(min.isVector() && max.isVector() &&
min.length() == max.length(),
"FakeQuantWithMinMaxArgs: min and max should be 1D tensors with the same length");
addInputArgument(x,min,max);
addIArgument(num_bits);
addBArgument(narrow);
}
public FakeQuantWithMinMaxArgs(){ }
protected void addArgs(){
iArguments.clear();
addIArgument(numBits, narrowRange ? 1 : 0);
addIArgument(numBits);
addBArgument(narrowRange);
addTArgument(min, max);
}

View File

@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -33,11 +34,22 @@ public class FakeQuantWithMinMaxVars extends DynamicCustomOp {
addArgs();
}
public FakeQuantWithMinMaxVars(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
Preconditions.checkArgument(min.isVector() && max.isVector() &&
min.length() == max.length(),
"FakeQuantWithMinMaxVars: min and max should be 1D tensors with the same length");
addInputArgument(x,min,max);
addIArgument(num_bits);
addBArgument(narrow);
}
public FakeQuantWithMinMaxVars(){ }
protected void addArgs(){
iArguments.clear();
addIArgument(numBits, narrowRange ? 1 : 0);
bArguments.clear();
addIArgument(numBits);
addBArgument(narrowRange);
}
@Override

View File

@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp {
}
@Override
public String tensorflowName() {
return "Div";
public String[] tensorflowNames() {
return new String[]{"Div","RealDiv"};
}

View File

@ -111,29 +111,17 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
"zeros_like/rank2_float32_dtype_int.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
"fake_quant/min_max_args_per_channel.*",
// Suggesting TF 1.15 bug
"non_max_suppression_v2/float16.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
"betainc.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452
"polygamma.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
"roll/.*",
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
"matrix_band_part/.*",
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458
"adjust_hue/.*",
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459
"adjust_saturation/.*"
// 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507
"resize_bicubic/int32.*"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have

View File

@ -943,16 +943,9 @@ public class CustomOpsTests extends BaseNd4jTest {
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
INDArray out = Nd4j.createUninitialized(x.shape());
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
Nd4j.exec(op);
assertEquals(expected, out);
/*TF: [[ 0.7801, 0.5966, 0.7260, 0.2320, 0.5084],
[ 0.1800, 0.5046, 0.8684, 0.3513, 0.5084],
[ 0.0877, 0.5966, 0.6600, 0.3513, 0.1604]]
SD: [[ 0.7770, 0.5969, 0.7232, 0.2310, 0.5098],
[ 0.1793, 0.5053, 0.8685, 0.3500, 0.5098],
[ 0.0874, 0.5969, 0.6574, 0.3500, 0.1597]]*/
}
@Test
@ -1036,13 +1029,12 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray min = Nd4j.createFromArray(new float[]{-63.65f});
INDArray max = Nd4j.createFromArray(new float[]{0.1f});
INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1);
INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}).
reshape(1,2,3,1);
Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output));
INDArray[] output = Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max));
assertEquals(expected, output);
assertEquals(expected, output[0]);
}
@Test