Merge pull request #9222 from eclipse/ag_ctc_loss
Add ctc loss from KonduitAI PR, add missing java bitsmaster
commit
ea3e450941
|
@ -159,7 +159,7 @@
|
|||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>${maven-surefire-plugin.version}</version>
|
||||
<configuration>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
|
||||
<!--
|
||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||
|
|
|
@ -167,7 +167,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -230,7 +230,7 @@
|
|||
-->
|
||||
<useSystemClassLoader>true</useSystemClassLoader>
|
||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
|
||||
<includes>
|
||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||
<include>*.java</include>
|
||||
|
@ -331,7 +331,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/ctcLoss.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) {
|
||||
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
|
||||
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
|
||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||
|
||||
int batchSize0 = targetLabels->sizeAt(0);
|
||||
int batchSize1 = logitInput->sizeAt(0);
|
||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
||||
int batchSize4 = outputLosses->sizeAt(0);
|
||||
|
||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||
|
||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0);
|
||||
REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());
|
||||
|
||||
auto emptyGradients = NDArrayFactory::empty<float>();
|
||||
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGradients, blankIndex);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(ctc_loss) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
|
||||
->setAllowedInputTypes(1,{ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(ctc_loss) {
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto zShapeInfo = inputShape->at(2);
|
||||
|
||||
auto dtype = ArrayOptions::dataType(yShapeInfo);
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(zShapeInfo, dtype)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) {
|
||||
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGradients = OUTPUT_VARIABLE(0);
|
||||
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
|
||||
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
|
||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||
|
||||
int batchSize0 = targetLabels->sizeAt(0);
|
||||
int batchSize1 = logitInput->sizeAt(0);
|
||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
||||
int batchSize4 = outputGradients->sizeAt(0);
|
||||
|
||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||
|
||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0);
|
||||
REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());
|
||||
|
||||
auto emptyLoss = NDArrayFactory::empty<float>();
|
||||
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, emptyLoss, *outputGradients, blankIndex);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(ctc_loss_grad) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
|
||||
->setAllowedInputTypes(1,{ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(ctc_loss_grad) {
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto dtype = ArrayOptions::dataType(yShapeInfo);
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(yShapeInfo, dtype)));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -360,6 +360,27 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Implementation of CTC loss function
|
||||
*
|
||||
* Input arrays:
|
||||
* 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer
|
||||
* 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float
|
||||
* 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer
|
||||
* 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer
|
||||
*
|
||||
*
|
||||
* Input integer arguments:
|
||||
* 0: blank index - index of the blank label in logits
|
||||
*
|
||||
* Output array:
|
||||
* 0: loss values, type float. NDArray {BATCH_LEN} negative log probabilities of loss
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_ctc_loss)
|
||||
DECLARE_CUSTOM_OP(ctc_loss, 4, 1, false, 0, 1);
|
||||
DECLARE_CUSTOM_OP(ctc_loss_grad, 4, 1, false, 0, 1);
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,507 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#include <type_traits>
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
#include <memory>
|
||||
#include <execution/Threads.h>
|
||||
#include <execution/ThreadPool.h>
|
||||
#include <helpers/LoopsCoordsHelper.h>
|
||||
#include <ops/declarable/helpers/ctcLoss.h>
|
||||
|
||||
namespace sd
|
||||
{
|
||||
namespace ops
|
||||
{
|
||||
namespace helpers
|
||||
{
|
||||
|
||||
//choose ptr[index*element_stride]
|
||||
template <bool Strided, typename Type>
|
||||
typename std::enable_if<Strided == true, Type &>::type
|
||||
element(Type *ptr, int index, int element_stride)
|
||||
{
|
||||
return ptr[index * element_stride];
|
||||
}
|
||||
|
||||
//choose ptr[index] assuming element_stride is 1
|
||||
template <bool Strided, typename Type>
|
||||
typename std::enable_if<Strided == false, Type &>::type
|
||||
element(Type *ptr, int index, int element_stride)
|
||||
{
|
||||
return ptr[index];
|
||||
}
|
||||
|
||||
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
|
||||
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
|
||||
{
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
//initialize alphas at t=0
|
||||
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||||
//alphaPtr[1] =logP[lbl[0]];
|
||||
alphaPtr[1] = element<IsLogPStrided>(logP, *lbl, elwiseP);
|
||||
//the rest initialization was skipped
|
||||
//as its assumed the array already were initialized with negative infinity
|
||||
//move to the next frame
|
||||
Type *alphaPrevPtr = alphaPtr;
|
||||
alphaPtr += incA;
|
||||
logP += incP;
|
||||
|
||||
auto startX = lenSB - 2 * lenT;
|
||||
//process the rest
|
||||
for (auto t = 1; t < lenT; t++)
|
||||
{
|
||||
|
||||
//start = max(0,L-2*(T-t))
|
||||
auto s = startX + 2 * t;
|
||||
s = s > 0 ? s : 0;
|
||||
for (; s < lenSB; s++)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we force blanks for even indexes
|
||||
//strided version of lbl[ind] => element<IsLblStrided>(lbl, ind, elwiseS)
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||||
// {t-1,s}
|
||||
Type alphaS = alphaPrevPtr[s];
|
||||
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
|
||||
Type cMax = std::max(alphaS, alphaS_1);
|
||||
//logP[currentInd] or logP[currentInd*elwiseP]
|
||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||
// if blank or the same as previous
|
||||
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
|
||||
{
|
||||
Type alphaS_2 = alphaPrevPtr[s - 2];
|
||||
cMax = std::max(cMax, alphaS_2);
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
}
|
||||
|
||||
//store t-1 alpha Ptr
|
||||
alphaPrevPtr = alphaPtr;
|
||||
logP += incP;
|
||||
alphaPtr += incA;
|
||||
}
|
||||
auto logP0 = alphaPrevPtr[lenSB - 1];
|
||||
auto logP1 = alphaPrevPtr[lenSB - 2];
|
||||
auto cMax = std::max(logP0, logP1);
|
||||
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
|
||||
}
|
||||
|
||||
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
|
||||
|
||||
template <bool IsLogPStrided = false, bool IsLblStrided = false, bool isGradStrided = false, typename Type, typename IndexType = int>
|
||||
void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl,
|
||||
const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex,
|
||||
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||||
{
|
||||
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
Nd4jLong lenSB = 2 * lenS + 1;
|
||||
auto origBetta = bettaPtr;
|
||||
auto origLogP = logP;
|
||||
//move to the last frame
|
||||
bettaPtr += (lenT - 1) * incA;
|
||||
logP += (lenT - 1) * incP;
|
||||
|
||||
//initialize bettas at t=lenT
|
||||
bettaPtr[lenSB - 1] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||||
auto lblIndex = element<IsLblStrided>(lbl, lenS - 1, elwiseS);
|
||||
bettaPtr[lenSB - 2] = element<IsLogPStrided>(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]];
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//move to the last
|
||||
gradPtr += (lenT - 1) * incG;
|
||||
alphaPtr += (lenT - 1) * incA;
|
||||
for (auto s = lenSB - 1; s >= 0; s--)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
gradPtr -= incG;
|
||||
alphaPtr -= incA;
|
||||
#endif
|
||||
|
||||
auto bettaPrevPtr = bettaPtr;
|
||||
bettaPtr -= incA;
|
||||
logP -= incP;
|
||||
//process the rest
|
||||
for (auto t = lenT - 2; t >= 0; t--)
|
||||
{
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
auto end = lenSB - 1;
|
||||
#else
|
||||
auto end = std::min(2 * t + 2, lenSB - 1);
|
||||
#endif
|
||||
for (auto s = end; s >= 0; s--)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||||
// {t-1,s}
|
||||
Type bettaS = bettaPrevPtr[s];
|
||||
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
|
||||
Type cMax = std::max(bettaS, bettaS_1);
|
||||
//logP[currentInd]
|
||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||
// if blank or the same as previous
|
||||
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
|
||||
{
|
||||
Type bettaS_2 = bettaPrevPtr[s + 2];
|
||||
cMax = std::max(cMax, bettaS_2);
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
alphaPtr -= incA;
|
||||
gradPtr -= incG;
|
||||
#endif
|
||||
|
||||
bettaPrevPtr = bettaPtr;
|
||||
bettaPtr -= incA;
|
||||
logP -= incP;
|
||||
}
|
||||
|
||||
auto logBP0 = bettaPrevPtr[0];
|
||||
auto logBP1 = bettaPrevPtr[1];
|
||||
auto bcMax = std::max(logBP0, logBP1);
|
||||
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
|
||||
|
||||
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//alpha*betta
|
||||
bettaPtr = origBetta;
|
||||
logP = origLogP;
|
||||
|
||||
for (int t = 0; t < lenT; t++)
|
||||
{
|
||||
|
||||
for (int s = 0; s < lenSB; s++)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
//alphaPtr[s] = alphaBettaS;
|
||||
}
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
|
||||
gradPtr += incG;
|
||||
bettaPtr += incA;
|
||||
alphaPtr += incA;
|
||||
logP += incP;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates ctc loss and fills gradients
|
||||
* @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn)
|
||||
* @param incP stride of logits for the next time frame
|
||||
* @param gradPtr gradient for output
|
||||
* @param incG stride of the gradient for the next time frame
|
||||
* @param lbl target label
|
||||
* @param lenT frame length
|
||||
* @param lenK class length
|
||||
* @param lenS target label length
|
||||
* @param blankIndex index of the blank label in logit class
|
||||
*/
|
||||
template <bool IsLogPStrided = true, bool IsLblStrided = true, bool IsGradStrided = true, typename Type, typename IndexType>
|
||||
Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex,
|
||||
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||||
{
|
||||
|
||||
auto lenSB = 2 * lenS + 1;
|
||||
//create temp Array for holding bettaArr [lenT,lenSB]
|
||||
//create temp Array for holding alphaArr [lenT,lenSB]
|
||||
int bufferC = gradPtr ? 2 : 1;
|
||||
NDArray bufferArr = NDArrayFactory::create<Type>('c', {bufferC, lenT, lenSB});
|
||||
auto bufferPtr = bufferArr.bufferAsT<Type>();
|
||||
auto incA = bufferArr.stridesOf()[1];
|
||||
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
|
||||
#if 1
|
||||
if (gradPtr)
|
||||
{
|
||||
if (elwiseG == 1)
|
||||
{
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < lenK * lenT; i++)
|
||||
{
|
||||
gradPtr[i] = negInf;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto tempPtr = gradPtr;
|
||||
for (int i = 0; i < lenT; i++)
|
||||
{
|
||||
for (int j = 0; j < lenK; j++)
|
||||
element<false>(tempPtr, j, elwiseG) = negInf;
|
||||
tempPtr += incG;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// set all vals to neginf
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < bufferC * lenSB * lenT; i++)
|
||||
{
|
||||
bufferPtr[i] = negInf;
|
||||
}
|
||||
|
||||
//forward
|
||||
Type logLoss = forward<IsLogPStrided, IsLblStrided>(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS);
|
||||
//backward and gradient if gradptr supplied
|
||||
if (gradPtr)
|
||||
backwardAndGrad<IsLogPStrided, IsLblStrided, IsGradStrided>(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG);
|
||||
return logLoss;
|
||||
}
|
||||
|
||||
template <typename Type, typename IndexType>
|
||||
void
|
||||
ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex)
|
||||
{
|
||||
// lenT - input length of T
|
||||
// lenS - lenght of sequence
|
||||
// lenSB - length with blanks
|
||||
auto lenBatch = logits.shapeOf()[0];
|
||||
|
||||
auto maxLenT = logits.shapeOf()[1];
|
||||
auto lenK = logits.shapeOf()[2];
|
||||
auto maxLenS = targetLabels.shapeOf()[1];
|
||||
|
||||
// get probability bufer and tagetLabels buffer
|
||||
auto logP = logits.bufferAsT<Type>();
|
||||
auto lblPtr = targetLabels.bufferAsT<IndexType>();
|
||||
|
||||
auto lenTPtr = logitsLengths.bufferAsT<IndexType>();
|
||||
auto lenSPtr = targetLabelLengths.bufferAsT<IndexType>();
|
||||
|
||||
auto batchLbl = targetLabels.stridesOf()[0];
|
||||
auto batchP = logits.stridesOf()[0];
|
||||
auto incP = logits.stridesOf()[1];
|
||||
|
||||
auto elwiseSLen = targetLabelLengths.stridesOf()[0];
|
||||
auto elwiseT = logitsLengths.stridesOf()[0];
|
||||
auto elwiseS = targetLabels.stridesOf()[1];
|
||||
auto elwiseP = logits.stridesOf()[2];
|
||||
|
||||
int elwiseLL = 0;
|
||||
Type *logLossPtr = nullptr;
|
||||
if (!logLosses.isEmpty()){
|
||||
elwiseLL = logLosses.stridesOf()[0];
|
||||
logLossPtr = logLosses.bufferAsT<Type>();
|
||||
}
|
||||
|
||||
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
|
||||
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||
Type *gradPtr = nullptr;
|
||||
Type resultLoss;
|
||||
int batchG, incG, elwiseG;
|
||||
if (!gradients.isEmpty())
|
||||
{
|
||||
batchG = gradients.stridesOf()[0];
|
||||
incG = gradients.stridesOf()[1];
|
||||
elwiseG = gradients.stridesOf()[2];
|
||||
gradPtr = gradients.bufferAsT<Type>() + start * batchG;
|
||||
}else{
|
||||
elwiseG=1;
|
||||
}
|
||||
auto logPtr = logP + start * batchP;
|
||||
auto tempLblPtr = lblPtr + start * batchLbl;
|
||||
|
||||
if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1)
|
||||
{
|
||||
//choose ews one
|
||||
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||||
{
|
||||
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||||
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||||
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||
if (lenS <= 0 || lenT <= 0)
|
||||
{
|
||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (lenS > lenT)
|
||||
lenS = lenT;
|
||||
resultLoss = unitLossAndGrad<false, false, false, Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex);
|
||||
}
|
||||
if (gradPtr) gradPtr += batchG;
|
||||
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||||
logPtr += batchP;
|
||||
tempLblPtr += batchLbl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
//slow strided case for all 3
|
||||
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||||
{
|
||||
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||||
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||||
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||
if (lenS <= 0 || lenT <= 0)
|
||||
{
|
||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (lenS > lenT)
|
||||
lenS = lenT;
|
||||
resultLoss = unitLossAndGrad<true,true,true,Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG);
|
||||
}
|
||||
if (gradPtr) gradPtr += batchG;
|
||||
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||||
logPtr += batchP;
|
||||
tempLblPtr += batchLbl;
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lenBatch, 1);
|
||||
}
|
||||
|
||||
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||
|
||||
|
||||
} // namespace helpers
|
||||
} // namespace ops
|
||||
} // namespace sd
|
|
@ -0,0 +1,55 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HELPERS_CTCLOSS_H
|
||||
#define LIBND4J_HELPERS_CTCLOSS_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <graph/Context.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
/**
|
||||
* @brief Implementation of CTC loss function
|
||||
* References:
|
||||
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks:
|
||||
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
|
||||
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
||||
*
|
||||
* @param block Context
|
||||
* @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well.
|
||||
* @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN}
|
||||
* @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits
|
||||
* @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels
|
||||
* @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss
|
||||
* @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients
|
||||
* @param blankIndex index of the blank label in logits
|
||||
*/
|
||||
void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif // LIBND4J_ADDBIAS_H
|
|
@ -0,0 +1,225 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <array/NDArrayFactory.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
|
||||
template<typename Op, typename ...Args>
|
||||
void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){
|
||||
if(err==CUDNN_STATUS_SUCCESS){
|
||||
err = op(std::forward<Args>(args)...);
|
||||
if(err){
|
||||
nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* bufferInHost( const NDArray &array) {
|
||||
array.syncToHost();
|
||||
return reinterpret_cast<const T*>(array.buffer());
|
||||
}
|
||||
|
||||
std::vector<int> getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){
|
||||
//concatenate target labels
|
||||
const int32_t *tlabels = bufferInHost<int32_t>(targetLabels);
|
||||
const int32_t *tlens =bufferInHost<int32_t>(targetLabelLengths);
|
||||
int32_t nextOffset = targetLabels.strideAt(0);
|
||||
int32_t elStride = targetLabels.strideAt(1);
|
||||
int32_t batchCount = targetLabelLengths.lengthOf();
|
||||
std::vector<int> labels;
|
||||
labels.resize(targetLabels.lengthOf());
|
||||
int j=0;
|
||||
if(targetLabels.ews()){
|
||||
for(int i=0; i<batchCount;i++){
|
||||
int count = tlens[i];
|
||||
for( int k=0;k<count;k++){
|
||||
labels[j] = tlabels[k];
|
||||
j++;
|
||||
}
|
||||
tlabels+=nextOffset;
|
||||
}
|
||||
}else{
|
||||
for(int i=0; i<batchCount;i++){
|
||||
int count = tlens[i];
|
||||
for( int k=0;k<count;k++){
|
||||
labels[j] = tlabels[k*elStride];
|
||||
j++;
|
||||
}
|
||||
tlabels+=nextOffset;
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
cudnnStatus_t cudnnCtcLoss(const LaunchContext &context, const NDArray &probs, const int32_t* targetLabelsPtr, const NDArray& probInputLengthes,
|
||||
const NDArray &targetLabelLengths, NDArray &ctcLosses, NDArray &grads){
|
||||
const int dims[] = {(int)probs.sizeAt(0), (int)probs.sizeAt(1), (int)probs.sizeAt(2)};
|
||||
const int strides[] = {(int)probs.strideAt(0), (int)probs.strideAt(1), (int)probs.strideAt(2)};
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context.getCuDnnHandle());
|
||||
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
|
||||
callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream());
|
||||
|
||||
cudnnCTCLossDescriptor_t ctcLossDesc;
|
||||
cudnnTensorDescriptor_t probsDesc = nullptr;
|
||||
cudnnTensorDescriptor_t gradsDesc = nullptr;
|
||||
callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
|
||||
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides);
|
||||
if(!grads.isEmpty()){
|
||||
const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)};
|
||||
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides);
|
||||
}
|
||||
|
||||
size_t tempWorkSpaceSize=0;
|
||||
callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle, probsDesc, gradsDesc,
|
||||
targetLabelsPtr,
|
||||
bufferInHost<int32_t>(targetLabelLengths),
|
||||
bufferInHost<int32_t>(probInputLengthes),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||
ctcLossDesc, &tempWorkSpaceSize);
|
||||
|
||||
// Allocate temp tempWorkspace buffer
|
||||
void *tempWorkSpace = nullptr;
|
||||
cudaMalloc(&tempWorkSpace, tempWorkSpaceSize);
|
||||
|
||||
NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs});
|
||||
callCudnnIfNoErr(err, cudnnCTCLoss,*handle,
|
||||
probsDesc,
|
||||
probs.specialBuffer(),
|
||||
targetLabelsPtr,
|
||||
bufferInHost<int32_t>(targetLabelLengths),
|
||||
bufferInHost<int32_t>(probInputLengthes),
|
||||
ctcLosses.specialBuffer(),
|
||||
gradsDesc,
|
||||
grads.specialBuffer(),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||
ctcLossDesc,
|
||||
tempWorkSpace,
|
||||
tempWorkSpaceSize);
|
||||
|
||||
NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs});
|
||||
|
||||
cudaFree(tempWorkSpace);
|
||||
callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc);
|
||||
if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc);
|
||||
callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc);
|
||||
return err;
|
||||
}
|
||||
|
||||
PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
auto context = block.launchContext();
|
||||
//in Cudnn Batch is in the middle dimension
|
||||
logitInput->permutei({1,0,2});
|
||||
//in Cudnn targets are concantenated instead of batched as matrix
|
||||
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
|
||||
const int32_t *ldata= labels.data();
|
||||
auto emptyGrads= NDArrayFactory::empty<float>();
|
||||
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads);
|
||||
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool checkLabelLength(const NDArray &labelLengthArr){
|
||||
//check label lengthes
|
||||
auto lenBatch = labelLengthArr.lengthOf();
|
||||
for(int i=0; i < lenBatch; i++){
|
||||
// The labelLengths is greater than 256.
|
||||
if(labelLengthArr.e<int32_t>(i)>256) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
auto dTypeInput = logitInput->dataType();
|
||||
auto intType = targetLabelLengths->dataType();
|
||||
auto dTypeOutput = outputLosses->dataType();
|
||||
|
||||
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
|
||||
is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
|
||||
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGradients = OUTPUT_VARIABLE(0);
|
||||
auto context = block.launchContext();
|
||||
//in Cudnn Batch is in the middle dimension
|
||||
logitInput->permutei({1,0,2});
|
||||
outputGradients->permutei({1,0,2});
|
||||
//in Cudnn targets are concantenated instead of batched as matrix
|
||||
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
|
||||
const int32_t * ldata= labels.data();
|
||||
auto tempLosses = NDArrayFactory::create<float>('c', {logitInputLengths->sizeAt(0)});
|
||||
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients);
|
||||
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
|
||||
//restore grads shape from {T, BATCH, C} -> {BATCHS, T, C}
|
||||
outputGradients->permutei({1,0,2});
|
||||
//tempLosses.printIndexedBuffer("tempLosses");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGrads = OUTPUT_VARIABLE(0);
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
auto dTypeInput = logitInput->dataType();
|
||||
auto intType = targetLabelLengths->dataType();
|
||||
auto dTypeOutput = outputGrads->dataType();
|
||||
|
||||
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
|
||||
is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
|
||||
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -60,6 +60,9 @@ namespace platforms {
|
|||
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(ctc_loss, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(ctc_loss_grad, ENGINE_CUDA);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) {
|
||||
switch (dataType) {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5411,12 +5411,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -5655,6 +5657,7 @@ public final class TensorNamespace {
|
|||
/**
|
||||
* <pre>
|
||||
* Tensors
|
||||
*
|
||||
* A serialized tensor value.
|
||||
* </pre>
|
||||
*
|
||||
|
@ -7010,12 +7013,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -7766,6 +7771,7 @@ public final class TensorNamespace {
|
|||
/**
|
||||
* <pre>
|
||||
* Tensors
|
||||
*
|
||||
* A serialized tensor value.
|
||||
* </pre>
|
||||
*
|
||||
|
@ -9080,12 +9086,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -9102,12 +9110,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -9130,12 +9140,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.loss;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLoss extends BaseLoss {
|
||||
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, label);
|
||||
}
|
||||
|
||||
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLoss(){ }
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "ctc_loss";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
//No external gradient
|
||||
//Args are: predictions, weights, label
|
||||
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.loss.bp;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLossBp extends BaseLossBp {
|
||||
|
||||
|
||||
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLossBp(){ }
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "ctc_loss_grad";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
|
||||
}
|
||||
|
||||
}
|
|
@ -391,7 +391,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine>-Dfile.encoding=UTF-8 "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -107,7 +107,7 @@
|
|||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -99,7 +99,7 @@
|
|||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -125,7 +125,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -103,7 +103,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> -Dfile.encoding=UTF-8 "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -159,7 +159,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -114,7 +114,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
Loading…
Reference in New Issue