Various pre-release fixes (#111)

* Various fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix default dtypes for MaxPoolWithArgmax

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-05 14:20:03 +11:00 committed by GitHub
parent 63ed202057
commit 2052ce7026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 17 deletions

View File

@ -46,7 +46,7 @@ namespace nd4j {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT)
->setAllowedOutputTypes(1, DataType::INT64);
->setAllowedOutputTypes(1, {ALL_INTS});
}

View File

@ -103,6 +103,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,

View File

@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
@Override
public String[] tensorflowNames() {
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
}
@Override

View File

@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
if(attributesForNode.containsKey("argmax")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
} else {
outputType = DataType.UINT32;
outputType = DataType.LONG;
}
}
@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0));
result.add(outputType == null ? DataType.UINT32 : outputType);
result.add(outputType == null ? DataType.INT : outputType);
return result;
}
}

View File

@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(true)
.build();
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig);
assertArrayEquals(inArr.shape(), results[0].eval().shape());
assertArrayEquals(inArr.shape(), results[1].eval().shape());
}
@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("w", wArr);
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
INDArray expected = Nd4j.createFromArray(
new double[][][]{

View File

@ -23,13 +23,7 @@ import static org.junit.Assert.fail;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
public class ConvConfigTests {
@ -489,24 +483,24 @@ public class ConvConfigTests {
@Test
public void testConv1D(){
Conv1DConfig.builder().k(2).build();
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
try{
Conv1DConfig.builder().k(0).build();
Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel"));
}
try{
Conv1DConfig.builder().k(4).s(-2).build();
Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride"));
}
try{
Conv1DConfig.builder().k(3).p(-2).build();
Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding"));