100 lines
3.3 KiB
C++
100 lines
3.3 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author Yurii Shyrma, created on 26.02.2018
|
||
|
//
|
||
|
|
||
|
|
||
|
#include<ops/declarable/helpers/addBias.h>
|
||
|
|
||
|
namespace nd4j {
|
||
|
namespace ops {
|
||
|
namespace helpers {
|
||
|
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
template <typename X, typename Y>
|
||
|
static void addBias_(NDArray& input, const NDArray& bias, const bool isNCHW) {
|
||
|
|
||
|
// input [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||
|
// bias [oC]
|
||
|
|
||
|
X* inBuff = input.bufferAsT<X>();
|
||
|
const Y* biasBuff = bias.bufferAsT<Y>();
|
||
|
|
||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||
|
bS = input.sizeAt(0);
|
||
|
const Nd4jLong stride0 = input.stridesOf()[0];
|
||
|
const Nd4jLong stride1 = input.stridesOf()[1];
|
||
|
const Nd4jLong stride2 = input.stridesOf()[2];
|
||
|
|
||
|
uint biasShapeInfoCast[MAX_RANK];
|
||
|
bool canCastBias = nd4j::DataTypeUtils::castShapeInfo(bias.getShapeInfo(), biasShapeInfoCast);
|
||
|
|
||
|
if(isNCHW) {
|
||
|
|
||
|
oC = input.sizeAt(1);
|
||
|
oH = input.sizeAt(2);
|
||
|
oW = input.sizeAt(3);
|
||
|
|
||
|
const int oHoW = oH*oW;
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
|
||
|
for (int i = 0; i < bS; ++i) {
|
||
|
for (int c = 0; c < oC; ++c) {
|
||
|
|
||
|
auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias);
|
||
|
auto inOffset = i * stride0 + c * stride1;
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (uint k = 0; k < oHoW; ++k)
|
||
|
inBuff[inOffset + k] += static_cast<X>(biasBuff[biasOffset]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
oC = input.sizeAt(3);
|
||
|
oH = input.sizeAt(1);
|
||
|
oW = input.sizeAt(2);
|
||
|
|
||
|
PRAGMA_OMP_PARALLEL_FOR
|
||
|
for (int i = 0; i < bS*oH*oW; ++i) {
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int c = 0; c < oC; ++c) {
|
||
|
auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias);
|
||
|
inBuff[i * oC + c] += static_cast<X>(biasBuff[biasOffset]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//////////////////////////////////////////////////////////////////////////
|
||
|
void addBias(NDArray& input, const NDArray& bias, const bool isNCHW) {
|
||
|
|
||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, isNCHW), FLOAT_TYPES, FLOAT_TYPES);
|
||
|
}
|
||
|
|
||
|
|
||
|
BUILD_DOUBLE_TEMPLATE(template void addBias_, (NDArray& input, const NDArray& bias, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES);
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|