Various fixes (DL4J, ND4J) (#147)
* Import fixes, IsMax dtype calc, small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * SubsamplingLayer fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J - SpaceToBatch layer updates Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
ca7e5593ec
commit
9c2bfc9863
|
@ -298,7 +298,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
int inputDepth = 1;
|
||||
|
||||
int[] kernel = {2, 2};
|
||||
int[] blocks = {1, 1};
|
||||
int[] blocks = {2, 2};
|
||||
|
||||
String[] activations = {"sigmoid", "tanh"};
|
||||
SubsamplingLayer.PoolingType[] poolingTypes =
|
||||
|
@ -309,8 +309,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||
for (int minibatchSize : minibatchSizes) {
|
||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||
for (int i = 0; i < minibatchSize; i++) {
|
||||
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
|
||||
for (int i = 0; i < 4 * minibatchSize; i++) {
|
||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||
}
|
||||
|
||||
|
@ -318,11 +318,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.updater(new NoOp()).weightInit(new NormalDistribution(0, 1))
|
||||
.list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth)
|
||||
.nOut(3).build())//output: (5-2+0)/1+1 = 4
|
||||
.list()
|
||||
.layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).nOut(3).build())
|
||||
.layer(new SpaceToBatchLayer.Builder(blocks).build()) //trivial space to batch
|
||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(4 * 4 * 3)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nOut(nOut).build())
|
||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
|
||||
.build();
|
||||
|
|
|
@ -100,12 +100,15 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
Gradient gradient = new DefaultGradient();
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("batch_to_space")
|
||||
.addInputs(epsilon, getBlocksArray(), getPaddingArray())
|
||||
.addOutputs(outEpsilon)
|
||||
INDArray epsilonNHWC = epsilon.permute(0, 2, 3, 1);
|
||||
INDArray outEpsilonNHWC = outEpsilon.permute(0, 2, 3, 1);
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("batch_to_space_nd")
|
||||
.addInputs(epsilonNHWC, getBlocksArray(), getPaddingArray())
|
||||
.addOutputs(outEpsilonNHWC)
|
||||
.callInplace(false)
|
||||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.exec(op);
|
||||
|
||||
outEpsilon = backpropDropOutIfPresent(outEpsilon);
|
||||
return new Pair<>(gradient, outEpsilon);
|
||||
|
@ -143,11 +146,14 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{outMiniBatch, depth, outH, outW}, 'c');
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("space_to_batch")
|
||||
.addInputs(input, getBlocksArray(), getPaddingArray())
|
||||
.addOutputs(out)
|
||||
INDArray inNHWC = input.permute(0, 2, 3, 1);
|
||||
INDArray outNHWC = out.permute(0, 2, 3, 1);
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("space_to_batch_nd")
|
||||
.addInputs(inNHWC, getBlocksArray(), getPaddingArray())
|
||||
.addOutputs(outNHWC)
|
||||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.exec(op);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
|||
b = DynamicCustomOp.builder("maxpool2d_bp");
|
||||
break;
|
||||
case AVG:
|
||||
b = DynamicCustomOp.builder("maxpool2d_bp");
|
||||
b = DynamicCustomOp.builder("avgpool2d_bp");
|
||||
if(layerConf().isAvgPoolIncludePadInDivisor()){
|
||||
//Mostly this is a legacy case - beta4 and earlier models.
|
||||
extra = 1; //Divide by "number present" excluding padding
|
||||
|
|
|
@ -350,9 +350,12 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp.class,
|
||||
|
@ -396,8 +399,11 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatchND.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Svd.class,
|
||||
|
@ -457,6 +463,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class,
|
||||
|
@ -559,11 +566,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class
|
||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class
|
||||
|
||||
);
|
||||
|
||||
|
|
|
@ -79,4 +79,10 @@ public class IsMax extends DynamicCustomOp {
|
|||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return Collections.singletonList(f().zerosLike(arg()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
//Also supports other types if say float array is provided as output array
|
||||
return Collections.singletonList(DataType.BOOL);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1249,7 +1249,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
case 2:
|
||||
//TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||
inArr = Nd4j.create(new double[]{-3,5,0,2});
|
||||
exp = Nd4j.create(new boolean[]{false,true,false,false}).castTo(DataType.DOUBLE);
|
||||
exp = Nd4j.create(new boolean[]{false,true,false,false});
|
||||
out = sd.math().isMax(in);
|
||||
break;
|
||||
case 3:
|
||||
|
|
Loading…
Reference in New Issue