Add ctc loss from KonduitAI PR, add missing java bits

master
agibsonccc 2021-03-11 14:22:34 +09:00
parent b7e433a22a
commit c3f04caef4
20 changed files with 1424 additions and 155 deletions

View File

@ -159,7 +159,7 @@
<artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<configuration>
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included

View File

@ -167,7 +167,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -230,7 +230,7 @@
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
<argLine> -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
<includes>
<!-- Default setting only runs tests that start/end with "Test" -->
<include>*.java</include>
@ -331,7 +331,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -0,0 +1,134 @@
/*******************************************************************************
* Copyright (c) 2021 Deeplearning4j Contributors
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
//
// @author AbdelRauf
//
#include <system/op_boilerplate.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/ctcLoss.h>
namespace sd {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputLosses = OUTPUT_VARIABLE(0);
int blankIndex = INT_ARG(0);
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
int batchSize0 = targetLabels->sizeAt(0);
int batchSize1 = logitInput->sizeAt(0);
int batchSize2 = targetLabelLengths->sizeAt(0);
int batchSize3 = logitInputLengths->sizeAt(0);
int batchSize4 = outputLosses->sizeAt(0);
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0);
REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());
auto emptyGradients = NDArrayFactory::empty<float>();
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGradients, blankIndex);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(ctc_loss) {
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
->setAllowedInputTypes(1,{ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(ctc_loss) {
auto yShapeInfo = inputShape->at(1);
auto zShapeInfo = inputShape->at(2);
auto dtype = ArrayOptions::dataType(yShapeInfo);
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(zShapeInfo, dtype)));
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputGradients = OUTPUT_VARIABLE(0);
int blankIndex = INT_ARG(0);
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
int batchSize0 = targetLabels->sizeAt(0);
int batchSize1 = logitInput->sizeAt(0);
int batchSize2 = targetLabelLengths->sizeAt(0);
int batchSize3 = logitInputLengths->sizeAt(0);
int batchSize4 = outputGradients->sizeAt(0);
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0);
REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());
auto emptyLoss = NDArrayFactory::empty<float>();
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, emptyLoss, *outputGradients, blankIndex);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(ctc_loss_grad) {
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
->setAllowedInputTypes(1,{ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(ctc_loss_grad) {
auto yShapeInfo = inputShape->at(1);
auto dtype = ArrayOptions::dataType(yShapeInfo);
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(yShapeInfo, dtype)));
}
}
}

View File

@ -360,6 +360,27 @@ namespace ops {
DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0);
#endif
/**
* Implementation of CTC loss function
*
* Input arrays:
* 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer
* 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float
* 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer
* 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer
*
*
* Input integer arguments:
* 0: blank index - index of the blank label in logits
*
* Output array:
* 0: loss values, type float. NDArray {BATCH_LEN} negative log probabilities of loss
*/
#if NOT_EXCLUDED(OP_ctc_loss)
DECLARE_CUSTOM_OP(ctc_loss, 4, 1, false, 0, 1);
DECLARE_CUSTOM_OP(ctc_loss_grad, 4, 1, false, 0, 1);
#endif
}
}

View File

@ -0,0 +1,507 @@
/*******************************************************************************
* Copyright (c) 2021 Deeplearning4j Contributors
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/ctcLoss.h>
namespace sd
{
namespace ops
{
namespace helpers
{
//choose ptr[index*element_stride]
template <bool Strided, typename Type>
typename std::enable_if<Strided == true, Type &>::type
element(Type *ptr, int index, int element_stride)
{
return ptr[index * element_stride];
}
//choose ptr[index] assuming element_stride is 1
template <bool Strided, typename Type>
typename std::enable_if<Strided == false, Type &>::type
element(Type *ptr, int index, int element_stride)
{
return ptr[index];
}
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
{
Type negInf = -DataTypeUtils::infOrMax<Type>();
//initialize alphas at t=0
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
//alphaPtr[1] =logP[lbl[0]];
alphaPtr[1] = element<IsLogPStrided>(logP, *lbl, elwiseP);
//the rest initialization was skipped
//as its assumed the array already were initialized with negative infinity
//move to the next frame
Type *alphaPrevPtr = alphaPtr;
alphaPtr += incA;
logP += incP;
auto startX = lenSB - 2 * lenT;
//process the rest
for (auto t = 1; t < lenT; t++)
{
//start = max(0,L-2*(T-t))
auto s = startX + 2 * t;
s = s > 0 ? s : 0;
for (; s < lenSB; s++)
{
auto ind = s / 2; //our real index
//we force blanks for even indexes
//strided version of lbl[ind] => element<IsLblStrided>(lbl, ind, elwiseS)
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
// {t-1,s}
Type alphaS = alphaPrevPtr[s];
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
Type cMax = std::max(alphaS, alphaS_1);
//logP[currentInd] or logP[currentInd*elwiseP]
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
// if blank or the same as previous
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
{
Type alphaS_2 = alphaPrevPtr[s - 2];
cMax = std::max(cMax, alphaS_2);
if (cMax == negInf)
cMax = 0;
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
}
else
{
if (cMax == negInf)
cMax = 0;
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
}
}
//store t-1 alpha Ptr
alphaPrevPtr = alphaPtr;
logP += incP;
alphaPtr += incA;
}
auto logP0 = alphaPrevPtr[lenSB - 1];
auto logP1 = alphaPrevPtr[lenSB - 2];
auto cMax = std::max(logP0, logP1);
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
}
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
template <bool IsLogPStrided = false, bool IsLblStrided = false, bool isGradStrided = false, typename Type, typename IndexType = int>
void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl,
const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex,
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
{
Type negInf = -DataTypeUtils::infOrMax<Type>();
Nd4jLong lenSB = 2 * lenS + 1;
auto origBetta = bettaPtr;
auto origLogP = logP;
//move to the last frame
bettaPtr += (lenT - 1) * incA;
logP += (lenT - 1) * incP;
//initialize bettas at t=lenT
bettaPtr[lenSB - 1] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
auto lblIndex = element<IsLblStrided>(lbl, lenS - 1, elwiseS);
bettaPtr[lenSB - 2] = element<IsLogPStrided>(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]];
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//move to the last
gradPtr += (lenT - 1) * incG;
alphaPtr += (lenT - 1) * incA;
for (auto s = lenSB - 1; s >= 0; s--)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
}
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
gradPtr -= incG;
alphaPtr -= incA;
#endif
auto bettaPrevPtr = bettaPtr;
bettaPtr -= incA;
logP -= incP;
//process the rest
for (auto t = lenT - 2; t >= 0; t--)
{
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
auto end = lenSB - 1;
#else
auto end = std::min(2 * t + 2, lenSB - 1);
#endif
for (auto s = end; s >= 0; s--)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
// {t-1,s}
Type bettaS = bettaPrevPtr[s];
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
Type cMax = std::max(bettaS, bettaS_1);
//logP[currentInd]
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
// if blank or the same as previous
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
{
Type bettaS_2 = bettaPrevPtr[s + 2];
cMax = std::max(cMax, bettaS_2);
if (cMax == negInf)
cMax = 0;
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
}
else
{
if (cMax == negInf)
cMax = 0;
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
}
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
#endif
}
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
alphaPtr -= incA;
gradPtr -= incG;
#endif
bettaPrevPtr = bettaPtr;
bettaPtr -= incA;
logP -= incP;
}
auto logBP0 = bettaPrevPtr[0];
auto logBP1 = bettaPrevPtr[1];
auto bcMax = std::max(logBP0, logBP1);
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
//alpha*betta
bettaPtr = origBetta;
logP = origLogP;
for (int t = 0; t < lenT; t++)
{
for (int s = 0; s < lenSB; s++)
{
auto ind = s / 2; //our real index
//we forced blanks for even indexes
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
//sum (alpha(s)*betta(s) ) over real indexes
auto &currentGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
if (currentGrad == negInf)
{
currentGrad = alphaBettaS;
}
else
{
Type cMax = std::max(currentGrad, alphaBettaS);
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
}
//alphaPtr[s] = alphaBettaS;
}
PRAGMA_OMP_SIMD
for (int k = 0; k < lenK; k++)
{
//compute the rest grad
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
// p2= grad(k) / (prob(t,k)*Z )
//in logscale . plus we have Z as -logLoss
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
// gradPtr[k] = std::exp(logP[k]) - p2;
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
auto &currentGrad = element<isGradStrided>(gradPtr, k, elwiseG);
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
currentGrad = std::exp(currentProb) - p2;
}
gradPtr += incG;
bettaPtr += incA;
alphaPtr += incA;
logP += incP;
}
#endif
}
/**
* Calculates ctc loss and fills gradients
* @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn)
* @param incP stride of logits for the next time frame
* @param gradPtr gradient for output
* @param incG stride of the gradient for the next time frame
* @param lbl target label
* @param lenT frame length
* @param lenK class length
* @param lenS target label length
* @param blankIndex index of the blank label in logit class
*/
template <bool IsLogPStrided = true, bool IsLblStrided = true, bool IsGradStrided = true, typename Type, typename IndexType>
Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex,
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
{
auto lenSB = 2 * lenS + 1;
//create temp Array for holding bettaArr [lenT,lenSB]
//create temp Array for holding alphaArr [lenT,lenSB]
int bufferC = gradPtr ? 2 : 1;
NDArray bufferArr = NDArrayFactory::create<Type>('c', {bufferC, lenT, lenSB});
auto bufferPtr = bufferArr.bufferAsT<Type>();
auto incA = bufferArr.stridesOf()[1];
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
Type negInf = -DataTypeUtils::infOrMax<Type>();
#if 1
if (gradPtr)
{
if (elwiseG == 1)
{
PRAGMA_OMP_SIMD
for (int i = 0; i < lenK * lenT; i++)
{
gradPtr[i] = negInf;
}
}
else
{
auto tempPtr = gradPtr;
for (int i = 0; i < lenT; i++)
{
for (int j = 0; j < lenK; j++)
element<false>(tempPtr, j, elwiseG) = negInf;
tempPtr += incG;
}
}
}
#endif
// set all vals to neginf
PRAGMA_OMP_SIMD
for (int i = 0; i < bufferC * lenSB * lenT; i++)
{
bufferPtr[i] = negInf;
}
//forward
Type logLoss = forward<IsLogPStrided, IsLblStrided>(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS);
//backward and gradient if gradptr supplied
if (gradPtr)
backwardAndGrad<IsLogPStrided, IsLblStrided, IsGradStrided>(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG);
return logLoss;
}
template <typename Type, typename IndexType>
void
ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex)
{
// lenT - input length of T
// lenS - lenght of sequence
// lenSB - length with blanks
auto lenBatch = logits.shapeOf()[0];
auto maxLenT = logits.shapeOf()[1];
auto lenK = logits.shapeOf()[2];
auto maxLenS = targetLabels.shapeOf()[1];
// get probability bufer and tagetLabels buffer
auto logP = logits.bufferAsT<Type>();
auto lblPtr = targetLabels.bufferAsT<IndexType>();
auto lenTPtr = logitsLengths.bufferAsT<IndexType>();
auto lenSPtr = targetLabelLengths.bufferAsT<IndexType>();
auto batchLbl = targetLabels.stridesOf()[0];
auto batchP = logits.stridesOf()[0];
auto incP = logits.stridesOf()[1];
auto elwiseSLen = targetLabelLengths.stridesOf()[0];
auto elwiseT = logitsLengths.stridesOf()[0];
auto elwiseS = targetLabels.stridesOf()[1];
auto elwiseP = logits.stridesOf()[2];
int elwiseLL = 0;
Type *logLossPtr = nullptr;
if (!logLosses.isEmpty()){
elwiseLL = logLosses.stridesOf()[0];
logLossPtr = logLosses.bufferAsT<Type>();
}
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Type *gradPtr = nullptr;
Type resultLoss;
int batchG, incG, elwiseG;
if (!gradients.isEmpty())
{
batchG = gradients.stridesOf()[0];
incG = gradients.stridesOf()[1];
elwiseG = gradients.stridesOf()[2];
gradPtr = gradients.bufferAsT<Type>() + start * batchG;
}else{
elwiseG=1;
}
auto logPtr = logP + start * batchP;
auto tempLblPtr = lblPtr + start * batchLbl;
if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1)
{
//choose ews one
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
{
auto lenT = lenTPtr[batchIndex * elwiseT];
auto lenS = lenSPtr[batchIndex * elwiseSLen];
lenT = lenT > maxLenT ? maxLenT : lenT;
lenS = lenS > maxLenS ? maxLenS : lenS;
if (lenS <= 0 || lenT <= 0)
{
resultLoss = -DataTypeUtils::infOrMax<Type>();
}
else
{
if (lenS > lenT)
lenS = lenT;
resultLoss = unitLossAndGrad<false, false, false, Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex);
}
if (gradPtr) gradPtr += batchG;
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
logPtr += batchP;
tempLblPtr += batchLbl;
}
}
else
{
//slow strided case for all 3
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
{
auto lenT = lenTPtr[batchIndex * elwiseT];
auto lenS = lenSPtr[batchIndex * elwiseSLen];
lenT = lenT > maxLenT ? maxLenT : lenT;
lenS = lenS > maxLenS ? maxLenS : lenS;
if (lenS <= 0 || lenT <= 0)
{
resultLoss = -DataTypeUtils::infOrMax<Type>();
}
else
{
if (lenS > lenT)
lenS = lenT;
resultLoss = unitLossAndGrad<true,true,true,Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG);
}
if (gradPtr) gradPtr += batchG;
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
logPtr += batchP;
tempLblPtr += batchLbl;
}
}
};
samediff::Threads::parallel_for(func, 0, lenBatch, 1);
}
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
} // namespace helpers
} // namespace ops
} // namespace sd

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2021 Deeplearning4j Contributors
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
//
// @author AbdelRauf
//
#ifndef LIBND4J_HELPERS_CTCLOSS_H
#define LIBND4J_HELPERS_CTCLOSS_H
#include <ops/declarable/helpers/helpers.h>
#include <graph/Context.h>
namespace sd {
namespace ops {
namespace helpers {
/**
* @brief Implementation of CTC loss function
* References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
*
* @param block Context
* @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well.
* @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN}
* @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits
* @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels
* @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss
* @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients
* @param blankIndex index of the blank label in logits
*/
void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex);
}
}
}
#endif // LIBND4J_ADDBIAS_H

View File

@ -0,0 +1,225 @@
/*******************************************************************************
* Copyright (c) 2021 Deeplearning4j Contributors
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
//
// @author AbdelRauf
//
#include "cudnnUtils.h"
#include <array/NDArrayFactory.h>
#include <vector>
namespace sd {
namespace ops {
namespace platforms {
template<typename Op, typename ...Args>
void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){
if(err==CUDNN_STATUS_SUCCESS){
err = op(std::forward<Args>(args)...);
if(err){
nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err));
}
}
}
template <typename T>
const T* bufferInHost( const NDArray &array) {
array.syncToHost();
return reinterpret_cast<const T*>(array.buffer());
}
std::vector<int> getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){
//concatenate target labels
const int32_t *tlabels = bufferInHost<int32_t>(targetLabels);
const int32_t *tlens =bufferInHost<int32_t>(targetLabelLengths);
int32_t nextOffset = targetLabels.strideAt(0);
int32_t elStride = targetLabels.strideAt(1);
int32_t batchCount = targetLabelLengths.lengthOf();
std::vector<int> labels;
labels.resize(targetLabels.lengthOf());
int j=0;
if(targetLabels.ews()){
for(int i=0; i<batchCount;i++){
int count = tlens[i];
for( int k=0;k<count;k++){
labels[j] = tlabels[k];
j++;
}
tlabels+=nextOffset;
}
}else{
for(int i=0; i<batchCount;i++){
int count = tlens[i];
for( int k=0;k<count;k++){
labels[j] = tlabels[k*elStride];
j++;
}
tlabels+=nextOffset;
}
}
return labels;
}
cudnnStatus_t cudnnCtcLoss(const LaunchContext &context, const NDArray &probs, const int32_t* targetLabelsPtr, const NDArray& probInputLengthes,
const NDArray &targetLabelLengths, NDArray &ctcLosses, NDArray &grads){
const int dims[] = {(int)probs.sizeAt(0), (int)probs.sizeAt(1), (int)probs.sizeAt(2)};
const int strides[] = {(int)probs.strideAt(0), (int)probs.strideAt(1), (int)probs.strideAt(2)};
auto handle = reinterpret_cast<cudnnHandle_t *>(context.getCuDnnHandle());
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream());
cudnnCTCLossDescriptor_t ctcLossDesc;
cudnnTensorDescriptor_t probsDesc = nullptr;
cudnnTensorDescriptor_t gradsDesc = nullptr;
callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc);
callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc);
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides);
if(!grads.isEmpty()){
const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)};
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc);
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides);
}
size_t tempWorkSpaceSize=0;
callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle, probsDesc, gradsDesc,
targetLabelsPtr,
bufferInHost<int32_t>(targetLabelLengths),
bufferInHost<int32_t>(probInputLengthes),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
ctcLossDesc, &tempWorkSpaceSize);
// Allocate temp tempWorkspace buffer
void *tempWorkSpace = nullptr;
cudaMalloc(&tempWorkSpace, tempWorkSpaceSize);
NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs});
callCudnnIfNoErr(err, cudnnCTCLoss,*handle,
probsDesc,
probs.specialBuffer(),
targetLabelsPtr,
bufferInHost<int32_t>(targetLabelLengths),
bufferInHost<int32_t>(probInputLengthes),
ctcLosses.specialBuffer(),
gradsDesc,
grads.specialBuffer(),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
ctcLossDesc,
tempWorkSpace,
tempWorkSpaceSize);
NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs});
cudaFree(tempWorkSpace);
callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc);
if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc);
callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc);
return err;
}
PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputLosses = OUTPUT_VARIABLE(0);
auto context = block.launchContext();
//in Cudnn Batch is in the middle dimension
logitInput->permutei({1,0,2});
//in Cudnn targets are concantenated instead of batched as matrix
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
const int32_t *ldata= labels.data();
auto emptyGrads= NDArrayFactory::empty<float>();
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads);
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
return Status::OK();
}
template<typename T>
bool checkLabelLength(const NDArray &labelLengthArr){
//check label lengthes
auto lenBatch = labelLengthArr.lengthOf();
for(int i=0; i < lenBatch; i++){
// The labelLengths is greater than 256.
if(labelLengthArr.e<int32_t>(i)>256) return false;
}
return true;
}
PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputLosses = OUTPUT_VARIABLE(0);
int blankIndex = INT_ARG(0);
auto dTypeInput = logitInput->dataType();
auto intType = targetLabelLengths->dataType();
auto dTypeOutput = outputLosses->dataType();
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
return is_supported;
}
PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputGradients = OUTPUT_VARIABLE(0);
auto context = block.launchContext();
//in Cudnn Batch is in the middle dimension
logitInput->permutei({1,0,2});
outputGradients->permutei({1,0,2});
//in Cudnn targets are concantenated instead of batched as matrix
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
const int32_t * ldata= labels.data();
auto tempLosses = NDArrayFactory::create<float>('c', {logitInputLengths->sizeAt(0)});
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients);
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
//restore grads shape from {T, BATCH, C} -> {BATCHS, T, C}
outputGradients->permutei({1,0,2});
//tempLosses.printIndexedBuffer("tempLosses");
return Status::OK();
}
PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) {
auto targetLabels = INPUT_VARIABLE(0);
auto logitInput = INPUT_VARIABLE(1);
auto targetLabelLengths = INPUT_VARIABLE(2);
auto logitInputLengths = INPUT_VARIABLE(3);
auto outputGrads = OUTPUT_VARIABLE(0);
int blankIndex = INT_ARG(0);
auto dTypeInput = logitInput->dataType();
auto intType = targetLabelLengths->dataType();
auto dTypeOutput = outputGrads->dataType();
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
return is_supported;
}
}
}
}

View File

@ -60,6 +60,9 @@ namespace platforms {
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
DECLARE_PLATFORM(ctc_loss, ENGINE_CUDA);
DECLARE_PLATFORM(ctc_loss_grad, ENGINE_CUDA);
//////////////////////////////////////////////////////////////////////////
FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) {
switch (dataType) {

File diff suppressed because it is too large Load Diff

View File

@ -5411,12 +5411,14 @@ public final class TensorNamespace {
* Serializations can either use one of the fields above, or use this
* raw bytes field. The only exception is the string case, where one is
* required to store the content in the repeated bytes string_data field.
*
* When this raw_data field is used to store tensor value, elements MUST
* be stored in as fixed-width, little-endian order.
* Floating-point data types MUST be stored in IEEE 754 format.
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
*
* Note: the advantage of specific field rather than the raw_data field is
* that in some cases (e.g. int data), protobuf does a better packing via
* variable length storage, and may lead to smaller binary footprint.
@ -5655,6 +5657,7 @@ public final class TensorNamespace {
/**
* <pre>
* Tensors
*
* A serialized tensor value.
* </pre>
*
@ -7010,12 +7013,14 @@ public final class TensorNamespace {
* Serializations can either use one of the fields above, or use this
* raw bytes field. The only exception is the string case, where one is
* required to store the content in the repeated bytes string_data field.
*
* When this raw_data field is used to store tensor value, elements MUST
* be stored in as fixed-width, little-endian order.
* Floating-point data types MUST be stored in IEEE 754 format.
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
*
* Note: the advantage of specific field rather than the raw_data field is
* that in some cases (e.g. int data), protobuf does a better packing via
* variable length storage, and may lead to smaller binary footprint.
@ -7766,6 +7771,7 @@ public final class TensorNamespace {
/**
* <pre>
* Tensors
*
* A serialized tensor value.
* </pre>
*
@ -9080,12 +9086,14 @@ public final class TensorNamespace {
* Serializations can either use one of the fields above, or use this
* raw bytes field. The only exception is the string case, where one is
* required to store the content in the repeated bytes string_data field.
*
* When this raw_data field is used to store tensor value, elements MUST
* be stored in as fixed-width, little-endian order.
* Floating-point data types MUST be stored in IEEE 754 format.
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
*
* Note: the advantage of specific field rather than the raw_data field is
* that in some cases (e.g. int data), protobuf does a better packing via
* variable length storage, and may lead to smaller binary footprint.
@ -9102,12 +9110,14 @@ public final class TensorNamespace {
* Serializations can either use one of the fields above, or use this
* raw bytes field. The only exception is the string case, where one is
* required to store the content in the repeated bytes string_data field.
*
* When this raw_data field is used to store tensor value, elements MUST
* be stored in as fixed-width, little-endian order.
* Floating-point data types MUST be stored in IEEE 754 format.
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
*
* Note: the advantage of specific field rather than the raw_data field is
* that in some cases (e.g. int data), protobuf does a better packing via
* variable length storage, and may lead to smaller binary footprint.
@ -9130,12 +9140,14 @@ public final class TensorNamespace {
* Serializations can either use one of the fields above, or use this
* raw bytes field. The only exception is the string case, where one is
* required to store the content in the repeated bytes string_data field.
*
* When this raw_data field is used to store tensor value, elements MUST
* be stored in as fixed-width, little-endian order.
* Floating-point data types MUST be stored in IEEE 754 format.
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
*
* Note: the advantage of specific field rather than the raw_data field is
* that in some cases (e.g. int data), protobuf does a better packing via
* variable length storage, and may lead to smaller binary footprint.

View File

@ -0,0 +1,60 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import java.util.List;
public class CtcLoss extends BaseLoss {
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
super(sameDiff, lossReduce, predictions, weights, labels);
}
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, label);
}
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public CtcLoss(){ }
@Override
public String opName() {
return "ctc_loss";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
//No external gradient
//Args are: predictions, weights, label
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
}
}

View File

@ -0,0 +1,48 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.loss.bp;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import java.util.List;
public class CtcLossBp extends BaseLossBp {
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
super(sameDiff, lossReduce, predictions, weights, labels);
}
public CtcLossBp(){ }
@Override
public String opName() {
return "ctc_loss_grad";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
}
}

View File

@ -391,7 +391,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine>-Dfile.encoding=UTF-8 "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -107,7 +107,7 @@
<include>*.java</include>
<include>**/*.java</include>
</includes>
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -99,7 +99,7 @@
<include>*.java</include>
<include>**/*.java</include>
</includes>
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -125,7 +125,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -103,7 +103,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> -Dfile.encoding=UTF-8 "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -159,7 +159,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>

View File

@ -114,7 +114,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
<argLine> "</argLine>
</configuration>
</plugin>
</plugins>