cavis/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java

182 lines
7.3 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.nn.params;
import lombok.val;
import org.deeplearning4j.nn.api.AbstractParamInitializer;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.*;
public class BatchNormalizationParamInitializer extends AbstractParamInitializer {
2019-06-06 15:21:15 +03:00
private static final BatchNormalizationParamInitializer INSTANCE = new BatchNormalizationParamInitializer();
public static BatchNormalizationParamInitializer getInstance() {
return INSTANCE;
}
public static final String GAMMA = "gamma";
public static final String BETA = "beta";
public static final String GLOBAL_MEAN = "mean";
public static final String GLOBAL_VAR = "var";
public static final String GLOBAL_LOG_STD = "log10stdev";
@Override
public long numParams(LayerConfiguration l) {
2019-06-06 15:21:15 +03:00
BatchNormalization layer = (BatchNormalization) l;
//Parameters in batch norm:
//gamma, beta, global mean estimate, global variance estimate
// latter 2 are treated as parameters, which greatly simplifies spark training and model serialization
if (layer.isLockGammaBeta()) {
//Special case: gamma and beta are fixed values for all outputs -> no parameters for gamma and beta in this case
return 2 * layer.getNOut();
} else {
//Standard case: gamma and beta are learned per output; additional 2*nOut for global mean/variance estimate
return 4 * layer.getNOut();
}
}
@Override
public List<String> paramKeys(LayerConfiguration layer) {
2019-06-06 15:21:15 +03:00
if(((BatchNormalization)layer).isUseLogStd()){
return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_LOG_STD);
} else {
return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_VAR);
}
}
@Override
public List<String> weightKeys(LayerConfiguration layer) {
2019-06-06 15:21:15 +03:00
return Collections.emptyList();
}
@Override
public List<String> biasKeys(LayerConfiguration layer) {
2019-06-06 15:21:15 +03:00
return Collections.emptyList();
}
@Override
public boolean isWeightParam(LayerConfiguration layer, String key) {
2019-06-06 15:21:15 +03:00
return false;
}
@Override
public boolean isBiasParam(LayerConfiguration layer, String key) {
2019-06-06 15:21:15 +03:00
return false;
}
@Override
public Map<String, INDArray> init(LayerConfiguration conf, INDArray paramView, boolean initializeParams) {
2019-06-06 15:21:15 +03:00
Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
// TODO setup for RNN
BatchNormalization layer = (BatchNormalization) conf;
2019-06-06 15:21:15 +03:00
val nOut = layer.getNOut();
long meanOffset = 0;
if (!layer.isLockGammaBeta()) { //No gamma/beta parameters when gamma/beta are locked
INDArray gammaView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut));
INDArray betaView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, 2 * nOut));
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(GAMMA);
2019-06-06 15:21:15 +03:00
params.put(BETA, createBeta(conf, betaView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(BETA);
2019-06-06 15:21:15 +03:00
meanOffset = 2 * nOut;
}
INDArray globalMeanView =
paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset, meanOffset + nOut));
INDArray globalVarView = paramView.get(NDArrayIndex.interval(0,0,true),
NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut));
if (initializeParams) {
globalMeanView.assign(0);
if(layer.isUseLogStd()){
//Global log stdev: assign 0.0 as initial value (s=sqrt(v), and log10(s) = log10(sqrt(v)) -> log10(1) = 0
globalVarView.assign(0);
} else {
//Global variance view: assign 1.0 as initial value
globalVarView.assign(1);
}
}
params.put(GLOBAL_MEAN, globalMeanView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN);
2019-06-06 15:21:15 +03:00
if(layer.isUseLogStd()){
params.put(GLOBAL_LOG_STD, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD);
2019-06-06 15:21:15 +03:00
} else {
params.put(GLOBAL_VAR, globalVarView);
conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR);
2019-06-06 15:21:15 +03:00
}
return params;
}
@Override
public Map<String, INDArray> getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) {
BatchNormalization layer = (BatchNormalization) conf;
2019-06-06 15:21:15 +03:00
val nOut = layer.getNOut();
Map<String, INDArray> out = new LinkedHashMap<>();
long meanOffset = 0;
if (!layer.isLockGammaBeta()) {
INDArray gammaView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nOut));
INDArray betaView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, 2 * nOut));
out.put(GAMMA, gammaView);
out.put(BETA, betaView);
meanOffset = 2 * nOut;
}
out.put(GLOBAL_MEAN,
gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(meanOffset, meanOffset + nOut)));
if(layer.isUseLogStd()){
out.put(GLOBAL_LOG_STD, gradientView.get(NDArrayIndex.interval(0,0,true),
NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)));
} else {
out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.interval(0,0,true),
NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)));
}
return out;
}
private INDArray createBeta(LayerConfiguration conf, INDArray betaView, boolean initializeParams) {
BatchNormalization layer = (BatchNormalization) conf;
2019-06-06 15:21:15 +03:00
if (initializeParams)
betaView.assign(layer.getBeta());
return betaView;
}
private INDArray createGamma(LayerConfiguration conf, INDArray gammaView, boolean initializeParams) {
BatchNormalization layer = (BatchNormalization) conf;
2019-06-06 15:21:15 +03:00
if (initializeParams)
gammaView.assign(layer.getGamma());
return gammaView;
}
}