Small DL4J/SameDiff fixes (#70)
* More mask fixes + remove debugging println Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small batch norm derivative fixe Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f29f19e9e9
commit
06e4f5f96e
|
@ -216,8 +216,14 @@ public abstract class AbstractSameDiffLayer extends Layer {
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
||||||
} else if(input.rank() == 3){
|
} else if(input.rank() == 3){
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length]
|
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length]
|
||||||
|
} else if(input.rank() == 4){
|
||||||
|
//CNN style - return [mb, 1, 1, 1] for broadcast...
|
||||||
|
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
|
||||||
|
} else if(input.rank() == 5){
|
||||||
|
//CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
|
||||||
|
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("When using masking with rank 4+ inputs, the onesMaskForInput method must be implemented, " +
|
throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " +
|
||||||
"in order to determine the correct mask shape for this layer");
|
"in order to determine the correct mask shape for this layer");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,7 +166,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
System.out.println(dLdIn);
|
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -162,11 +162,11 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
List<SDVariable> ret = new ArrayList<>();
|
List<SDVariable> ret = new ArrayList<>();
|
||||||
List<SDVariable> inputs = new ArrayList<>();
|
List<SDVariable> inputs = new ArrayList<>(Arrays.asList(args()));
|
||||||
inputs.addAll(Arrays.asList(args()));
|
|
||||||
inputs.add(f1.get(0));
|
inputs.add(f1.get(0));
|
||||||
BatchNormDerivative batchNormDerivative = BatchNormDerivative.derivativeBuilder()
|
BatchNormDerivative batchNormDerivative = BatchNormDerivative.derivativeBuilder()
|
||||||
.sameDiff(sameDiff)
|
.sameDiff(sameDiff)
|
||||||
|
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||||
.applyGamma(applyGamma)
|
.applyGamma(applyGamma)
|
||||||
.applyBeta(applyBeta)
|
.applyBeta(applyBeta)
|
||||||
.epsilon(epsilon)
|
.epsilon(epsilon)
|
||||||
|
|
Loading…
Reference in New Issue