Various pre-release fixes (#111)

* Various fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix default dtypes for MaxPoolWithArgmax

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-05 14:20:03 +11:00 committed by GitHub
parent 63ed202057
commit 2052ce7026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 17 deletions

View File

@ -46,7 +46,7 @@ namespace nd4j {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT) ->setAllowedOutputTypes(0, DataType::INHERIT)
->setAllowedOutputTypes(1, DataType::INT64); ->setAllowedOutputTypes(1, {ALL_INTS});
} }

View File

@ -103,6 +103,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,

View File

@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
@Override @Override
public String[] tensorflowNames() { public String[] tensorflowNames() {
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
} }
@Override @Override

View File

@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
if(attributesForNode.containsKey("argmax")) { if(attributesForNode.containsKey("argmax")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
} else { } else {
outputType = DataType.UINT32; outputType = DataType.LONG;
} }
} }
@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>(); List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0)); result.add(inputDataTypes.get(0));
result.add(outputType == null ? DataType.UINT32 : outputType); result.add(outputType == null ? DataType.INT : outputType);
return result; return result;
} }
} }

View File

@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(true) .isSameMode(true)
.build(); .build();
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig);
assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[0].eval().shape());
assertArrayEquals(inArr.shape(), results[1].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape());
} }
@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("w", wArr); SDVariable w = sd.var("w", wArr);
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build()); SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
INDArray expected = Nd4j.createFromArray( INDArray expected = Nd4j.createFromArray(
new double[][][]{ new double[][][]{

View File

@ -23,13 +23,7 @@ import static org.junit.Assert.fail;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
public class ConvConfigTests { public class ConvConfigTests {
@ -489,24 +483,24 @@ public class ConvConfigTests {
@Test @Test
public void testConv1D(){ public void testConv1D(){
Conv1DConfig.builder().k(2).build(); Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
try{ try{
Conv1DConfig.builder().k(0).build(); Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build();
fail(); fail();
} catch (IllegalArgumentException e){ } catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel")); assertTrue(e.getMessage().contains("Kernel"));
} }
try{ try{
Conv1DConfig.builder().k(4).s(-2).build(); Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build();
fail(); fail();
} catch (IllegalArgumentException e){ } catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride")); assertTrue(e.getMessage().contains("Stride"));
} }
try{ try{
Conv1DConfig.builder().k(3).p(-2).build(); Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build();
fail(); fail();
} catch (IllegalArgumentException e){ } catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding")); assertTrue(e.getMessage().contains("Padding"));