507 lines
24 KiB
C++
507 lines
24 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2021 Deeplearning4j Contributors
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
*******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author AbdelRauf
|
||
|
//
|
||
|
|
||
|
#include <type_traits>
|
||
|
#include <cmath>
|
||
|
#include <stdexcept>
|
||
|
#include <memory>
|
||
|
#include <execution/Threads.h>
|
||
|
#include <execution/ThreadPool.h>
|
||
|
#include <helpers/LoopsCoordsHelper.h>
|
||
|
#include <ops/declarable/helpers/ctcLoss.h>
|
||
|
|
||
|
namespace sd
|
||
|
{
|
||
|
namespace ops
|
||
|
{
|
||
|
namespace helpers
|
||
|
{
|
||
|
|
||
|
//choose ptr[index*element_stride]
|
||
|
template <bool Strided, typename Type>
|
||
|
typename std::enable_if<Strided == true, Type &>::type
|
||
|
element(Type *ptr, int index, int element_stride)
|
||
|
{
|
||
|
return ptr[index * element_stride];
|
||
|
}
|
||
|
|
||
|
//choose ptr[index] assuming element_stride is 1
|
||
|
template <bool Strided, typename Type>
|
||
|
typename std::enable_if<Strided == false, Type &>::type
|
||
|
element(Type *ptr, int index, int element_stride)
|
||
|
{
|
||
|
return ptr[index];
|
||
|
}
|
||
|
|
||
|
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
|
||
|
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
|
||
|
{
|
||
|
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||
|
//initialize alphas at t=0
|
||
|
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||
|
//alphaPtr[1] =logP[lbl[0]];
|
||
|
alphaPtr[1] = element<IsLogPStrided>(logP, *lbl, elwiseP);
|
||
|
//the rest initialization was skipped
|
||
|
//as its assumed the array already were initialized with negative infinity
|
||
|
//move to the next frame
|
||
|
Type *alphaPrevPtr = alphaPtr;
|
||
|
alphaPtr += incA;
|
||
|
logP += incP;
|
||
|
|
||
|
auto startX = lenSB - 2 * lenT;
|
||
|
//process the rest
|
||
|
for (auto t = 1; t < lenT; t++)
|
||
|
{
|
||
|
|
||
|
//start = max(0,L-2*(T-t))
|
||
|
auto s = startX + 2 * t;
|
||
|
s = s > 0 ? s : 0;
|
||
|
for (; s < lenSB; s++)
|
||
|
{
|
||
|
auto ind = s / 2; //our real index
|
||
|
//we force blanks for even indexes
|
||
|
//strided version of lbl[ind] => element<IsLblStrided>(lbl, ind, elwiseS)
|
||
|
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||
|
// {t-1,s}
|
||
|
Type alphaS = alphaPrevPtr[s];
|
||
|
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
|
||
|
Type cMax = std::max(alphaS, alphaS_1);
|
||
|
//logP[currentInd] or logP[currentInd*elwiseP]
|
||
|
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||
|
// if blank or the same as previous
|
||
|
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
|
||
|
{
|
||
|
Type alphaS_2 = alphaPrevPtr[s - 2];
|
||
|
cMax = std::max(cMax, alphaS_2);
|
||
|
if (cMax == negInf)
|
||
|
cMax = 0;
|
||
|
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
if (cMax == negInf)
|
||
|
cMax = 0;
|
||
|
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//store t-1 alpha Ptr
|
||
|
alphaPrevPtr = alphaPtr;
|
||
|
logP += incP;
|
||
|
alphaPtr += incA;
|
||
|
}
|
||
|
auto logP0 = alphaPrevPtr[lenSB - 1];
|
||
|
auto logP1 = alphaPrevPtr[lenSB - 2];
|
||
|
auto cMax = std::max(logP0, logP1);
|
||
|
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
|
||
|
}
|
||
|
|
||
|
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
|
||
|
|
||
|
template <bool IsLogPStrided = false, bool IsLblStrided = false, bool isGradStrided = false, typename Type, typename IndexType = int>
|
||
|
void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl,
|
||
|
const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex,
|
||
|
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||
|
{
|
||
|
|
||
|
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||
|
Nd4jLong lenSB = 2 * lenS + 1;
|
||
|
auto origBetta = bettaPtr;
|
||
|
auto origLogP = logP;
|
||
|
//move to the last frame
|
||
|
bettaPtr += (lenT - 1) * incA;
|
||
|
logP += (lenT - 1) * incP;
|
||
|
|
||
|
//initialize bettas at t=lenT
|
||
|
bettaPtr[lenSB - 1] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||
|
auto lblIndex = element<IsLblStrided>(lbl, lenS - 1, elwiseS);
|
||
|
bettaPtr[lenSB - 2] = element<IsLogPStrided>(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]];
|
||
|
|
||
|
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||
|
//move to the last
|
||
|
gradPtr += (lenT - 1) * incG;
|
||
|
alphaPtr += (lenT - 1) * incA;
|
||
|
for (auto s = lenSB - 1; s >= 0; s--)
|
||
|
{
|
||
|
auto ind = s / 2; //our real index
|
||
|
//we forced blanks for even indexes
|
||
|
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||
|
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||
|
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||
|
|
||
|
//sum (alpha(s)*betta(s) ) over real indexes
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||
|
if (currentGrad == negInf)
|
||
|
{
|
||
|
currentGrad = alphaBettaS;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
Type cMax = std::max(currentGrad, alphaBettaS);
|
||
|
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||
|
}
|
||
|
}
|
||
|
for (int k = 0; k < lenK; k++)
|
||
|
{
|
||
|
//compute the rest grad
|
||
|
|
||
|
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||
|
|
||
|
// p2= grad(k) / (prob(t,k)*Z )
|
||
|
//in logscale . plus we have Z as -logLoss
|
||
|
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||
|
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||
|
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||
|
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||
|
currentGrad = std::exp(currentProb) - p2;
|
||
|
}
|
||
|
gradPtr -= incG;
|
||
|
alphaPtr -= incA;
|
||
|
#endif
|
||
|
|
||
|
auto bettaPrevPtr = bettaPtr;
|
||
|
bettaPtr -= incA;
|
||
|
logP -= incP;
|
||
|
//process the rest
|
||
|
for (auto t = lenT - 2; t >= 0; t--)
|
||
|
{
|
||
|
|
||
|
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||
|
auto end = lenSB - 1;
|
||
|
#else
|
||
|
auto end = std::min(2 * t + 2, lenSB - 1);
|
||
|
#endif
|
||
|
for (auto s = end; s >= 0; s--)
|
||
|
{
|
||
|
auto ind = s / 2; //our real index
|
||
|
//we forced blanks for even indexes
|
||
|
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||
|
// {t-1,s}
|
||
|
Type bettaS = bettaPrevPtr[s];
|
||
|
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
|
||
|
Type cMax = std::max(bettaS, bettaS_1);
|
||
|
//logP[currentInd]
|
||
|
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||
|
// if blank or the same as previous
|
||
|
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
|
||
|
{
|
||
|
Type bettaS_2 = bettaPrevPtr[s + 2];
|
||
|
cMax = std::max(cMax, bettaS_2);
|
||
|
if (cMax == negInf)
|
||
|
cMax = 0;
|
||
|
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
if (cMax == negInf)
|
||
|
cMax = 0;
|
||
|
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
|
||
|
}
|
||
|
|
||
|
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||
|
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||
|
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||
|
|
||
|
//sum (alpha(s)*betta(s) ) over real indexes
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||
|
if (currentGrad == negInf)
|
||
|
{
|
||
|
currentGrad = alphaBettaS;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
Type cMax = std::max(currentGrad, alphaBettaS);
|
||
|
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||
|
}
|
||
|
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||
|
for (int k = 0; k < lenK; k++)
|
||
|
{
|
||
|
//compute the rest grad
|
||
|
|
||
|
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||
|
|
||
|
// p2= grad(k) / (prob(t,k)*Z )
|
||
|
//in logscale . plus we have Z as -logLoss
|
||
|
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||
|
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||
|
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||
|
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||
|
currentGrad = std::exp(currentProb) - p2;
|
||
|
}
|
||
|
alphaPtr -= incA;
|
||
|
gradPtr -= incG;
|
||
|
#endif
|
||
|
|
||
|
bettaPrevPtr = bettaPtr;
|
||
|
bettaPtr -= incA;
|
||
|
logP -= incP;
|
||
|
}
|
||
|
|
||
|
auto logBP0 = bettaPrevPtr[0];
|
||
|
auto logBP1 = bettaPrevPtr[1];
|
||
|
auto bcMax = std::max(logBP0, logBP1);
|
||
|
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
|
||
|
|
||
|
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||
|
//alpha*betta
|
||
|
bettaPtr = origBetta;
|
||
|
logP = origLogP;
|
||
|
|
||
|
for (int t = 0; t < lenT; t++)
|
||
|
{
|
||
|
|
||
|
for (int s = 0; s < lenSB; s++)
|
||
|
{
|
||
|
auto ind = s / 2; //our real index
|
||
|
//we forced blanks for even indexes
|
||
|
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||
|
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||
|
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||
|
|
||
|
//sum (alpha(s)*betta(s) ) over real indexes
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||
|
if (currentGrad == negInf)
|
||
|
{
|
||
|
currentGrad = alphaBettaS;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
Type cMax = std::max(currentGrad, alphaBettaS);
|
||
|
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||
|
}
|
||
|
//alphaPtr[s] = alphaBettaS;
|
||
|
}
|
||
|
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int k = 0; k < lenK; k++)
|
||
|
{
|
||
|
//compute the rest grad
|
||
|
|
||
|
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||
|
|
||
|
// p2= grad(k) / (prob(t,k)*Z )
|
||
|
//in logscale . plus we have Z as -logLoss
|
||
|
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||
|
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||
|
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||
|
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||
|
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||
|
currentGrad = std::exp(currentProb) - p2;
|
||
|
}
|
||
|
|
||
|
gradPtr += incG;
|
||
|
bettaPtr += incA;
|
||
|
alphaPtr += incA;
|
||
|
logP += incP;
|
||
|
}
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Calculates ctc loss and fills gradients
|
||
|
* @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn)
|
||
|
* @param incP stride of logits for the next time frame
|
||
|
* @param gradPtr gradient for output
|
||
|
* @param incG stride of the gradient for the next time frame
|
||
|
* @param lbl target label
|
||
|
* @param lenT frame length
|
||
|
* @param lenK class length
|
||
|
* @param lenS target label length
|
||
|
* @param blankIndex index of the blank label in logit class
|
||
|
*/
|
||
|
template <bool IsLogPStrided = true, bool IsLblStrided = true, bool IsGradStrided = true, typename Type, typename IndexType>
|
||
|
Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex,
|
||
|
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||
|
{
|
||
|
|
||
|
auto lenSB = 2 * lenS + 1;
|
||
|
//create temp Array for holding bettaArr [lenT,lenSB]
|
||
|
//create temp Array for holding alphaArr [lenT,lenSB]
|
||
|
int bufferC = gradPtr ? 2 : 1;
|
||
|
NDArray bufferArr = NDArrayFactory::create<Type>('c', {bufferC, lenT, lenSB});
|
||
|
auto bufferPtr = bufferArr.bufferAsT<Type>();
|
||
|
auto incA = bufferArr.stridesOf()[1];
|
||
|
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
|
||
|
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||
|
|
||
|
#if 1
|
||
|
if (gradPtr)
|
||
|
{
|
||
|
if (elwiseG == 1)
|
||
|
{
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < lenK * lenT; i++)
|
||
|
{
|
||
|
gradPtr[i] = negInf;
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
auto tempPtr = gradPtr;
|
||
|
for (int i = 0; i < lenT; i++)
|
||
|
{
|
||
|
for (int j = 0; j < lenK; j++)
|
||
|
element<false>(tempPtr, j, elwiseG) = negInf;
|
||
|
tempPtr += incG;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
// set all vals to neginf
|
||
|
PRAGMA_OMP_SIMD
|
||
|
for (int i = 0; i < bufferC * lenSB * lenT; i++)
|
||
|
{
|
||
|
bufferPtr[i] = negInf;
|
||
|
}
|
||
|
|
||
|
//forward
|
||
|
Type logLoss = forward<IsLogPStrided, IsLblStrided>(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS);
|
||
|
//backward and gradient if gradptr supplied
|
||
|
if (gradPtr)
|
||
|
backwardAndGrad<IsLogPStrided, IsLblStrided, IsGradStrided>(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG);
|
||
|
return logLoss;
|
||
|
}
|
||
|
|
||
|
template <typename Type, typename IndexType>
|
||
|
void
|
||
|
ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex)
|
||
|
{
|
||
|
// lenT - input length of T
|
||
|
// lenS - lenght of sequence
|
||
|
// lenSB - length with blanks
|
||
|
auto lenBatch = logits.shapeOf()[0];
|
||
|
|
||
|
auto maxLenT = logits.shapeOf()[1];
|
||
|
auto lenK = logits.shapeOf()[2];
|
||
|
auto maxLenS = targetLabels.shapeOf()[1];
|
||
|
|
||
|
// get probability bufer and tagetLabels buffer
|
||
|
auto logP = logits.bufferAsT<Type>();
|
||
|
auto lblPtr = targetLabels.bufferAsT<IndexType>();
|
||
|
|
||
|
auto lenTPtr = logitsLengths.bufferAsT<IndexType>();
|
||
|
auto lenSPtr = targetLabelLengths.bufferAsT<IndexType>();
|
||
|
|
||
|
auto batchLbl = targetLabels.stridesOf()[0];
|
||
|
auto batchP = logits.stridesOf()[0];
|
||
|
auto incP = logits.stridesOf()[1];
|
||
|
|
||
|
auto elwiseSLen = targetLabelLengths.stridesOf()[0];
|
||
|
auto elwiseT = logitsLengths.stridesOf()[0];
|
||
|
auto elwiseS = targetLabels.stridesOf()[1];
|
||
|
auto elwiseP = logits.stridesOf()[2];
|
||
|
|
||
|
int elwiseLL = 0;
|
||
|
Type *logLossPtr = nullptr;
|
||
|
if (!logLosses.isEmpty()){
|
||
|
elwiseLL = logLosses.stridesOf()[0];
|
||
|
logLossPtr = logLosses.bufferAsT<Type>();
|
||
|
}
|
||
|
|
||
|
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
|
||
|
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||
|
Type *gradPtr = nullptr;
|
||
|
Type resultLoss;
|
||
|
int batchG, incG, elwiseG;
|
||
|
if (!gradients.isEmpty())
|
||
|
{
|
||
|
batchG = gradients.stridesOf()[0];
|
||
|
incG = gradients.stridesOf()[1];
|
||
|
elwiseG = gradients.stridesOf()[2];
|
||
|
gradPtr = gradients.bufferAsT<Type>() + start * batchG;
|
||
|
}else{
|
||
|
elwiseG=1;
|
||
|
}
|
||
|
auto logPtr = logP + start * batchP;
|
||
|
auto tempLblPtr = lblPtr + start * batchLbl;
|
||
|
|
||
|
if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1)
|
||
|
{
|
||
|
//choose ews one
|
||
|
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||
|
{
|
||
|
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||
|
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||
|
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||
|
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||
|
if (lenS <= 0 || lenT <= 0)
|
||
|
{
|
||
|
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
if (lenS > lenT)
|
||
|
lenS = lenT;
|
||
|
resultLoss = unitLossAndGrad<false, false, false, Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex);
|
||
|
}
|
||
|
if (gradPtr) gradPtr += batchG;
|
||
|
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||
|
logPtr += batchP;
|
||
|
tempLblPtr += batchLbl;
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
//slow strided case for all 3
|
||
|
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||
|
{
|
||
|
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||
|
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||
|
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||
|
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||
|
if (lenS <= 0 || lenT <= 0)
|
||
|
{
|
||
|
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
if (lenS > lenT)
|
||
|
lenS = lenT;
|
||
|
resultLoss = unitLossAndGrad<true,true,true,Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG);
|
||
|
}
|
||
|
if (gradPtr) gradPtr += batchG;
|
||
|
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||
|
logPtr += batchP;
|
||
|
tempLblPtr += batchLbl;
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
samediff::Threads::parallel_for(func, 0, lenBatch, 1);
|
||
|
}
|
||
|
|
||
|
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
|
||
|
|
||
|
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||
|
}
|
||
|
|
||
|
|
||
|
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||
|
|
||
|
|
||
|
} // namespace helpers
|
||
|
} // namespace ops
|
||
|
} // namespace sd
|