Small DL4J/SameDiff fixes (#70)

* More mask fixes + remove debugging println

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small batch norm derivative fixe

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-20 12:37:34 +10:00 committed by AlexDBlack
parent f29f19e9e9
commit 06e4f5f96e
3 changed files with 9 additions and 4 deletions

View File

@ -216,8 +216,14 @@ public abstract class AbstractSameDiffLayer extends Layer {
return Nd4j.ones(input.dataType(), input.size(0), 1);
} else if(input.rank() == 3){
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length]
} else if(input.rank() == 4){
//CNN style - return [mb, 1, 1, 1] for broadcast...
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
} else if(input.rank() == 5){
//CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
} else {
throw new IllegalStateException("When using masking with rank 4+ inputs, the onesMaskForInput method must be implemented, " +
throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " +
"in order to determine the correct mask shape for this layer");
}
}

View File

@ -166,7 +166,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
System.out.println(dLdIn);
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
}

View File

@ -162,11 +162,11 @@ public class BatchNorm extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
List<SDVariable> ret = new ArrayList<>();
List<SDVariable> inputs = new ArrayList<>();
inputs.addAll(Arrays.asList(args()));
List<SDVariable> inputs = new ArrayList<>(Arrays.asList(args()));
inputs.add(f1.get(0));
BatchNormDerivative batchNormDerivative = BatchNormDerivative.derivativeBuilder()
.sameDiff(sameDiff)
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.applyGamma(applyGamma)
.applyBeta(applyBeta)
.epsilon(epsilon)