105 lines
3.5 KiB
Java
105 lines
3.5 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j.nn.params;
|
|
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import org.deeplearning4j.nn.api.AbstractParamInitializer;
|
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
public class WrapperLayerParamInitializer extends AbstractParamInitializer {
|
|
|
|
private static final WrapperLayerParamInitializer INSTANCE = new WrapperLayerParamInitializer();
|
|
|
|
public static WrapperLayerParamInitializer getInstance(){
|
|
return INSTANCE;
|
|
}
|
|
|
|
private WrapperLayerParamInitializer(){
|
|
|
|
}
|
|
|
|
@Override
|
|
public long numParams(LayerConfiguration layer) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().numParams(l);
|
|
}
|
|
|
|
@Override
|
|
public List<String> paramKeys(LayerConfiguration layer) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().paramKeys(l);
|
|
}
|
|
|
|
@Override
|
|
public List<String> weightKeys(LayerConfiguration layer) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().weightKeys(l);
|
|
}
|
|
|
|
@Override
|
|
public List<String> biasKeys(LayerConfiguration layer) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().biasKeys(l);
|
|
}
|
|
|
|
@Override
|
|
public boolean isWeightParam(LayerConfiguration layer, String key) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().isWeightParam(layer, key);
|
|
}
|
|
|
|
@Override
|
|
public boolean isBiasParam(LayerConfiguration layer, String key) {
|
|
LayerConfiguration l = underlying(layer);
|
|
return l.initializer().isBiasParam(layer, key);
|
|
}
|
|
|
|
@Override
|
|
public Map<String, INDArray> init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) {
|
|
LayerConfiguration orig = conf;
|
|
LayerConfiguration l = underlying(conf);
|
|
|
|
Map<String,INDArray> m = l.initializer().init(conf, paramsView, initializeParams);
|
|
|
|
return m;
|
|
}
|
|
|
|
@Override
|
|
public Map<String, INDArray> getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) {
|
|
LayerConfiguration orig = conf;
|
|
LayerConfiguration l = underlying(conf);
|
|
|
|
Map<String,INDArray> m = l.initializer().getGradientsFromFlattened(conf, gradientView);
|
|
|
|
return m;
|
|
}
|
|
|
|
private LayerConfiguration underlying(LayerConfiguration layer){
|
|
while (layer instanceof BaseWrapperLayerConfiguration) {
|
|
layer = ((BaseWrapperLayerConfiguration)layer).getUnderlying();
|
|
}
|
|
return layer;
|
|
}
|
|
}
|