Various fixes (DL4J, ND4J) (#147)

* Import fixes, IsMax dtype calc, small test fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SubsamplingLayer fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J - SpaceToBatch layer updates

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-22 16:16:03 +10:00 committed by GitHub
parent ca7e5593ec
commit 9c2bfc9863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 21 deletions

View File

@ -298,7 +298,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int inputDepth = 1; int inputDepth = 1;
int[] kernel = {2, 2}; int[] kernel = {2, 2};
int[] blocks = {1, 1}; int[] blocks = {2, 2};
String[] activations = {"sigmoid", "tanh"}; String[] activations = {"sigmoid", "tanh"};
SubsamplingLayer.PoolingType[] poolingTypes = SubsamplingLayer.PoolingType[] poolingTypes =
@ -309,8 +309,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < 4 * minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
} }
@ -318,11 +318,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new NeuralNetConfiguration.Builder() new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(new NormalDistribution(0, 1)) .updater(new NoOp()).weightInit(new NormalDistribution(0, 1))
.list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth) .list()
.nOut(3).build())//output: (5-2+0)/1+1 = 4 .layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).nOut(3).build())
.layer(new SpaceToBatchLayer.Builder(blocks).build()) //trivial space to batch .layer(new SpaceToBatchLayer.Builder(blocks).build()) //trivial space to batch
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4 * 4 * 3) .activation(Activation.SOFTMAX)
.nOut(nOut).build()) .nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)) .setInputType(InputType.convolutionalFlat(height, width, inputDepth))
.build(); .build();

View File

@ -100,12 +100,15 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
Gradient gradient = new DefaultGradient(); Gradient gradient = new DefaultGradient();
CustomOp op = DynamicCustomOp.builder("batch_to_space") INDArray epsilonNHWC = epsilon.permute(0, 2, 3, 1);
.addInputs(epsilon, getBlocksArray(), getPaddingArray()) INDArray outEpsilonNHWC = outEpsilon.permute(0, 2, 3, 1);
.addOutputs(outEpsilon)
CustomOp op = DynamicCustomOp.builder("batch_to_space_nd")
.addInputs(epsilonNHWC, getBlocksArray(), getPaddingArray())
.addOutputs(outEpsilonNHWC)
.callInplace(false) .callInplace(false)
.build(); .build();
Nd4j.getExecutioner().exec(op); Nd4j.exec(op);
outEpsilon = backpropDropOutIfPresent(outEpsilon); outEpsilon = backpropDropOutIfPresent(outEpsilon);
return new Pair<>(gradient, outEpsilon); return new Pair<>(gradient, outEpsilon);
@ -143,11 +146,14 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{outMiniBatch, depth, outH, outW}, 'c'); INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{outMiniBatch, depth, outH, outW}, 'c');
CustomOp op = DynamicCustomOp.builder("space_to_batch") INDArray inNHWC = input.permute(0, 2, 3, 1);
.addInputs(input, getBlocksArray(), getPaddingArray()) INDArray outNHWC = out.permute(0, 2, 3, 1);
.addOutputs(out)
CustomOp op = DynamicCustomOp.builder("space_to_batch_nd")
.addInputs(inNHWC, getBlocksArray(), getPaddingArray())
.addOutputs(outNHWC)
.build(); .build();
Nd4j.getExecutioner().exec(op); Nd4j.exec(op);
return out; return out;
} }

View File

@ -172,7 +172,7 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
b = DynamicCustomOp.builder("maxpool2d_bp"); b = DynamicCustomOp.builder("maxpool2d_bp");
break; break;
case AVG: case AVG:
b = DynamicCustomOp.builder("maxpool2d_bp"); b = DynamicCustomOp.builder("avgpool2d_bp");
if(layerConf().isAvgPoolIncludePadInDivisor()){ if(layerConf().isAvgPoolIncludePadInDivisor()){
//Mostly this is a legacy case - beta4 and earlier models. //Mostly this is a legacy case - beta4 and earlier models.
extra = 1; //Divide by "number present" excluding padding extra = 1; //Divide by "number present" excluding padding

View File

@ -350,9 +350,12 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention.class, org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp.class,
@ -396,8 +399,11 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax.class, org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch.class, org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatchND.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Svd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Svd.class,
@ -457,6 +463,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class,
@ -559,11 +566,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class, org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
org.nd4j.linalg.api.ops.random.impl.Range.class, org.nd4j.linalg.api.ops.random.impl.Range.class,
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class, org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class
org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class
); );

View File

@ -79,4 +79,10 @@ public class IsMax extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().zerosLike(arg())); return Collections.singletonList(f().zerosLike(arg()));
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
//Also supports other types if say float array is provided as output array
return Collections.singletonList(DataType.BOOL);
}
} }

View File

@ -1249,7 +1249,7 @@ public class TransformOpValidation extends BaseOpValidation {
case 2: case 2:
//TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872 //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872
inArr = Nd4j.create(new double[]{-3,5,0,2}); inArr = Nd4j.create(new double[]{-3,5,0,2});
exp = Nd4j.create(new boolean[]{false,true,false,false}).castTo(DataType.DOUBLE); exp = Nd4j.create(new boolean[]{false,true,false,false});
out = sd.math().isMax(in); out = sd.math().isMax(in);
break; break;
case 3: case 3: