cavis/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitVarScalingNormalFanAvg.java
Alex Black 09a827fb6d
Fixes and pre-release QA (#51)
* #8395 Keras import - support scaled identity weight init

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Keras scaled weight init fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8352 Deprecate duplicate SamplingDataSetIterator class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove /O2 optimization for faster CUDA build

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak regression test precision for CUDA

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix edge cases for buffer creation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update MKLDNN validation tests to new helper enable/disable settings

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Delete debugging class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* MKLDNN test - add proper skip for CUDA backend

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Align WeightInitUtil with weight init classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for SameDiff test layers weight init when using IWeightInit classes

Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-11-16 17:04:29 +11:00

55 lines
1.8 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.weights;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.factory.Nd4j;
/**
* Truncated aussian distribution with mean 0, variance 1.0/((fanIn + fanOut)/2)
*
* @author Adam Gibson
*/
@Data
@NoArgsConstructor
public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
private Double scale;
public WeightInitVarScalingNormalFanAvg(Double scale){
this.scale = scale;
}
@Override
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
double std;
if(scale == null){
std = Math.sqrt(2.0 / (fanIn + fanOut));
} else {
std = Math.sqrt(2.0 * scale / (fanIn + fanOut));
}
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape);
}
}