cavis/libnd4j/include/ops/declarable/helpers/cpu/ctcLoss.cpp

507 lines
24 KiB
C++
Raw Normal View History

/*******************************************************************************
* Copyright (c) 2021 Deeplearning4j Contributors
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/ctcLoss.h>
namespace sd
{
namespace ops
{
namespace helpers
{
//choose ptr[index*element_stride]
template <bool Strided, typename Type>
typename std::enable_if<Strided == true, Type &>::type
element(Type *ptr, int index, int element_stride)
{
return ptr[index * element_stride];
}
//choose ptr[index] assuming element_stride is 1
template <bool Strided, typename Type>
typename std::enable_if<Strided == false, Type &>::type
element(Type *ptr, int index, int element_stride)
{
return ptr[index];
}
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
{
Type negInf = -DataTypeUtils::infOrMax<Type>();
//initialize alphas at t=0
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
//alphaPtr[1] =logP[lbl[0]];
alphaPtr[1] = element<IsLogPStrided>(logP, *lbl, elwiseP);
//the rest initialization was skipped
//as its assumed the array already were initialized with negative infinity
//move to the next frame
Type *alphaPrevPtr = alphaPtr;
alphaPtr += incA;
logP += incP;
auto startX = lenSB - 2 * lenT;
//process the rest
for (auto t = 1; t < lenT; t++)
{
//start = max(0,L-2*(T-t))
auto s = startX + 2 * t;
s = s > 0 ? s : 0;
for (; s < lenSB; s++)
{
auto ind = s / 2; //our real index
//we force blanks for even indexes
//strided version of lbl[ind] => element<IsLblStrided>(lbl, ind, elwiseS)
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
// {t-1,s}
Type alphaS = alphaPrevPtr[s];
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
Type cMax = std::max(alphaS, alphaS_1);
//logP[currentInd] or logP[currentInd*elwiseP]
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
// if blank or the same as previous
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
{
Type alphaS_2 = alphaPrevPtr[s - 2];
cMax = std::max(cMax, alphaS_2);
if (cMax == negInf)
cMax = 0;
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
}
else
{
if (cMax == negInf)
cMax = 0;
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
}
}
//store t-1 alpha Ptr
alphaPrevPtr = alphaPtr;
logP += incP;
alphaPtr += incA;
}
auto logP0 = alphaPrevPtr[lenSB - 1];
auto logP1 = alphaPrevPtr[lenSB - 2];
auto cMax = std::max(logP0, logP1);
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
}
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
template <bool IsLogPStrided = false, bool IsLblStrided = false, bool isGradStrided = false, typename Type, typename IndexType = int>
void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl,
const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex,
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
{
Type negInf = -DataTypeUtils::infOrMax<Type>();
Nd4jLong lenSB = 2 * lenS + 1;
auto origBetta = bettaPtr;
auto origLogP = logP;
//move to the last frame
bettaPtr += (lenT - 1) * incA;
logP += (lenT - 1) * incP;
//initialize bettas at t=lenT
bettaPtr[lenSB - 1] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
auto lblIndex = element<IsLblStrided>(lbl, lenS - 1, elwiseS);
bettaPtr[lenSB - 2] = element<IsLogPStrided>(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]];
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//move to the last
gradPtr += (lenT - 1) * incG;
alphaPtr += (lenT - 1) * incA;
for (auto s = lenSB - 1; s >= 0; s--)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
}
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
gradPtr -= incG;
alphaPtr -= incA;
#endif
auto bettaPrevPtr = bettaPtr;
bettaPtr -= incA;
logP -= incP;
//process the rest
for (auto t = lenT - 2; t >= 0; t--)
{
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
auto end = lenSB - 1;
#else
auto end = std::min(2 * t + 2, lenSB - 1);
#endif
for (auto s = end; s >= 0; s--)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
// {t-1,s}
Type bettaS = bettaPrevPtr[s];
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
Type cMax = std::max(bettaS, bettaS_1);
//logP[currentInd]
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
// if blank or the same as previous
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
{
Type bettaS_2 = bettaPrevPtr[s + 2];
cMax = std::max(cMax, bettaS_2);
if (cMax == negInf)
cMax = 0;
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
}
else
{
if (cMax == negInf)
cMax = 0;
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
}
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
#endif
}
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
alphaPtr -= incA;
gradPtr -= incG;
#endif
bettaPrevPtr = bettaPtr;
bettaPtr -= incA;
logP -= incP;
}
auto logBP0 = bettaPrevPtr[0];
auto logBP1 = bettaPrevPtr[1];
auto bcMax = std::max(logBP0, logBP1);
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//alpha*betta
bettaPtr = origBetta;
logP = origLogP;
for (int t = 0; t < lenT; t++)
{
for (int s = 0; s < lenSB; s++)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
//alphaPtr[s] = alphaBettaS;
}
PRAGMA_OMP_SIMD
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
gradPtr += incG;
bettaPtr += incA;
alphaPtr += incA;
logP += incP;
}
#endif
}
/**
* Calculates ctc loss and fills gradients
* @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn)
* @param incP stride of logits for the next time frame
* @param gradPtr gradient for output
* @param incG stride of the gradient for the next time frame
* @param lbl target label
* @param lenT frame length
* @param lenK class length
* @param lenS target label length
* @param blankIndex index of the blank label in logit class
*/
template <bool IsLogPStrided = true, bool IsLblStrided = true, bool IsGradStrided = true, typename Type, typename IndexType>
Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex,
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
{
auto lenSB = 2 * lenS + 1;
//create temp Array for holding bettaArr [lenT,lenSB]
//create temp Array for holding alphaArr [lenT,lenSB]
int bufferC = gradPtr ? 2 : 1;
NDArray bufferArr = NDArrayFactory::create<Type>('c', {bufferC, lenT, lenSB});
auto bufferPtr = bufferArr.bufferAsT<Type>();
auto incA = bufferArr.stridesOf()[1];
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
Type negInf = -DataTypeUtils::infOrMax<Type>();
#if 1
if (gradPtr)
{
if (elwiseG == 1)
{
PRAGMA_OMP_SIMD
for (int i = 0; i < lenK * lenT; i++)
{
gradPtr[i] = negInf;
}
}
else
{
auto tempPtr = gradPtr;
for (int i = 0; i < lenT; i++)
{
for (int j = 0; j < lenK; j++)
element<false>(tempPtr, j, elwiseG) = negInf;
tempPtr += incG;
}
}
}
#endif
// set all vals to neginf
PRAGMA_OMP_SIMD
for (int i = 0; i < bufferC * lenSB * lenT; i++)
{
bufferPtr[i] = negInf;
}
//forward
Type logLoss = forward<IsLogPStrided, IsLblStrided>(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS);
//backward and gradient if gradptr supplied
if (gradPtr)
backwardAndGrad<IsLogPStrided, IsLblStrided, IsGradStrided>(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG);
return logLoss;
}
template <typename Type, typename IndexType>
void
ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex)
{
// lenT - input length of T
// lenS - lenght of sequence
// lenSB - length with blanks
auto lenBatch = logits.shapeOf()[0];
auto maxLenT = logits.shapeOf()[1];
auto lenK = logits.shapeOf()[2];
auto maxLenS = targetLabels.shapeOf()[1];
// get probability bufer and tagetLabels buffer
auto logP = logits.bufferAsT<Type>();
auto lblPtr = targetLabels.bufferAsT<IndexType>();
auto lenTPtr = logitsLengths.bufferAsT<IndexType>();
auto lenSPtr = targetLabelLengths.bufferAsT<IndexType>();
auto batchLbl = targetLabels.stridesOf()[0];
auto batchP = logits.stridesOf()[0];
auto incP = logits.stridesOf()[1];
auto elwiseSLen = targetLabelLengths.stridesOf()[0];
auto elwiseT = logitsLengths.stridesOf()[0];
auto elwiseS = targetLabels.stridesOf()[1];
auto elwiseP = logits.stridesOf()[2];
int elwiseLL = 0;
Type *logLossPtr = nullptr;
if (!logLosses.isEmpty()){
elwiseLL = logLosses.stridesOf()[0];
logLossPtr = logLosses.bufferAsT<Type>();
}
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Type *gradPtr = nullptr;
Type resultLoss;
int batchG, incG, elwiseG;
if (!gradients.isEmpty())
{
batchG = gradients.stridesOf()[0];
incG = gradients.stridesOf()[1];
elwiseG = gradients.stridesOf()[2];
gradPtr = gradients.bufferAsT<Type>() + start * batchG;
}else{
elwiseG=1;
}
auto logPtr = logP + start * batchP;
auto tempLblPtr = lblPtr + start * batchLbl;
if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1)
{
//choose ews one
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
{
auto lenT = lenTPtr[batchIndex * elwiseT];
auto lenS = lenSPtr[batchIndex * elwiseSLen];
lenT = lenT > maxLenT ? maxLenT : lenT;
lenS = lenS > maxLenS ? maxLenS : lenS;
if (lenS <= 0 || lenT <= 0)
{
resultLoss = -DataTypeUtils::infOrMax<Type>();
}
else
{
if (lenS > lenT)
lenS = lenT;
resultLoss = unitLossAndGrad<false, false, false, Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex);
}
if (gradPtr) gradPtr += batchG;
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
logPtr += batchP;
tempLblPtr += batchLbl;
}
}
else
{
//slow strided case for all 3
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
{
auto lenT = lenTPtr[batchIndex * elwiseT];
auto lenS = lenSPtr[batchIndex * elwiseSLen];
lenT = lenT > maxLenT ? maxLenT : lenT;
lenS = lenS > maxLenS ? maxLenS : lenS;
if (lenS <= 0 || lenT <= 0)
{
resultLoss = -DataTypeUtils::infOrMax<Type>();
}
else
{
if (lenS > lenT)
lenS = lenT;
resultLoss = unitLossAndGrad<true,true,true,Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG);
}
if (gradPtr) gradPtr += batchG;
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
logPtr += batchP;
tempLblPtr += batchLbl;
}
}
};
samediff::Threads::parallel_for(func, 0, lenBatch, 1);
}
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
} // namespace helpers
} // namespace ops
} // namespace sd