Add ctc loss from KonduitAI PR, add missing java bits
parent
b7e433a22a
commit
c3f04caef4
|
@ -159,7 +159,7 @@
|
|||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>${maven-surefire-plugin.version}</version>
|
||||
<configuration>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
|
||||
<!--
|
||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||
|
|
|
@ -167,7 +167,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -230,7 +230,7 @@
|
|||
-->
|
||||
<useSystemClassLoader>true</useSystemClassLoader>
|
||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
|
||||
<includes>
|
||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||
<include>*.java</include>
|
||||
|
@ -331,7 +331,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/ctcLoss.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) {
|
||||
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
|
||||
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
|
||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||
|
||||
int batchSize0 = targetLabels->sizeAt(0);
|
||||
int batchSize1 = logitInput->sizeAt(0);
|
||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
||||
int batchSize4 = outputLosses->sizeAt(0);
|
||||
|
||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||
|
||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0);
|
||||
REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());
|
||||
|
||||
auto emptyGradients = NDArrayFactory::empty<float>();
|
||||
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGradients, blankIndex);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(ctc_loss) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
|
||||
->setAllowedInputTypes(1,{ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(ctc_loss) {
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto zShapeInfo = inputShape->at(2);
|
||||
|
||||
auto dtype = ArrayOptions::dataType(yShapeInfo);
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(zShapeInfo, dtype)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) {
|
||||
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGradients = OUTPUT_VARIABLE(0);
|
||||
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf());
|
||||
REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf());
|
||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||
|
||||
int batchSize0 = targetLabels->sizeAt(0);
|
||||
int batchSize1 = logitInput->sizeAt(0);
|
||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
||||
int batchSize4 = outputGradients->sizeAt(0);
|
||||
|
||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||
|
||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0);
|
||||
REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());
|
||||
|
||||
auto emptyLoss = NDArrayFactory::empty<float>();
|
||||
sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, emptyLoss, *outputGradients, blankIndex);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(ctc_loss_grad) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ALL_INDICES})
|
||||
->setAllowedInputTypes(1,{ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(ctc_loss_grad) {
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto dtype = ArrayOptions::dataType(yShapeInfo);
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(yShapeInfo, dtype)));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -360,6 +360,27 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Implementation of CTC loss function
|
||||
*
|
||||
* Input arrays:
|
||||
* 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer
|
||||
* 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float
|
||||
* 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer
|
||||
* 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer
|
||||
*
|
||||
*
|
||||
* Input integer arguments:
|
||||
* 0: blank index - index of the blank label in logits
|
||||
*
|
||||
* Output array:
|
||||
* 0: loss values, type float. NDArray {BATCH_LEN} negative log probabilities of loss
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_ctc_loss)
|
||||
DECLARE_CUSTOM_OP(ctc_loss, 4, 1, false, 0, 1);
|
||||
DECLARE_CUSTOM_OP(ctc_loss_grad, 4, 1, false, 0, 1);
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,507 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#include <type_traits>
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
#include <memory>
|
||||
#include <execution/Threads.h>
|
||||
#include <execution/ThreadPool.h>
|
||||
#include <helpers/LoopsCoordsHelper.h>
|
||||
#include <ops/declarable/helpers/ctcLoss.h>
|
||||
|
||||
namespace sd
|
||||
{
|
||||
namespace ops
|
||||
{
|
||||
namespace helpers
|
||||
{
|
||||
|
||||
//choose ptr[index*element_stride]
|
||||
template <bool Strided, typename Type>
|
||||
typename std::enable_if<Strided == true, Type &>::type
|
||||
element(Type *ptr, int index, int element_stride)
|
||||
{
|
||||
return ptr[index * element_stride];
|
||||
}
|
||||
|
||||
//choose ptr[index] assuming element_stride is 1
|
||||
template <bool Strided, typename Type>
|
||||
typename std::enable_if<Strided == false, Type &>::type
|
||||
element(Type *ptr, int index, int element_stride)
|
||||
{
|
||||
return ptr[index];
|
||||
}
|
||||
|
||||
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
|
||||
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
|
||||
{
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
//initialize alphas at t=0
|
||||
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||||
//alphaPtr[1] =logP[lbl[0]];
|
||||
alphaPtr[1] = element<IsLogPStrided>(logP, *lbl, elwiseP);
|
||||
//the rest initialization was skipped
|
||||
//as its assumed the array already were initialized with negative infinity
|
||||
//move to the next frame
|
||||
Type *alphaPrevPtr = alphaPtr;
|
||||
alphaPtr += incA;
|
||||
logP += incP;
|
||||
|
||||
auto startX = lenSB - 2 * lenT;
|
||||
//process the rest
|
||||
for (auto t = 1; t < lenT; t++)
|
||||
{
|
||||
|
||||
//start = max(0,L-2*(T-t))
|
||||
auto s = startX + 2 * t;
|
||||
s = s > 0 ? s : 0;
|
||||
for (; s < lenSB; s++)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we force blanks for even indexes
|
||||
//strided version of lbl[ind] => element<IsLblStrided>(lbl, ind, elwiseS)
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||||
// {t-1,s}
|
||||
Type alphaS = alphaPrevPtr[s];
|
||||
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
|
||||
Type cMax = std::max(alphaS, alphaS_1);
|
||||
//logP[currentInd] or logP[currentInd*elwiseP]
|
||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||
// if blank or the same as previous
|
||||
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
|
||||
{
|
||||
Type alphaS_2 = alphaPrevPtr[s - 2];
|
||||
cMax = std::max(cMax, alphaS_2);
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
}
|
||||
|
||||
//store t-1 alpha Ptr
|
||||
alphaPrevPtr = alphaPtr;
|
||||
logP += incP;
|
||||
alphaPtr += incA;
|
||||
}
|
||||
auto logP0 = alphaPrevPtr[lenSB - 1];
|
||||
auto logP1 = alphaPrevPtr[lenSB - 2];
|
||||
auto cMax = std::max(logP0, logP1);
|
||||
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
|
||||
}
|
||||
|
||||
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
|
||||
|
||||
template <bool IsLogPStrided = false, bool IsLblStrided = false, bool isGradStrided = false, typename Type, typename IndexType = int>
|
||||
void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl,
|
||||
const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex,
|
||||
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||||
{
|
||||
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
Nd4jLong lenSB = 2 * lenS + 1;
|
||||
auto origBetta = bettaPtr;
|
||||
auto origLogP = logP;
|
||||
//move to the last frame
|
||||
bettaPtr += (lenT - 1) * incA;
|
||||
logP += (lenT - 1) * incP;
|
||||
|
||||
//initialize bettas at t=lenT
|
||||
bettaPtr[lenSB - 1] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||||
auto lblIndex = element<IsLblStrided>(lbl, lenS - 1, elwiseS);
|
||||
bettaPtr[lenSB - 2] = element<IsLogPStrided>(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]];
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//move to the last
|
||||
gradPtr += (lenT - 1) * incG;
|
||||
alphaPtr += (lenT - 1) * incA;
|
||||
for (auto s = lenSB - 1; s >= 0; s--)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS);
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
gradPtr -= incG;
|
||||
alphaPtr -= incA;
|
||||
#endif
|
||||
|
||||
auto bettaPrevPtr = bettaPtr;
|
||||
bettaPtr -= incA;
|
||||
logP -= incP;
|
||||
//process the rest
|
||||
for (auto t = lenT - 2; t >= 0; t--)
|
||||
{
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
auto end = lenSB - 1;
|
||||
#else
|
||||
auto end = std::min(2 * t + 2, lenSB - 1);
|
||||
#endif
|
||||
for (auto s = end; s >= 0; s--)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||||
// {t-1,s}
|
||||
Type bettaS = bettaPrevPtr[s];
|
||||
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
|
||||
Type cMax = std::max(bettaS, bettaS_1);
|
||||
//logP[currentInd]
|
||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||
// if blank or the same as previous
|
||||
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
|
||||
{
|
||||
Type bettaS_2 = bettaPrevPtr[s + 2];
|
||||
cMax = std::max(cMax, bettaS_2);
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cMax == negInf)
|
||||
cMax = 0;
|
||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
|
||||
}
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
alphaPtr -= incA;
|
||||
gradPtr -= incG;
|
||||
#endif
|
||||
|
||||
bettaPrevPtr = bettaPtr;
|
||||
bettaPtr -= incA;
|
||||
logP -= incP;
|
||||
}
|
||||
|
||||
auto logBP0 = bettaPrevPtr[0];
|
||||
auto logBP1 = bettaPrevPtr[1];
|
||||
auto bcMax = std::max(logBP0, logBP1);
|
||||
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
|
||||
|
||||
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||
//alpha*betta
|
||||
bettaPtr = origBetta;
|
||||
logP = origLogP;
|
||||
|
||||
for (int t = 0; t < lenT; t++)
|
||||
{
|
||||
|
||||
for (int s = 0; s < lenSB; s++)
|
||||
{
|
||||
auto ind = s / 2; //our real index
|
||||
//we forced blanks for even indexes
|
||||
auto currentInd = (s % 2 == 0) ? blankIndex : element<IsLblStrided>(lbl, ind, elwiseS); //lbl[ind];
|
||||
//alpha(s)*betta(s) in log scale but still store in alpha to save memory
|
||||
auto alphaBettaS = alphaPtr[s] + bettaPtr[s];
|
||||
|
||||
//sum (alpha(s)*betta(s) ) over real indexes
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, currentInd, elwiseG); // gradPtr[currentInd];
|
||||
if (currentGrad == negInf)
|
||||
{
|
||||
currentGrad = alphaBettaS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
||||
}
|
||||
//alphaPtr[s] = alphaBettaS;
|
||||
}
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int k = 0; k < lenK; k++)
|
||||
{
|
||||
//compute the rest grad
|
||||
|
||||
// prob(t,k) - grad(k) / ((prob(t,k)*Z) )
|
||||
|
||||
// p2= grad(k) / (prob(t,k)*Z )
|
||||
//in logscale . plus we have Z as -logLoss
|
||||
// auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]);
|
||||
// gradPtr[k] = std::exp(logP[k]) - p2;
|
||||
auto currentProb = element<IsLogPStrided>(logP, k, elwiseP);
|
||||
auto ¤tGrad = element<isGradStrided>(gradPtr, k, elwiseG);
|
||||
auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb);
|
||||
currentGrad = std::exp(currentProb) - p2;
|
||||
}
|
||||
|
||||
gradPtr += incG;
|
||||
bettaPtr += incA;
|
||||
alphaPtr += incA;
|
||||
logP += incP;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates ctc loss and fills gradients
|
||||
* @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn)
|
||||
* @param incP stride of logits for the next time frame
|
||||
* @param gradPtr gradient for output
|
||||
* @param incG stride of the gradient for the next time frame
|
||||
* @param lbl target label
|
||||
* @param lenT frame length
|
||||
* @param lenK class length
|
||||
* @param lenS target label length
|
||||
* @param blankIndex index of the blank label in logit class
|
||||
*/
|
||||
template <bool IsLogPStrided = true, bool IsLblStrided = true, bool IsGradStrided = true, typename Type, typename IndexType>
|
||||
Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex,
|
||||
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||||
{
|
||||
|
||||
auto lenSB = 2 * lenS + 1;
|
||||
//create temp Array for holding bettaArr [lenT,lenSB]
|
||||
//create temp Array for holding alphaArr [lenT,lenSB]
|
||||
int bufferC = gradPtr ? 2 : 1;
|
||||
NDArray bufferArr = NDArrayFactory::create<Type>('c', {bufferC, lenT, lenSB});
|
||||
auto bufferPtr = bufferArr.bufferAsT<Type>();
|
||||
auto incA = bufferArr.stridesOf()[1];
|
||||
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
|
||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
||||
|
||||
#if 1
|
||||
if (gradPtr)
|
||||
{
|
||||
if (elwiseG == 1)
|
||||
{
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < lenK * lenT; i++)
|
||||
{
|
||||
gradPtr[i] = negInf;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto tempPtr = gradPtr;
|
||||
for (int i = 0; i < lenT; i++)
|
||||
{
|
||||
for (int j = 0; j < lenK; j++)
|
||||
element<false>(tempPtr, j, elwiseG) = negInf;
|
||||
tempPtr += incG;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// set all vals to neginf
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int i = 0; i < bufferC * lenSB * lenT; i++)
|
||||
{
|
||||
bufferPtr[i] = negInf;
|
||||
}
|
||||
|
||||
//forward
|
||||
Type logLoss = forward<IsLogPStrided, IsLblStrided>(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS);
|
||||
//backward and gradient if gradptr supplied
|
||||
if (gradPtr)
|
||||
backwardAndGrad<IsLogPStrided, IsLblStrided, IsGradStrided>(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG);
|
||||
return logLoss;
|
||||
}
|
||||
|
||||
template <typename Type, typename IndexType>
|
||||
void
|
||||
ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex)
|
||||
{
|
||||
// lenT - input length of T
|
||||
// lenS - lenght of sequence
|
||||
// lenSB - length with blanks
|
||||
auto lenBatch = logits.shapeOf()[0];
|
||||
|
||||
auto maxLenT = logits.shapeOf()[1];
|
||||
auto lenK = logits.shapeOf()[2];
|
||||
auto maxLenS = targetLabels.shapeOf()[1];
|
||||
|
||||
// get probability bufer and tagetLabels buffer
|
||||
auto logP = logits.bufferAsT<Type>();
|
||||
auto lblPtr = targetLabels.bufferAsT<IndexType>();
|
||||
|
||||
auto lenTPtr = logitsLengths.bufferAsT<IndexType>();
|
||||
auto lenSPtr = targetLabelLengths.bufferAsT<IndexType>();
|
||||
|
||||
auto batchLbl = targetLabels.stridesOf()[0];
|
||||
auto batchP = logits.stridesOf()[0];
|
||||
auto incP = logits.stridesOf()[1];
|
||||
|
||||
auto elwiseSLen = targetLabelLengths.stridesOf()[0];
|
||||
auto elwiseT = logitsLengths.stridesOf()[0];
|
||||
auto elwiseS = targetLabels.stridesOf()[1];
|
||||
auto elwiseP = logits.stridesOf()[2];
|
||||
|
||||
int elwiseLL = 0;
|
||||
Type *logLossPtr = nullptr;
|
||||
if (!logLosses.isEmpty()){
|
||||
elwiseLL = logLosses.stridesOf()[0];
|
||||
logLossPtr = logLosses.bufferAsT<Type>();
|
||||
}
|
||||
|
||||
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
|
||||
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||
Type *gradPtr = nullptr;
|
||||
Type resultLoss;
|
||||
int batchG, incG, elwiseG;
|
||||
if (!gradients.isEmpty())
|
||||
{
|
||||
batchG = gradients.stridesOf()[0];
|
||||
incG = gradients.stridesOf()[1];
|
||||
elwiseG = gradients.stridesOf()[2];
|
||||
gradPtr = gradients.bufferAsT<Type>() + start * batchG;
|
||||
}else{
|
||||
elwiseG=1;
|
||||
}
|
||||
auto logPtr = logP + start * batchP;
|
||||
auto tempLblPtr = lblPtr + start * batchLbl;
|
||||
|
||||
if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1)
|
||||
{
|
||||
//choose ews one
|
||||
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||||
{
|
||||
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||||
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||||
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||
if (lenS <= 0 || lenT <= 0)
|
||||
{
|
||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (lenS > lenT)
|
||||
lenS = lenT;
|
||||
resultLoss = unitLossAndGrad<false, false, false, Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex);
|
||||
}
|
||||
if (gradPtr) gradPtr += batchG;
|
||||
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||||
logPtr += batchP;
|
||||
tempLblPtr += batchLbl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
//slow strided case for all 3
|
||||
for (int batchIndex = start; batchIndex < stop; batchIndex += increment)
|
||||
{
|
||||
auto lenT = lenTPtr[batchIndex * elwiseT];
|
||||
auto lenS = lenSPtr[batchIndex * elwiseSLen];
|
||||
lenT = lenT > maxLenT ? maxLenT : lenT;
|
||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||
if (lenS <= 0 || lenT <= 0)
|
||||
{
|
||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (lenS > lenT)
|
||||
lenS = lenT;
|
||||
resultLoss = unitLossAndGrad<true,true,true,Type, IndexType>(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG);
|
||||
}
|
||||
if (gradPtr) gradPtr += batchG;
|
||||
if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss;
|
||||
logPtr += batchP;
|
||||
tempLblPtr += batchLbl;
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lenBatch, 1);
|
||||
}
|
||||
|
||||
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||
|
||||
|
||||
} // namespace helpers
|
||||
} // namespace ops
|
||||
} // namespace sd
|
|
@ -0,0 +1,55 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HELPERS_CTCLOSS_H
|
||||
#define LIBND4J_HELPERS_CTCLOSS_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <graph/Context.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
/**
|
||||
* @brief Implementation of CTC loss function
|
||||
* References:
|
||||
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks:
|
||||
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
|
||||
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
||||
*
|
||||
* @param block Context
|
||||
* @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well.
|
||||
* @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN}
|
||||
* @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits
|
||||
* @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels
|
||||
* @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss
|
||||
* @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients
|
||||
* @param blankIndex index of the blank label in logits
|
||||
*/
|
||||
void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif // LIBND4J_ADDBIAS_H
|
|
@ -0,0 +1,225 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*******************************************************************************/
|
||||
//
|
||||
// @author AbdelRauf
|
||||
//
|
||||
|
||||
#include "cudnnUtils.h"
|
||||
#include <array/NDArrayFactory.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
|
||||
template<typename Op, typename ...Args>
|
||||
void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){
|
||||
if(err==CUDNN_STATUS_SUCCESS){
|
||||
err = op(std::forward<Args>(args)...);
|
||||
if(err){
|
||||
nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* bufferInHost( const NDArray &array) {
|
||||
array.syncToHost();
|
||||
return reinterpret_cast<const T*>(array.buffer());
|
||||
}
|
||||
|
||||
std::vector<int> getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){
|
||||
//concatenate target labels
|
||||
const int32_t *tlabels = bufferInHost<int32_t>(targetLabels);
|
||||
const int32_t *tlens =bufferInHost<int32_t>(targetLabelLengths);
|
||||
int32_t nextOffset = targetLabels.strideAt(0);
|
||||
int32_t elStride = targetLabels.strideAt(1);
|
||||
int32_t batchCount = targetLabelLengths.lengthOf();
|
||||
std::vector<int> labels;
|
||||
labels.resize(targetLabels.lengthOf());
|
||||
int j=0;
|
||||
if(targetLabels.ews()){
|
||||
for(int i=0; i<batchCount;i++){
|
||||
int count = tlens[i];
|
||||
for( int k=0;k<count;k++){
|
||||
labels[j] = tlabels[k];
|
||||
j++;
|
||||
}
|
||||
tlabels+=nextOffset;
|
||||
}
|
||||
}else{
|
||||
for(int i=0; i<batchCount;i++){
|
||||
int count = tlens[i];
|
||||
for( int k=0;k<count;k++){
|
||||
labels[j] = tlabels[k*elStride];
|
||||
j++;
|
||||
}
|
||||
tlabels+=nextOffset;
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
cudnnStatus_t cudnnCtcLoss(const LaunchContext &context, const NDArray &probs, const int32_t* targetLabelsPtr, const NDArray& probInputLengthes,
|
||||
const NDArray &targetLabelLengths, NDArray &ctcLosses, NDArray &grads){
|
||||
const int dims[] = {(int)probs.sizeAt(0), (int)probs.sizeAt(1), (int)probs.sizeAt(2)};
|
||||
const int strides[] = {(int)probs.strideAt(0), (int)probs.strideAt(1), (int)probs.strideAt(2)};
|
||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context.getCuDnnHandle());
|
||||
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
|
||||
callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream());
|
||||
|
||||
cudnnCTCLossDescriptor_t ctcLossDesc;
|
||||
cudnnTensorDescriptor_t probsDesc = nullptr;
|
||||
cudnnTensorDescriptor_t gradsDesc = nullptr;
|
||||
callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
|
||||
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides);
|
||||
if(!grads.isEmpty()){
|
||||
const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)};
|
||||
callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc);
|
||||
callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides);
|
||||
}
|
||||
|
||||
size_t tempWorkSpaceSize=0;
|
||||
callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle, probsDesc, gradsDesc,
|
||||
targetLabelsPtr,
|
||||
bufferInHost<int32_t>(targetLabelLengths),
|
||||
bufferInHost<int32_t>(probInputLengthes),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||
ctcLossDesc, &tempWorkSpaceSize);
|
||||
|
||||
// Allocate temp tempWorkspace buffer
|
||||
void *tempWorkSpace = nullptr;
|
||||
cudaMalloc(&tempWorkSpace, tempWorkSpaceSize);
|
||||
|
||||
NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs});
|
||||
callCudnnIfNoErr(err, cudnnCTCLoss,*handle,
|
||||
probsDesc,
|
||||
probs.specialBuffer(),
|
||||
targetLabelsPtr,
|
||||
bufferInHost<int32_t>(targetLabelLengths),
|
||||
bufferInHost<int32_t>(probInputLengthes),
|
||||
ctcLosses.specialBuffer(),
|
||||
gradsDesc,
|
||||
grads.specialBuffer(),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||
ctcLossDesc,
|
||||
tempWorkSpace,
|
||||
tempWorkSpaceSize);
|
||||
|
||||
NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs});
|
||||
|
||||
cudaFree(tempWorkSpace);
|
||||
callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc);
|
||||
if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc);
|
||||
callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc);
|
||||
return err;
|
||||
}
|
||||
|
||||
PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
auto context = block.launchContext();
|
||||
//in Cudnn Batch is in the middle dimension
|
||||
logitInput->permutei({1,0,2});
|
||||
//in Cudnn targets are concantenated instead of batched as matrix
|
||||
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
|
||||
const int32_t *ldata= labels.data();
|
||||
auto emptyGrads= NDArrayFactory::empty<float>();
|
||||
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads);
|
||||
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool checkLabelLength(const NDArray &labelLengthArr){
|
||||
//check label lengthes
|
||||
auto lenBatch = labelLengthArr.lengthOf();
|
||||
for(int i=0; i < lenBatch; i++){
|
||||
// The labelLengths is greater than 256.
|
||||
if(labelLengthArr.e<int32_t>(i)>256) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputLosses = OUTPUT_VARIABLE(0);
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
auto dTypeInput = logitInput->dataType();
|
||||
auto intType = targetLabelLengths->dataType();
|
||||
auto dTypeOutput = outputLosses->dataType();
|
||||
|
||||
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
|
||||
is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
|
||||
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGradients = OUTPUT_VARIABLE(0);
|
||||
auto context = block.launchContext();
|
||||
//in Cudnn Batch is in the middle dimension
|
||||
logitInput->permutei({1,0,2});
|
||||
outputGradients->permutei({1,0,2});
|
||||
//in Cudnn targets are concantenated instead of batched as matrix
|
||||
auto labels = getConcatTargets(*targetLabels, *targetLabelLengths);
|
||||
const int32_t * ldata= labels.data();
|
||||
auto tempLosses = NDArrayFactory::create<float>('c', {logitInputLengths->sizeAt(0)});
|
||||
auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients);
|
||||
if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err);
|
||||
//restore grads shape from {T, BATCH, C} -> {BATCHS, T, C}
|
||||
outputGradients->permutei({1,0,2});
|
||||
//tempLosses.printIndexedBuffer("tempLosses");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) {
|
||||
auto targetLabels = INPUT_VARIABLE(0);
|
||||
auto logitInput = INPUT_VARIABLE(1);
|
||||
auto targetLabelLengths = INPUT_VARIABLE(2);
|
||||
auto logitInputLengths = INPUT_VARIABLE(3);
|
||||
auto outputGrads = OUTPUT_VARIABLE(0);
|
||||
int blankIndex = INT_ARG(0);
|
||||
|
||||
auto dTypeInput = logitInput->dataType();
|
||||
auto intType = targetLabelLengths->dataType();
|
||||
auto dTypeOutput = outputGrads->dataType();
|
||||
|
||||
bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32;
|
||||
is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews();
|
||||
is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -60,6 +60,9 @@ namespace platforms {
|
|||
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
|
||||
|
||||
DECLARE_PLATFORM(ctc_loss, ENGINE_CUDA);
|
||||
DECLARE_PLATFORM(ctc_loss_grad, ENGINE_CUDA);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) {
|
||||
switch (dataType) {
|
||||
|
|
|
@ -21,8 +21,7 @@
|
|||
#include <helpers/helper_hash.h>
|
||||
#include <array/NDArray.h>
|
||||
#include <array/NDArrayList.h>
|
||||
|
||||
|
||||
#include <numeric>
|
||||
using namespace sd;
|
||||
using namespace sd::graph;
|
||||
|
||||
|
@ -4179,3 +4178,208 @@ TEST_F(DeclarableOpsTests2, lstmCell_test12) {
|
|||
|
||||
|
||||
}
|
||||
|
||||
#if !defined(__CUDABLAS__) || defined(HAVE_CUDNN)
|
||||
TEST_F(DeclarableOpsTests2, ctc_loss_test1) {
|
||||
constexpr int FRAME_LEN = 6 ;
|
||||
constexpr int CLASS_LEN = 5 ;
|
||||
constexpr int BATCH_LEN = 4 ;
|
||||
constexpr int MIN_TARGET_LEN = 2;
|
||||
constexpr int MAX_TARGET_LEN = 4;
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
//cudnn blankindex should be 0
|
||||
constexpr int BLANK_INDEX=0;
|
||||
#else
|
||||
constexpr int BLANK_INDEX=CLASS_LEN-1;
|
||||
#endif
|
||||
//logits were generated using numpy random and applying log softmax
|
||||
//[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4)
|
||||
auto logits = NDArrayFactory::create<float>('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN },
|
||||
{-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f,
|
||||
-2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f,
|
||||
-1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f,
|
||||
-1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f,
|
||||
-1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f,
|
||||
-1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f,
|
||||
-1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f,
|
||||
-1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f,
|
||||
-1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f,
|
||||
-1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f,
|
||||
-1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f,
|
||||
-1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f,
|
||||
-1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f,
|
||||
-1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f,
|
||||
-1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f,
|
||||
-2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f,
|
||||
-1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f,
|
||||
-1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f,
|
||||
-1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f,
|
||||
-1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f,
|
||||
-1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f,
|
||||
-1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f,
|
||||
-1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f,
|
||||
-1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f
|
||||
});
|
||||
|
||||
auto logits_length = NDArrayFactory::create<int>('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN});
|
||||
std::vector<int> target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2};
|
||||
#if defined(HAVE_CUDNN)
|
||||
//for cudnn blank index is -. therefore our targets cant be 0
|
||||
for(int i=0;i<target.size();i++){
|
||||
target[i]=target[i]+1;
|
||||
}
|
||||
#endif
|
||||
auto labels = NDArrayFactory::create<int>('c',{BATCH_LEN, MAX_TARGET_LEN}, target );
|
||||
|
||||
auto labels_len = NDArrayFactory::create<int>('c', {BATCH_LEN}, {MIN_TARGET_LEN,MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1});
|
||||
|
||||
#if defined(HAVE_CUDNN)
|
||||
auto expected = NDArrayFactory::create<float>('c', {BATCH_LEN}, {6.088762f, 5.9546056f, 7.5806675f, 5.5532417f});
|
||||
#else
|
||||
auto expected = NDArrayFactory::create<float>('c', {BATCH_LEN}, {6.0661564f, 6.4285727f, 7.7180986f, 4.936057f});
|
||||
#endif
|
||||
sd::ops::ctc_loss op;
|
||||
|
||||
//logits.printIndexedBuffer("logits");
|
||||
//labels.printIndexedBuffer("labels");
|
||||
|
||||
auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto *loss = results.at(0);
|
||||
|
||||
//loss->printIndexedBuffer("loss");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(loss));
|
||||
ASSERT_TRUE(expected.equalsTo(loss));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests2, ctc_loss_grad_test1) {
|
||||
constexpr int FRAME_LEN = 6 ;
|
||||
constexpr int CLASS_LEN = 5 ;
|
||||
constexpr int BATCH_LEN = 4 ;
|
||||
constexpr int MAX_TARGET_LEN = 4;
|
||||
constexpr int MIN_TARGET_LEN = 2;
|
||||
#if defined(HAVE_CUDNN)
|
||||
//cudnn blankindex should be 0
|
||||
constexpr int BLANK_INDEX=0;
|
||||
#else
|
||||
constexpr int BLANK_INDEX=CLASS_LEN-1;
|
||||
#endif
|
||||
//logits were generated using numpy random and applying log softmax
|
||||
//[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4)
|
||||
auto logits = NDArrayFactory::create<float>('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN },
|
||||
{-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f,
|
||||
-2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f,
|
||||
-1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f,
|
||||
-1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f,
|
||||
-1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f,
|
||||
-1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f,
|
||||
-1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f,
|
||||
-1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f,
|
||||
-1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f,
|
||||
-1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f,
|
||||
-1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f,
|
||||
-1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f,
|
||||
-1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f,
|
||||
-1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f,
|
||||
-1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f,
|
||||
-2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f,
|
||||
-1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f,
|
||||
-1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f,
|
||||
-1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f,
|
||||
-1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f,
|
||||
-1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f,
|
||||
-1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f,
|
||||
-1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f,
|
||||
-1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f
|
||||
});
|
||||
|
||||
auto logits_length = NDArrayFactory::create<int>('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN});
|
||||
std::vector<int> target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2};
|
||||
#if defined(HAVE_CUDNN)
|
||||
//for cudnn blank index is 0. therefore our targets cant be 0
|
||||
for(int i=0;i<target.size();i++){
|
||||
target[i]=target[i]+1;
|
||||
}
|
||||
#endif
|
||||
auto labels = NDArrayFactory::create<int>('c',{BATCH_LEN, MAX_TARGET_LEN}, target );
|
||||
auto labels_len = NDArrayFactory::create<int>('c', {BATCH_LEN}, {MIN_TARGET_LEN, MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1});
|
||||
#if defined(HAVE_CUDNN)
|
||||
//results for blank Index=0
|
||||
auto expected = NDArrayFactory::create<float>('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN},
|
||||
{
|
||||
-0.2673936f, 0.17510113f, 0.16634358f, -0.33129925f, 0.2572481f,
|
||||
-0.17626494f, 0.19112396f, 0.2674601f, -0.44990796f, 0.1675888f,
|
||||
-0.33695614f, 0.27529928f, 0.1543045f, -0.28359637f, 0.19094874f,
|
||||
-0.26243734f, 0.1236309f, 0.13383625f, -0.26430953f, 0.26927972f,
|
||||
-0.33964074f, 0.21812534f, 0.1907491f, -0.3002034f, 0.23096953f,
|
||||
-0.200618f, 0.15514892f, 0.19264314f, -0.3310032f, 0.18382908f,
|
||||
-0.04921098f, 0.21598133f, -0.52588296f, 0.13597165f, 0.22314091f,
|
||||
-0.38300496f, 0.11730913f, -0.2633105f, 0.2825293f, 0.24647695f,
|
||||
-0.34686768f, 0.16539758f, -0.280806f, 0.24202588f, 0.22025016f,
|
||||
-0.21347934f, 0.19306758f, -0.304228f, 0.18027757f, 0.14436226f,
|
||||
0.02692442f, -0.08318196f, -0.2236172f, 0.15634498f, 0.12352975f,
|
||||
0.03155032f, -0.5855137f, 0.14724013f, 0.18989684f, 0.2168265f,
|
||||
0.10374172f, 0.11116405f, -0.67208123f, 0.25178862f, 0.20538692f,
|
||||
0.09189357f, 0.14803931f, 0.00725803f, -0.5132462f, 0.2660552f,
|
||||
-0.4309733f, 0.16058321f, 0.16320339f, -0.21557501f, 0.32276183f,
|
||||
-0.32850766f, 0.2217448f, 0.21034124f, -0.2934553f, 0.18987685f,
|
||||
0.06212101f, 0.1750198f, 0.17887063f, -0.38780046f, -0.02821094f,
|
||||
0.05002825f, 0.19618073f, 0.23872548f, 0.16032055f, -0.64525515f,
|
||||
-0.19972575f, -0.38012666f, 0.20550671f, 0.14981383f, 0.22453187f,
|
||||
-0.02966774f, -0.34505254f, 0.21335125f, -0.00961271f, 0.17098173f,
|
||||
-0.04058227f, -0.03726651f, 0.16733989f, -0.295955f, 0.20646395f,
|
||||
-0.05670565f, 0.12657055f, -0.00966609f, -0.2936089f, 0.23341022f,
|
||||
-0.01142454f, 0.17226583f, -0.2727364f, -0.01445916f, 0.12635438f,
|
||||
-0.23244353f, 0.22339724f, -0.5122685f, 0.29238105f, 0.2289337f
|
||||
});
|
||||
#else
|
||||
auto expected = NDArrayFactory::create<float>('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN},
|
||||
{
|
||||
0.21675213f, 0.17510113f, -0.27113008f, 0.18455505f, -0.30527824f,
|
||||
0.12460334f, 0.19112396f, -0.44803357f, 0.24922381f, -0.11691755f,
|
||||
0.16021198f, 0.27529928f, -0.28298444f, 0.21923551f, -0.37176234f,
|
||||
0.20386407f, 0.1236309f, -0.15528734f, 0.2693891f, -0.44159663f,
|
||||
0.23395306f, 0.21812534f, -0.36457074f, 0.12620285f, -0.21371071f,
|
||||
0.28493422f, 0.15514892f, -0.4384392f, 0.18344463f, -0.18508859f,
|
||||
0.19713868f, -0.61835873f, 0.22776747f, 0.13597165f, 0.05748086f,
|
||||
0.20409954f, -0.17006806f, 0.14958507f, 0.2825293f, -0.46614605f,
|
||||
0.218183f, -0.28762838f, 0.15414338f, 0.24202588f, -0.32672384f,
|
||||
0.09618269f, -0.40792802f, 0.15407808f, 0.18027757f, -0.02261038f,
|
||||
-0.40063405f, -0.04311697f, 0.3049292f, 0.15634498f, -0.01752307f,
|
||||
-0.43639395f, 0.31014743f, 0.14724013f, 0.18989684f, -0.21089047f,
|
||||
0.24827974f, -0.8280775f, 0.1833807f, 0.25178862f, 0.1446285f,
|
||||
0.15652135f, 0.05439584f, -0.5887033f, 0.17083165f, 0.20695446f,
|
||||
0.1825678f, 0.1605832f, -0.04697506f, 0.17088373f, -0.4670597f,
|
||||
0.13331066f, 0.2217448f, -0.46589473f, 0.24472642f, -0.13388708f,
|
||||
0.17822751f, 0.1750198f, -0.27072078f, -0.15830047f, 0.07577389f,
|
||||
0.16730122f, 0.19618073f, 0.23872548f, -0.618405f, 0.01619747f,
|
||||
-0.41614607f, 0.16291247f, 0.20550671f, 0.14981383f, -0.10208681f,
|
||||
-0.32300252f, 0.2421792f, -0.01448151f, 0.15483606f, -0.05953133f,
|
||||
-0.03524604f, 0.1660878f, -0.24423766f, 0.19025035f, -0.07685445f,
|
||||
0.1546654f, 0.00699046f, -0.26606354f, 0.17164008f, -0.06723261f,
|
||||
0.2533586f, -0.31069174f, -0.07983261f, 0.19742766f, -0.06026195f,
|
||||
0.1379485f, -0.47723943f, 0.11733948f, 0.29238105f, -0.07042958
|
||||
});
|
||||
#endif
|
||||
sd::ops::ctc_loss_grad op;
|
||||
|
||||
auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto *gradient = results.at(0);
|
||||
|
||||
//gradient->printIndexedBuffer("gradient");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(gradient));
|
||||
ASSERT_TRUE(expected.equalsTo(gradient, 1.e-06));
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -5411,12 +5411,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -5655,6 +5657,7 @@ public final class TensorNamespace {
|
|||
/**
|
||||
* <pre>
|
||||
* Tensors
|
||||
*
|
||||
* A serialized tensor value.
|
||||
* </pre>
|
||||
*
|
||||
|
@ -7010,12 +7013,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -7766,6 +7771,7 @@ public final class TensorNamespace {
|
|||
/**
|
||||
* <pre>
|
||||
* Tensors
|
||||
*
|
||||
* A serialized tensor value.
|
||||
* </pre>
|
||||
*
|
||||
|
@ -9080,12 +9086,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -9102,12 +9110,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
@ -9130,12 +9140,14 @@ public final class TensorNamespace {
|
|||
* Serializations can either use one of the fields above, or use this
|
||||
* raw bytes field. The only exception is the string case, where one is
|
||||
* required to store the content in the repeated bytes string_data field.
|
||||
*
|
||||
* When this raw_data field is used to store tensor value, elements MUST
|
||||
* be stored in as fixed-width, little-endian order.
|
||||
* Floating-point data types MUST be stored in IEEE 754 format.
|
||||
* Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
* Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
* Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
*
|
||||
* Note: the advantage of specific field rather than the raw_data field is
|
||||
* that in some cases (e.g. int data), protobuf does a better packing via
|
||||
* variable length storage, and may lead to smaller binary footprint.
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.loss;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLoss extends BaseLoss {
|
||||
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, label);
|
||||
}
|
||||
|
||||
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLoss(){ }
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "ctc_loss";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
//No external gradient
|
||||
//Args are: predictions, weights, label
|
||||
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.loss.bp;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLossBp extends BaseLossBp {
|
||||
|
||||
|
||||
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLossBp(){ }
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "ctc_loss_grad";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
|
||||
}
|
||||
|
||||
}
|
|
@ -391,7 +391,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine>-Dfile.encoding=UTF-8 "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -107,7 +107,7 @@
|
|||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -99,7 +99,7 @@
|
|||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -125,7 +125,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -103,7 +103,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> -Dfile.encoding=UTF-8 "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -159,7 +159,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -114,7 +114,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
<argLine> "</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
Loading…
Reference in New Issue