/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma, created on 26.02.2018 // #include namespace nd4j { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void addBias_(NDArray& input, const NDArray& bias, const bool isNCHW) { // input [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // bias [oC] X* inBuff = input.bufferAsT(); const Y* biasBuff = bias.bufferAsT(); int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; bS = input.sizeAt(0); const Nd4jLong stride0 = input.stridesOf()[0]; const Nd4jLong stride1 = input.stridesOf()[1]; const Nd4jLong stride2 = input.stridesOf()[2]; uint biasShapeInfoCast[MAX_RANK]; bool canCastBias = nd4j::DataTypeUtils::castShapeInfo(bias.getShapeInfo(), biasShapeInfoCast); if(isNCHW) { oC = input.sizeAt(1); oH = input.sizeAt(2); oW = input.sizeAt(3); const int oHoW = oH*oW; PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) for (int i = 0; i < bS; ++i) { for (int c = 0; c < oC; ++c) { auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias); auto inOffset = i * stride0 + c * stride1; PRAGMA_OMP_SIMD for (uint k = 0; k < oHoW; ++k) inBuff[inOffset + k] += static_cast(biasBuff[biasOffset]); } } } else { oC = input.sizeAt(3); oH = input.sizeAt(1); oW = input.sizeAt(2); PRAGMA_OMP_PARALLEL_FOR for (int i = 0; i < bS*oH*oW; ++i) { PRAGMA_OMP_SIMD for (int c = 0; c < oC; ++c) { auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias); inBuff[i * oC + c] += static_cast(biasBuff[biasOffset]); } } } } ////////////////////////////////////////////////////////////////////////// void addBias(NDArray& input, const NDArray& bias, const bool isNCHW) { BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, isNCHW), FLOAT_TYPES, FLOAT_TYPES); } BUILD_DOUBLE_TEMPLATE(template void addBias_, (NDArray& input, const NDArray& bias, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES); } } }