Various pre-release fixes (#111)
* Various fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix default dtypes for MaxPoolWithArgmax Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
63ed202057
commit
2052ce7026
|
@ -46,7 +46,7 @@ namespace nd4j {
|
|||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
||||
->setAllowedOutputTypes(1, DataType::INT64);
|
||||
->setAllowedOutputTypes(1, {ALL_INTS});
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -103,6 +103,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,
|
||||
|
|
|
@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
|||
if(attributesForNode.containsKey("argmax")) {
|
||||
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
|
||||
} else {
|
||||
outputType = DataType.UINT32;
|
||||
outputType = DataType.LONG;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
|||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||
List<DataType> result = new ArrayList<>();
|
||||
result.add(inputDataTypes.get(0));
|
||||
result.add(outputType == null ? DataType.UINT32 : outputType);
|
||||
result.add(outputType == null ? DataType.INT : outputType);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
.isSameMode(true)
|
||||
.build();
|
||||
|
||||
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
|
||||
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig);
|
||||
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
||||
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
||||
}
|
||||
|
@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable in = sd.var("in", inArr);
|
||||
SDVariable w = sd.var("w", wArr);
|
||||
|
||||
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
|
||||
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(
|
||||
new double[][][]{
|
||||
|
|
|
@ -23,13 +23,7 @@ import static org.junit.Assert.fail;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||
|
||||
public class ConvConfigTests {
|
||||
|
||||
|
@ -489,24 +483,24 @@ public class ConvConfigTests {
|
|||
|
||||
@Test
|
||||
public void testConv1D(){
|
||||
Conv1DConfig.builder().k(2).build();
|
||||
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
||||
|
||||
try{
|
||||
Conv1DConfig.builder().k(0).build();
|
||||
Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build();
|
||||
fail();
|
||||
} catch (IllegalArgumentException e){
|
||||
assertTrue(e.getMessage().contains("Kernel"));
|
||||
}
|
||||
|
||||
try{
|
||||
Conv1DConfig.builder().k(4).s(-2).build();
|
||||
Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build();
|
||||
fail();
|
||||
} catch (IllegalArgumentException e){
|
||||
assertTrue(e.getMessage().contains("Stride"));
|
||||
}
|
||||
|
||||
try{
|
||||
Conv1DConfig.builder().k(3).p(-2).build();
|
||||
Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build();
|
||||
fail();
|
||||
} catch (IllegalArgumentException e){
|
||||
assertTrue(e.getMessage().contains("Padding"));
|
||||
|
|
Loading…
Reference in New Issue