parent
36db761917
commit
30b51f8085
|
@ -296,7 +296,7 @@ public class GlobalPoolingLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
|||
|
||||
switch (poolingType) {
|
||||
case MAX:
|
||||
INDArray isMax = Nd4j.exec(new IsMax(inputArray.dup(), poolDim))[0];
|
||||
INDArray isMax = Nd4j.exec(new IsMax(inputArray, inputArray.ulike(), poolDim))[0];
|
||||
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon, isMax, broadcastDims));
|
||||
case AVG:
|
||||
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
||||
|
|
|
@ -639,7 +639,7 @@ public class ConvolutionUtils {
|
|||
|
||||
INDArray output = Nd4j.createUninitialized(new int[]{(int)in.size(0), 1, outH, 1}, 'c');
|
||||
|
||||
DynamicCustomOp op = new MaxPooling2D(in, output, Pooling2DConfig.builder()
|
||||
DynamicCustomOp op = new MaxPooling2D(reshaped4d, output, Pooling2DConfig.builder()
|
||||
.kH(k[0]).kW(k[1])
|
||||
.sH(s[0]).sW(s[1])
|
||||
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
|
||||
|
|
|
@ -136,7 +136,7 @@ public class MaskedReductionUtil {
|
|||
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
||||
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
||||
|
||||
INDArray isMax = Nd4j.exec(new IsMax(withInf, 2))[0];
|
||||
INDArray isMax = Nd4j.exec(new IsMax(withInf, withInf.ulike(), 2))[0];
|
||||
|
||||
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
|
||||
case AVG:
|
||||
|
@ -296,7 +296,7 @@ public class MaskedReductionUtil {
|
|||
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, dimensions));
|
||||
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
||||
|
||||
INDArray isMax = Nd4j.exec(new IsMax(withInf, 2, 3))[0];
|
||||
INDArray isMax = Nd4j.exec(new IsMax(withInf, withInf.ulike(), 2, 3))[0];
|
||||
|
||||
return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1));
|
||||
case AVG:
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TestName;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
|
|
Loading…
Reference in New Issue