diff --git a/datavec/pom.xml b/datavec/pom.xml index 2556c9782..1ec358c4b 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -159,7 +159,7 @@ maven-surefire-plugin ${maven-surefire-plugin.version} - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index 7a6ce9ef5..7b112e23f 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -230,7 +230,7 @@ --> true false - -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + -Dfile.encoding=UTF-8 -Xmx8g " *.java @@ -331,7 +331,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/libnd4j/include/ops/declarable/generic/loss/ctcLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/ctcLoss.cpp new file mode 100644 index 000000000..d37c16233 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/loss/ctcLoss.cpp @@ -0,0 +1,134 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + + +#include +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) { + + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputLosses = OUTPUT_VARIABLE(0); + + int blankIndex = INT_ARG(0); + + REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf()); + REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf()); + REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf()); + REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf()); + + int batchSize0 = targetLabels->sizeAt(0); + int batchSize1 = logitInput->sizeAt(0); + int batchSize2 = targetLabelLengths->sizeAt(0); + int batchSize3 = logitInputLengths->sizeAt(0); + int batchSize4 = outputLosses->sizeAt(0); + + bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3); + check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2); + + REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0); + REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str()); + + auto emptyGradients = NDArrayFactory::empty(); + sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGradients, blankIndex); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(ctc_loss) { + getOpDescriptor()->setAllowedInputTypes({ALL_INDICES}) + ->setAllowedInputTypes(1,{ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(ctc_loss) { + auto yShapeInfo = inputShape->at(1); + auto zShapeInfo = inputShape->at(2); + + auto dtype = ArrayOptions::dataType(yShapeInfo); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(zShapeInfo, dtype))); +} + + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) { + + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputGradients = OUTPUT_VARIABLE(0); + + int blankIndex = INT_ARG(0); + + REQUIRE_TRUE(targetLabels->rankOf()==2, 0, "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ", targetLabels->rankOf()); + REQUIRE_TRUE(logitInput->rankOf()==3, 0, "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ", logitInput->rankOf()); + REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf()); + REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf()); + + int batchSize0 = targetLabels->sizeAt(0); + int batchSize1 = logitInput->sizeAt(0); + int batchSize2 = targetLabelLengths->sizeAt(0); + int batchSize3 = logitInputLengths->sizeAt(0); + int batchSize4 = outputGradients->sizeAt(0); + + bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3); + check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2); + + REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0); + REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str()); + + auto emptyLoss = NDArrayFactory::empty(); + sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, emptyLoss, *outputGradients, blankIndex); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(ctc_loss_grad) { + getOpDescriptor()->setAllowedInputTypes({ALL_INDICES}) + ->setAllowedInputTypes(1,{ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(ctc_loss_grad) { + auto yShapeInfo = inputShape->at(1); + auto dtype = ArrayOptions::dataType(yShapeInfo); + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(yShapeInfo, dtype))); + +} + + + +} +} diff --git a/libnd4j/include/ops/declarable/headers/loss.h b/libnd4j/include/ops/declarable/headers/loss.h index 2e4018096..303470792 100644 --- a/libnd4j/include/ops/declarable/headers/loss.h +++ b/libnd4j/include/ops/declarable/headers/loss.h @@ -360,6 +360,27 @@ namespace ops { DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0); #endif + /** + * Implementation of CTC loss function + * + * Input arrays: + * 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer + * 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float + * 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer + * 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer + * + * + * Input integer arguments: + * 0: blank index - index of the blank label in logits + * + * Output array: + * 0: loss values, type float. NDArray {BATCH_LEN} negative log probabilities of loss + */ + #if NOT_EXCLUDED(OP_ctc_loss) + DECLARE_CUSTOM_OP(ctc_loss, 4, 1, false, 0, 1); + DECLARE_CUSTOM_OP(ctc_loss_grad, 4, 1, false, 0, 1); + #endif + } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ctcLoss.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ctcLoss.cpp new file mode 100644 index 000000000..1920167f2 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/ctcLoss.cpp @@ -0,0 +1,507 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sd +{ + namespace ops + { + namespace helpers + { + + //choose ptr[index*element_stride] + template + typename std::enable_if::type + element(Type *ptr, int index, int element_stride) + { + return ptr[index * element_stride]; + } + + //choose ptr[index] assuming element_stride is 1 + template + typename std::enable_if::type + element(Type *ptr, int index, int element_stride) + { + return ptr[index]; + } + + template + Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1) + { + Type negInf = -DataTypeUtils::infOrMax(); + //initialize alphas at t=0 + alphaPtr[0] = element(logP, blankIndex, elwiseP); + //alphaPtr[1] =logP[lbl[0]]; + alphaPtr[1] = element(logP, *lbl, elwiseP); + //the rest initialization was skipped + //as its assumed the array already were initialized with negative infinity + //move to the next frame + Type *alphaPrevPtr = alphaPtr; + alphaPtr += incA; + logP += incP; + + auto startX = lenSB - 2 * lenT; + //process the rest + for (auto t = 1; t < lenT; t++) + { + + //start = max(0,L-2*(T-t)) + auto s = startX + 2 * t; + s = s > 0 ? s : 0; + for (; s < lenSB; s++) + { + auto ind = s / 2; //our real index + //we force blanks for even indexes + //strided version of lbl[ind] => element(lbl, ind, elwiseS) + auto currentInd = (s % 2 == 0) ? blankIndex : element(lbl, ind, elwiseS); + // {t-1,s} + Type alphaS = alphaPrevPtr[s]; + Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf; + Type cMax = std::max(alphaS, alphaS_1); + //logP[currentInd] or logP[currentInd*elwiseP] + auto currentProb = element(logP, currentInd, elwiseP); + // if blank or the same as previous + if (s > 1 && currentInd != blankIndex && currentInd != element(lbl, ind - 1, elwiseS)) + { + Type alphaS_2 = alphaPrevPtr[s - 2]; + cMax = std::max(cMax, alphaS_2); + if (cMax == negInf) + cMax = 0; + alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb; + } + else + { + if (cMax == negInf) + cMax = 0; + alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb; + } + } + + //store t-1 alpha Ptr + alphaPrevPtr = alphaPtr; + logP += incP; + alphaPtr += incA; + } + auto logP0 = alphaPrevPtr[lenSB - 1]; + auto logP1 = alphaPrevPtr[lenSB - 2]; + auto cMax = std::max(logP0, logP1); + return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax); + } + +//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP + + template + void backwardAndGrad(Type forwardLogLoss, Type *alphaPtr, Type *bettaPtr, int incA,const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, + const Nd4jLong &lenS, const Nd4jLong &lenT, const Nd4jLong &lenK, const int &blankIndex, + int elwiseP = 1, int elwiseS = 1, int elwiseG = 1) + { + + Type negInf = -DataTypeUtils::infOrMax(); + Nd4jLong lenSB = 2 * lenS + 1; + auto origBetta = bettaPtr; + auto origLogP = logP; + //move to the last frame + bettaPtr += (lenT - 1) * incA; + logP += (lenT - 1) * incP; + + //initialize bettas at t=lenT + bettaPtr[lenSB - 1] = element(logP, blankIndex, elwiseP); + auto lblIndex = element(lbl, lenS - 1, elwiseS); + bettaPtr[lenSB - 2] = element(logP, lblIndex, elwiseP); // logP[lbl[lenS - 1]]; + +#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) + //move to the last + gradPtr += (lenT - 1) * incG; + alphaPtr += (lenT - 1) * incA; + for (auto s = lenSB - 1; s >= 0; s--) + { + auto ind = s / 2; //our real index + //we forced blanks for even indexes + auto currentInd = (s % 2 == 0) ? blankIndex : element(lbl, ind, elwiseS); + //alpha(s)*betta(s) in log scale but still store in alpha to save memory + auto alphaBettaS = alphaPtr[s] + bettaPtr[s]; + + //sum (alpha(s)*betta(s) ) over real indexes + auto ¤tGrad = element(gradPtr, currentInd, elwiseG); // gradPtr[currentInd]; + if (currentGrad == negInf) + { + currentGrad = alphaBettaS; + } + else + { + Type cMax = std::max(currentGrad, alphaBettaS); + currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax; + } + } + for (int k = 0; k < lenK; k++) + { + //compute the rest grad + + // prob(t,k) - grad(k) / ((prob(t,k)*Z) ) + + // p2= grad(k) / (prob(t,k)*Z ) + //in logscale . plus we have Z as -logLoss + // auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]); + // gradPtr[k] = std::exp(logP[k]) - p2; + auto currentProb = element(logP, k, elwiseP); + auto ¤tGrad = element(gradPtr, k, elwiseG); + auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb); + currentGrad = std::exp(currentProb) - p2; + } + gradPtr -= incG; + alphaPtr -= incA; +#endif + + auto bettaPrevPtr = bettaPtr; + bettaPtr -= incA; + logP -= incP; + //process the rest + for (auto t = lenT - 2; t >= 0; t--) + { + +#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) + auto end = lenSB - 1; +#else + auto end = std::min(2 * t + 2, lenSB - 1); +#endif + for (auto s = end; s >= 0; s--) + { + auto ind = s / 2; //our real index + //we forced blanks for even indexes + auto currentInd = (s % 2 == 0) ? blankIndex : element(lbl, ind, elwiseS); //lbl[ind]; + // {t-1,s} + Type bettaS = bettaPrevPtr[s]; + Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf; + Type cMax = std::max(bettaS, bettaS_1); + //logP[currentInd] + auto currentProb = element(logP, currentInd, elwiseP); + // if blank or the same as previous + if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element(lbl, ind + 1, elwiseS)) + { + Type bettaS_2 = bettaPrevPtr[s + 2]; + cMax = std::max(cMax, bettaS_2); + if (cMax == negInf) + cMax = 0; + bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb; + } + else + { + if (cMax == negInf) + cMax = 0; + bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb; + } + +#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) + //alpha(s)*betta(s) in log scale but still store in alpha to save memory + auto alphaBettaS = alphaPtr[s] + bettaPtr[s]; + + //sum (alpha(s)*betta(s) ) over real indexes + auto ¤tGrad = element(gradPtr, currentInd, elwiseG); // gradPtr[currentInd]; + if (currentGrad == negInf) + { + currentGrad = alphaBettaS; + } + else + { + Type cMax = std::max(currentGrad, alphaBettaS); + currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax; + } + +#endif + } + +#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) + for (int k = 0; k < lenK; k++) + { + //compute the rest grad + + // prob(t,k) - grad(k) / ((prob(t,k)*Z) ) + + // p2= grad(k) / (prob(t,k)*Z ) + //in logscale . plus we have Z as -logLoss + // auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]); + // gradPtr[k] = std::exp(logP[k]) - p2; + auto currentProb = element(logP, k, elwiseP); + auto ¤tGrad = element(gradPtr, k, elwiseG); + auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb); + currentGrad = std::exp(currentProb) - p2; + } + alphaPtr -= incA; + gradPtr -= incG; +#endif + + bettaPrevPtr = bettaPtr; + bettaPtr -= incA; + logP -= incP; + } + + auto logBP0 = bettaPrevPtr[0]; + auto logBP1 = bettaPrevPtr[1]; + auto bcMax = std::max(logBP0, logBP1); + auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax); + +#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) + //alpha*betta + bettaPtr = origBetta; + logP = origLogP; + + for (int t = 0; t < lenT; t++) + { + + for (int s = 0; s < lenSB; s++) + { + auto ind = s / 2; //our real index + //we forced blanks for even indexes + auto currentInd = (s % 2 == 0) ? blankIndex : element(lbl, ind, elwiseS); //lbl[ind]; + //alpha(s)*betta(s) in log scale but still store in alpha to save memory + auto alphaBettaS = alphaPtr[s] + bettaPtr[s]; + + //sum (alpha(s)*betta(s) ) over real indexes + auto ¤tGrad = element(gradPtr, currentInd, elwiseG); // gradPtr[currentInd]; + if (currentGrad == negInf) + { + currentGrad = alphaBettaS; + } + else + { + Type cMax = std::max(currentGrad, alphaBettaS); + currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax; + } + //alphaPtr[s] = alphaBettaS; + } + + PRAGMA_OMP_SIMD + for (int k = 0; k < lenK; k++) + { + //compute the rest grad + + // prob(t,k) - grad(k) / ((prob(t,k)*Z) ) + + // p2= grad(k) / (prob(t,k)*Z ) + //in logscale . plus we have Z as -logLoss + // auto p2 = std::exp(gradPtr[k] + forwardLogLoss - logP[k]); + // gradPtr[k] = std::exp(logP[k]) - p2; + auto currentProb = element(logP, k, elwiseP); + auto ¤tGrad = element(gradPtr, k, elwiseG); + auto p2 = std::exp(currentGrad + forwardLogLoss - currentProb); + currentGrad = std::exp(currentProb) - p2; + } + + gradPtr += incG; + bettaPtr += incA; + alphaPtr += incA; + logP += incP; + } +#endif + } + +/** + * Calculates ctc loss and fills gradients + * @param logP logits matrix(lenT,lenK) pointer (log soft max input of rnn) + * @param incP stride of logits for the next time frame + * @param gradPtr gradient for output + * @param incG stride of the gradient for the next time frame + * @param lbl target label + * @param lenT frame length + * @param lenK class length + * @param lenS target label length + * @param blankIndex index of the blank label in logit class +*/ + template + Type unitLossAndGrad(const Type *logP, int incP, Type *gradPtr, int incG,const IndexType *lbl, int lenT, int lenK, int lenS, int blankIndex, + int elwiseP = 1, int elwiseS = 1, int elwiseG = 1) + { + + auto lenSB = 2 * lenS + 1; + //create temp Array for holding bettaArr [lenT,lenSB] + //create temp Array for holding alphaArr [lenT,lenSB] + int bufferC = gradPtr ? 2 : 1; + NDArray bufferArr = NDArrayFactory::create('c', {bufferC, lenT, lenSB}); + auto bufferPtr = bufferArr.bufferAsT(); + auto incA = bufferArr.stridesOf()[1]; + auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0]; + Type negInf = -DataTypeUtils::infOrMax(); + +#if 1 + if (gradPtr) + { + if (elwiseG == 1) + { + PRAGMA_OMP_SIMD + for (int i = 0; i < lenK * lenT; i++) + { + gradPtr[i] = negInf; + } + } + else + { + auto tempPtr = gradPtr; + for (int i = 0; i < lenT; i++) + { + for (int j = 0; j < lenK; j++) + element(tempPtr, j, elwiseG) = negInf; + tempPtr += incG; + } + } + } +#endif + + // set all vals to neginf + PRAGMA_OMP_SIMD + for (int i = 0; i < bufferC * lenSB * lenT; i++) + { + bufferPtr[i] = negInf; + } + + //forward + Type logLoss = forward(bufferPtr, incA, logP, incP, lbl, lenSB, lenT, blankIndex, elwiseP, elwiseS); + //backward and gradient if gradptr supplied + if (gradPtr) + backwardAndGrad(logLoss, bufferPtr, bettaBufferPtr, incA, logP, incP, gradPtr, incG, lbl, lenS, lenT, lenK, blankIndex, elwiseP, elwiseS, elwiseG); + return logLoss; + } + + template + void + ctc_loss_(const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex) + { + // lenT - input length of T + // lenS - lenght of sequence + // lenSB - length with blanks + auto lenBatch = logits.shapeOf()[0]; + + auto maxLenT = logits.shapeOf()[1]; + auto lenK = logits.shapeOf()[2]; + auto maxLenS = targetLabels.shapeOf()[1]; + + // get probability bufer and tagetLabels buffer + auto logP = logits.bufferAsT(); + auto lblPtr = targetLabels.bufferAsT(); + + auto lenTPtr = logitsLengths.bufferAsT(); + auto lenSPtr = targetLabelLengths.bufferAsT(); + + auto batchLbl = targetLabels.stridesOf()[0]; + auto batchP = logits.stridesOf()[0]; + auto incP = logits.stridesOf()[1]; + + auto elwiseSLen = targetLabelLengths.stridesOf()[0]; + auto elwiseT = logitsLengths.stridesOf()[0]; + auto elwiseS = targetLabels.stridesOf()[1]; + auto elwiseP = logits.stridesOf()[2]; + + int elwiseLL = 0; + Type *logLossPtr = nullptr; + if (!logLosses.isEmpty()){ + elwiseLL = logLosses.stridesOf()[0]; + logLossPtr = logLosses.bufferAsT(); + } + + auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS, + batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { + Type *gradPtr = nullptr; + Type resultLoss; + int batchG, incG, elwiseG; + if (!gradients.isEmpty()) + { + batchG = gradients.stridesOf()[0]; + incG = gradients.stridesOf()[1]; + elwiseG = gradients.stridesOf()[2]; + gradPtr = gradients.bufferAsT() + start * batchG; + }else{ + elwiseG=1; + } + auto logPtr = logP + start * batchP; + auto tempLblPtr = lblPtr + start * batchLbl; + + if (elwiseP == 1 && elwiseS == 1 && elwiseG == 1) + { + //choose ews one + for (int batchIndex = start; batchIndex < stop; batchIndex += increment) + { + auto lenT = lenTPtr[batchIndex * elwiseT]; + auto lenS = lenSPtr[batchIndex * elwiseSLen]; + lenT = lenT > maxLenT ? maxLenT : lenT; + lenS = lenS > maxLenS ? maxLenS : lenS; + if (lenS <= 0 || lenT <= 0) + { + resultLoss = -DataTypeUtils::infOrMax(); + } + else + { + if (lenS > lenT) + lenS = lenT; + resultLoss = unitLossAndGrad(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex); + } + if (gradPtr) gradPtr += batchG; + if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss; + logPtr += batchP; + tempLblPtr += batchLbl; + } + } + else + { + //slow strided case for all 3 + for (int batchIndex = start; batchIndex < stop; batchIndex += increment) + { + auto lenT = lenTPtr[batchIndex * elwiseT]; + auto lenS = lenSPtr[batchIndex * elwiseSLen]; + lenT = lenT > maxLenT ? maxLenT : lenT; + lenS = lenS > maxLenS ? maxLenS : lenS; + if (lenS <= 0 || lenT <= 0) + { + resultLoss = -DataTypeUtils::infOrMax(); + } + else + { + if (lenS > lenT) + lenS = lenT; + resultLoss = unitLossAndGrad(logPtr, incP, gradPtr, incG, tempLblPtr, lenT, lenK, lenS, blankIndex, elwiseP, elwiseS, elwiseG); + } + if (gradPtr) gradPtr += batchG; + if (logLossPtr) logLossPtr[batchIndex * elwiseLL] = resultLoss; + logPtr += batchP; + tempLblPtr += batchLbl; + } + } + }; + samediff::Threads::parallel_for(func, 0, lenBatch, 1); + } + + void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){ + + BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES); + } + + + BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES); + + + } // namespace helpers + } // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/ctcLoss.h b/libnd4j/include/ops/declarable/helpers/ctcLoss.h new file mode 100644 index 000000000..320442456 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/ctcLoss.h @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + +#ifndef LIBND4J_HELPERS_CTCLOSS_H +#define LIBND4J_HELPERS_CTCLOSS_H + +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + + /** + * @brief Implementation of CTC loss function + * References: + Connectionist Temporal Classification - Labeling Unsegmented Sequence Data + with Recurrent Neural Networks: + [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) + ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) + * + * @param block Context + * @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well. + * @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN} + * @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits + * @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels + * @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss + * @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients + * @param blankIndex index of the blank label in logits + */ + void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex); + +} +} +} + + +#endif // LIBND4J_ADDBIAS_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ctcLoss.cu b/libnd4j/include/ops/declarable/helpers/cuda/ctcLoss.cu new file mode 100644 index 000000000..ca5f9b842 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/ctcLoss.cu @@ -0,0 +1,225 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + // + // @author AbdelRauf + // + +#include "cudnnUtils.h" +#include +#include + + +namespace sd { +namespace ops { +namespace platforms { + + + + template + void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){ + if(err==CUDNN_STATUS_SUCCESS){ + err = op(std::forward(args)...); + if(err){ + nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err)); + } + } + } + + template + const T* bufferInHost( const NDArray &array) { + array.syncToHost(); + return reinterpret_cast(array.buffer()); + } + + std::vector getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){ + //concatenate target labels + const int32_t *tlabels = bufferInHost(targetLabels); + const int32_t *tlens =bufferInHost(targetLabelLengths); + int32_t nextOffset = targetLabels.strideAt(0); + int32_t elStride = targetLabels.strideAt(1); + int32_t batchCount = targetLabelLengths.lengthOf(); + std::vector labels; + labels.resize(targetLabels.lengthOf()); + int j=0; + if(targetLabels.ews()){ + for(int i=0; i(context.getCuDnnHandle()); + cudnnStatus_t err = CUDNN_STATUS_SUCCESS; + callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream()); + + cudnnCTCLossDescriptor_t ctcLossDesc; + cudnnTensorDescriptor_t probsDesc = nullptr; + cudnnTensorDescriptor_t gradsDesc = nullptr; + callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc); + callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN); + callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc); + callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides); + if(!grads.isEmpty()){ + const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)}; + callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc); + callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides); + } + + size_t tempWorkSpaceSize=0; + callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle, probsDesc, gradsDesc, + targetLabelsPtr, + bufferInHost(targetLabelLengths), + bufferInHost(probInputLengthes), + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + ctcLossDesc, &tempWorkSpaceSize); + + // Allocate temp tempWorkspace buffer + void *tempWorkSpace = nullptr; + cudaMalloc(&tempWorkSpace, tempWorkSpaceSize); + + NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs}); + callCudnnIfNoErr(err, cudnnCTCLoss,*handle, + probsDesc, + probs.specialBuffer(), + targetLabelsPtr, + bufferInHost(targetLabelLengths), + bufferInHost(probInputLengthes), + ctcLosses.specialBuffer(), + gradsDesc, + grads.specialBuffer(), + CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + ctcLossDesc, + tempWorkSpace, + tempWorkSpaceSize); + + NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs}); + + cudaFree(tempWorkSpace); + callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc); + if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc); + callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc); + return err; + } + + PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) { + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputLosses = OUTPUT_VARIABLE(0); + auto context = block.launchContext(); + //in Cudnn Batch is in the middle dimension + logitInput->permutei({1,0,2}); + //in Cudnn targets are concantenated instead of batched as matrix + auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); + const int32_t *ldata= labels.data(); + auto emptyGrads= NDArrayFactory::empty(); + auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads); + if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); + return Status::OK(); + } + + template + bool checkLabelLength(const NDArray &labelLengthArr){ + //check label lengthes + auto lenBatch = labelLengthArr.lengthOf(); + for(int i=0; i < lenBatch; i++){ + // The labelLengths is greater than 256. + if(labelLengthArr.e(i)>256) return false; + } + return true; + } + + PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) { + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputLosses = OUTPUT_VARIABLE(0); + int blankIndex = INT_ARG(0); + + auto dTypeInput = logitInput->dataType(); + auto intType = targetLabelLengths->dataType(); + auto dTypeOutput = outputLosses->dataType(); + + bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32; + is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); + is_supported = is_supported && checkLabelLength(*targetLabelLengths); + return is_supported; + } + + PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) { + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputGradients = OUTPUT_VARIABLE(0); + auto context = block.launchContext(); + //in Cudnn Batch is in the middle dimension + logitInput->permutei({1,0,2}); + outputGradients->permutei({1,0,2}); + //in Cudnn targets are concantenated instead of batched as matrix + auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); + const int32_t * ldata= labels.data(); + auto tempLosses = NDArrayFactory::create('c', {logitInputLengths->sizeAt(0)}); + auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients); + if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); + //restore grads shape from {T, BATCH, C} -> {BATCHS, T, C} + outputGradients->permutei({1,0,2}); + //tempLosses.printIndexedBuffer("tempLosses"); + return Status::OK(); + } + + PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) { + auto targetLabels = INPUT_VARIABLE(0); + auto logitInput = INPUT_VARIABLE(1); + auto targetLabelLengths = INPUT_VARIABLE(2); + auto logitInputLengths = INPUT_VARIABLE(3); + auto outputGrads = OUTPUT_VARIABLE(0); + int blankIndex = INT_ARG(0); + + auto dTypeInput = logitInput->dataType(); + auto intType = targetLabelLengths->dataType(); + auto dTypeOutput = outputGrads->dataType(); + + bool is_supported = blankIndex==0 && intType == DataType::INT32 && dTypeInput == DataType::FLOAT32; + is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); + is_supported = is_supported && checkLabelLength(*targetLabelLengths); + return is_supported; + } + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h index 103dd0f5d..cfb822428 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -60,6 +60,9 @@ namespace platforms { DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); + DECLARE_PLATFORM(ctc_loss, ENGINE_CUDA); + DECLARE_PLATFORM(ctc_loss_grad, ENGINE_CUDA); + ////////////////////////////////////////////////////////////////////////// FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) { switch (dataType) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 74bce1ac1..7670f1844 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -21,8 +21,7 @@ #include #include #include - - +#include using namespace sd; using namespace sd::graph; @@ -127,7 +126,7 @@ TEST_F(DeclarableOpsTests2, gather_5) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -167,7 +166,7 @@ TEST_F(DeclarableOpsTests2, gather_7) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -233,7 +232,7 @@ TEST_F(DeclarableOpsTests2, gather_11) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -251,7 +250,7 @@ TEST_F(DeclarableOpsTests2, gather_12) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -284,7 +283,7 @@ TEST_F(DeclarableOpsTests2, gather_13) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -321,7 +320,7 @@ TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { auto result = op.evaluate({&input, &indices}, {}, {}); ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); - + } TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { @@ -373,7 +372,7 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { ASSERT_EQ(exp1, row_s1_4); ASSERT_EQ(exp1, row_s1_5); ASSERT_EQ(exp2, row_s1_6); - + } TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { @@ -390,7 +389,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -408,7 +407,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { ASSERT_TRUE(exp->isSameShape(z)); ASSERT_TRUE(exp->equalsTo(z)); - + delete exp; } @@ -426,7 +425,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { @@ -718,7 +717,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 0.f); - + } @@ -743,7 +742,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 60.); - + } @@ -768,7 +767,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 60.f); - + } @@ -793,7 +792,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.f); - + } @@ -818,7 +817,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 0.f); - + } @@ -843,7 +842,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.f); - + } @@ -870,7 +869,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.f); - + } @@ -895,7 +894,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 2.f); - + } @@ -924,7 +923,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 2.01667, 1e-5); - + } @@ -957,7 +956,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.93333, 1e-5); - + } @@ -990,7 +989,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.93333f, 1e-5); - + } @@ -1016,7 +1015,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.); - + } @@ -1041,7 +1040,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.); - + } @@ -1066,7 +1065,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 1.f); - + } @@ -1091,7 +1090,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 0.); - + } @@ -1128,7 +1127,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.965517, 1e-5); - + } @@ -1154,7 +1153,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -1180,7 +1179,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -1207,7 +1206,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -1232,7 +1231,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -1257,7 +1256,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -71.); - + } @@ -1282,7 +1281,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -71.f); - + } @@ -1307,7 +1306,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -69.f); - + } @@ -1332,7 +1331,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -24.f); - + } @@ -1357,7 +1356,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -24.); - + } @@ -1384,7 +1383,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == -32.); - + } @@ -1410,7 +1409,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1435,7 +1434,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1460,7 +1459,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1484,7 +1483,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test4) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 83.); - + } /////////////////////////////////////////////////////////////////// @@ -1508,7 +1507,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 83.); - + } /////////////////////////////////////////////////////////////////// @@ -1532,7 +1531,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 83.); - + } /////////////////////////////////////////////////////////////////// @@ -1556,7 +1555,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 6.91667, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1580,7 +1579,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 6.91667, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1604,7 +1603,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 6.91667, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1628,7 +1627,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 3.45833, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1652,7 +1651,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 3.45833, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1680,7 +1679,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test12) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 3.975, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1704,7 +1703,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test13) { ASSERT_TRUE(result->isScalar()); ASSERT_TRUE(result->e(0) == 0.); - + } /////////////////////////////////////////////////////////////////// @@ -1729,7 +1728,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1754,7 +1753,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1779,7 +1778,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -1803,7 +1802,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test4) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 13.44, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1827,7 +1826,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 13.44, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1851,7 +1850,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.12, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1875,7 +1874,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.12, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1903,7 +1902,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.3, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1927,7 +1926,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.56, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1951,7 +1950,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.56, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -1979,7 +1978,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.65, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2004,7 +2003,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2029,7 +2028,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2054,7 +2053,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2078,7 +2077,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test4) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -113.886429, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2102,7 +2101,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -113.886429, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2126,7 +2125,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -113.886429, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2150,7 +2149,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -9.490536, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2174,7 +2173,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -9.490536, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2198,7 +2197,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -9.490536, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2226,7 +2225,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -12.443609, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2250,7 +2249,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -4.745268, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2274,7 +2273,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test12) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -4.745268, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2302,7 +2301,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test13) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -6.221805, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2322,7 +2321,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2342,7 +2341,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2362,7 +2361,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2381,7 +2380,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 60.74394998193965, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2400,7 +2399,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 15.189082270182983, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2419,7 +2418,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 13.568564090650312, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2438,7 +2437,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 198.318201904499, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2457,7 +2456,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 10.709003499121707, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2476,7 +2475,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 17.686067864414472, 1e-5); - + } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { @@ -2500,7 +2499,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2525,7 +2524,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2550,7 +2549,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2579,7 +2578,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2603,7 +2602,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 612.5, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2627,7 +2626,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 612.5, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2651,7 +2650,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 612.5, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2679,7 +2678,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 608.75, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2703,7 +2702,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 51.041668, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2727,7 +2726,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 51.041668, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2751,7 +2750,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 51.041668, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2778,7 +2777,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 88.541664, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2802,7 +2801,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 25.520834, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2826,7 +2825,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 25.520834, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2850,7 +2849,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 25.520834, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2877,7 +2876,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 44.270832, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -2901,7 +2900,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2925,7 +2924,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2949,7 +2948,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2973,7 +2972,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -2996,7 +2995,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3019,7 +3018,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3042,7 +3041,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3065,7 +3064,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 10.2187976837, 1e-5); - + } @@ -3092,7 +3091,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 6.06840181351, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3115,7 +3114,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3138,7 +3137,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3161,7 +3160,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.851566493511, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3187,7 +3186,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 1.01140034199, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3210,7 +3209,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3233,7 +3232,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3256,7 +3255,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.425783246756, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3282,7 +3281,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 0.505700170994, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3305,7 +3304,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3329,7 +3328,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3353,7 +3352,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3377,7 +3376,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3401,7 +3400,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3424,7 +3423,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), 8.55521392822, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3447,7 +3446,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3470,7 +3469,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3493,7 +3492,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3516,7 +3515,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -2.12338066101, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3539,7 +3538,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -1.06169033051, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3562,7 +3561,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { ASSERT_TRUE(result->isScalar()); ASSERT_NEAR(result->e(0), -2.18880319595, 1e-5); - + } /////////////////////////////////////////////////////////////////// @@ -3586,7 +3585,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } @@ -3612,7 +3611,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3636,7 +3635,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } /////////////////////////////////////////////////////////////////// @@ -3681,7 +3680,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test1) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3726,7 +3725,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test2) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3771,7 +3770,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test3) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3816,7 +3815,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test4) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3861,7 +3860,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test5) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3906,7 +3905,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test6) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -3951,7 +3950,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test7) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } @@ -3997,7 +3996,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test8) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct,1e-4)); - + } /////////////////////////////////////////////////////////////////// @@ -4042,7 +4041,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test9) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -4087,7 +4086,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test10) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -4132,7 +4131,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test11) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } /////////////////////////////////////////////////////////////////// @@ -4177,5 +4176,210 @@ TEST_F(DeclarableOpsTests2, lstmCell_test12) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } + +#if !defined(__CUDABLAS__) || defined(HAVE_CUDNN) +TEST_F(DeclarableOpsTests2, ctc_loss_test1) { + constexpr int FRAME_LEN = 6 ; + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 4 ; + constexpr int MIN_TARGET_LEN = 2; + constexpr int MAX_TARGET_LEN = 4; + +#if defined(HAVE_CUDNN) +//cudnn blankindex should be 0 + constexpr int BLANK_INDEX=0; +#else + constexpr int BLANK_INDEX=CLASS_LEN-1; +#endif + //logits were generated using numpy random and applying log softmax + //[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4) + auto logits = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN }, + {-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f, + -2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f, + -1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f, + -1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f, + -1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f, + -1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f, + -1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f, + -1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f, + -1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f, + -1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f, + -1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f, + -1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f, + -1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f, + -1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f, + -1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f, + -2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f, + -1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f, + -1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f, + -1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f, + -1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f, + -1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f, + -1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f, + -1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f, + -1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f + }); + + auto logits_length = NDArrayFactory::create('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN}); + std::vector target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2}; +#if defined(HAVE_CUDNN) + //for cudnn blank index is -. therefore our targets cant be 0 + for(int i=0;i('c',{BATCH_LEN, MAX_TARGET_LEN}, target ); + + auto labels_len = NDArrayFactory::create('c', {BATCH_LEN}, {MIN_TARGET_LEN,MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1}); + +#if defined(HAVE_CUDNN) + auto expected = NDArrayFactory::create('c', {BATCH_LEN}, {6.088762f, 5.9546056f, 7.5806675f, 5.5532417f}); +#else + auto expected = NDArrayFactory::create('c', {BATCH_LEN}, {6.0661564f, 6.4285727f, 7.7180986f, 4.936057f}); +#endif + sd::ops::ctc_loss op; + + //logits.printIndexedBuffer("logits"); + //labels.printIndexedBuffer("labels"); + + auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *loss = results.at(0); + + //loss->printIndexedBuffer("loss"); + + ASSERT_TRUE(expected.isSameShape(loss)); + ASSERT_TRUE(expected.equalsTo(loss)); + +} + + +TEST_F(DeclarableOpsTests2, ctc_loss_grad_test1) { + constexpr int FRAME_LEN = 6 ; + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 4 ; + constexpr int MAX_TARGET_LEN = 4; + constexpr int MIN_TARGET_LEN = 2; +#if defined(HAVE_CUDNN) +//cudnn blankindex should be 0 + constexpr int BLANK_INDEX=0; +#else + constexpr int BLANK_INDEX=CLASS_LEN-1; +#endif + //logits were generated using numpy random and applying log softmax + //[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4) + auto logits = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN }, + {-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f, + -2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f, + -1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f, + -1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f, + -1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f, + -1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f, + -1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f, + -1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f, + -1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f, + -1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f, + -1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f, + -1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f, + -1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f, + -1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f, + -1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f, + -2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f, + -1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f, + -1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f, + -1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f, + -1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f, + -1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f, + -1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f, + -1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f, + -1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f + }); + + auto logits_length = NDArrayFactory::create('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN}); + std::vector target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2}; +#if defined(HAVE_CUDNN) + //for cudnn blank index is 0. therefore our targets cant be 0 + for(int i=0;i('c',{BATCH_LEN, MAX_TARGET_LEN}, target ); + auto labels_len = NDArrayFactory::create('c', {BATCH_LEN}, {MIN_TARGET_LEN, MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1}); +#if defined(HAVE_CUDNN) +//results for blank Index=0 + auto expected = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN}, + { + -0.2673936f, 0.17510113f, 0.16634358f, -0.33129925f, 0.2572481f, + -0.17626494f, 0.19112396f, 0.2674601f, -0.44990796f, 0.1675888f, + -0.33695614f, 0.27529928f, 0.1543045f, -0.28359637f, 0.19094874f, + -0.26243734f, 0.1236309f, 0.13383625f, -0.26430953f, 0.26927972f, + -0.33964074f, 0.21812534f, 0.1907491f, -0.3002034f, 0.23096953f, + -0.200618f, 0.15514892f, 0.19264314f, -0.3310032f, 0.18382908f, + -0.04921098f, 0.21598133f, -0.52588296f, 0.13597165f, 0.22314091f, + -0.38300496f, 0.11730913f, -0.2633105f, 0.2825293f, 0.24647695f, + -0.34686768f, 0.16539758f, -0.280806f, 0.24202588f, 0.22025016f, + -0.21347934f, 0.19306758f, -0.304228f, 0.18027757f, 0.14436226f, + 0.02692442f, -0.08318196f, -0.2236172f, 0.15634498f, 0.12352975f, + 0.03155032f, -0.5855137f, 0.14724013f, 0.18989684f, 0.2168265f, + 0.10374172f, 0.11116405f, -0.67208123f, 0.25178862f, 0.20538692f, + 0.09189357f, 0.14803931f, 0.00725803f, -0.5132462f, 0.2660552f, + -0.4309733f, 0.16058321f, 0.16320339f, -0.21557501f, 0.32276183f, + -0.32850766f, 0.2217448f, 0.21034124f, -0.2934553f, 0.18987685f, + 0.06212101f, 0.1750198f, 0.17887063f, -0.38780046f, -0.02821094f, + 0.05002825f, 0.19618073f, 0.23872548f, 0.16032055f, -0.64525515f, + -0.19972575f, -0.38012666f, 0.20550671f, 0.14981383f, 0.22453187f, + -0.02966774f, -0.34505254f, 0.21335125f, -0.00961271f, 0.17098173f, + -0.04058227f, -0.03726651f, 0.16733989f, -0.295955f, 0.20646395f, + -0.05670565f, 0.12657055f, -0.00966609f, -0.2936089f, 0.23341022f, + -0.01142454f, 0.17226583f, -0.2727364f, -0.01445916f, 0.12635438f, + -0.23244353f, 0.22339724f, -0.5122685f, 0.29238105f, 0.2289337f + }); +#else + auto expected = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN}, + { + 0.21675213f, 0.17510113f, -0.27113008f, 0.18455505f, -0.30527824f, + 0.12460334f, 0.19112396f, -0.44803357f, 0.24922381f, -0.11691755f, + 0.16021198f, 0.27529928f, -0.28298444f, 0.21923551f, -0.37176234f, + 0.20386407f, 0.1236309f, -0.15528734f, 0.2693891f, -0.44159663f, + 0.23395306f, 0.21812534f, -0.36457074f, 0.12620285f, -0.21371071f, + 0.28493422f, 0.15514892f, -0.4384392f, 0.18344463f, -0.18508859f, + 0.19713868f, -0.61835873f, 0.22776747f, 0.13597165f, 0.05748086f, + 0.20409954f, -0.17006806f, 0.14958507f, 0.2825293f, -0.46614605f, + 0.218183f, -0.28762838f, 0.15414338f, 0.24202588f, -0.32672384f, + 0.09618269f, -0.40792802f, 0.15407808f, 0.18027757f, -0.02261038f, + -0.40063405f, -0.04311697f, 0.3049292f, 0.15634498f, -0.01752307f, + -0.43639395f, 0.31014743f, 0.14724013f, 0.18989684f, -0.21089047f, + 0.24827974f, -0.8280775f, 0.1833807f, 0.25178862f, 0.1446285f, + 0.15652135f, 0.05439584f, -0.5887033f, 0.17083165f, 0.20695446f, + 0.1825678f, 0.1605832f, -0.04697506f, 0.17088373f, -0.4670597f, + 0.13331066f, 0.2217448f, -0.46589473f, 0.24472642f, -0.13388708f, + 0.17822751f, 0.1750198f, -0.27072078f, -0.15830047f, 0.07577389f, + 0.16730122f, 0.19618073f, 0.23872548f, -0.618405f, 0.01619747f, + -0.41614607f, 0.16291247f, 0.20550671f, 0.14981383f, -0.10208681f, + -0.32300252f, 0.2421792f, -0.01448151f, 0.15483606f, -0.05953133f, + -0.03524604f, 0.1660878f, -0.24423766f, 0.19025035f, -0.07685445f, + 0.1546654f, 0.00699046f, -0.26606354f, 0.17164008f, -0.06723261f, + 0.2533586f, -0.31069174f, -0.07983261f, 0.19742766f, -0.06026195f, + 0.1379485f, -0.47723943f, 0.11733948f, 0.29238105f, -0.07042958 + }); +#endif + sd::ops::ctc_loss_grad op; + + auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *gradient = results.at(0); + + //gradient->printIndexedBuffer("gradient"); + + ASSERT_TRUE(expected.isSameShape(gradient)); + ASSERT_TRUE(expected.equalsTo(gradient, 1.e-06)); + +} + +#endif \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java index 434bda3a8..a0899ddbe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java @@ -5411,12 +5411,14 @@ public final class TensorNamespace { * Serializations can either use one of the fields above, or use this * raw bytes field. The only exception is the string case, where one is * required to store the content in the repeated bytes string_data field. + * * When this raw_data field is used to store tensor value, elements MUST * be stored in as fixed-width, little-endian order. * Floating-point data types MUST be stored in IEEE 754 format. * Complex64 elements must be written as two consecutive FLOAT values, real component first. * Complex128 elements must be written as two consecutive DOUBLE values, real component first. * Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + * * Note: the advantage of specific field rather than the raw_data field is * that in some cases (e.g. int data), protobuf does a better packing via * variable length storage, and may lead to smaller binary footprint. @@ -5655,6 +5657,7 @@ public final class TensorNamespace { /** *
    * Tensors
+   *
    * A serialized tensor value.
    * 
* @@ -7010,12 +7013,14 @@ public final class TensorNamespace { * Serializations can either use one of the fields above, or use this * raw bytes field. The only exception is the string case, where one is * required to store the content in the repeated bytes string_data field. + * * When this raw_data field is used to store tensor value, elements MUST * be stored in as fixed-width, little-endian order. * Floating-point data types MUST be stored in IEEE 754 format. * Complex64 elements must be written as two consecutive FLOAT values, real component first. * Complex128 elements must be written as two consecutive DOUBLE values, real component first. * Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + * * Note: the advantage of specific field rather than the raw_data field is * that in some cases (e.g. int data), protobuf does a better packing via * variable length storage, and may lead to smaller binary footprint. @@ -7766,6 +7771,7 @@ public final class TensorNamespace { /** *
      * Tensors
+     *
      * A serialized tensor value.
      * 
* @@ -9080,12 +9086,14 @@ public final class TensorNamespace { * Serializations can either use one of the fields above, or use this * raw bytes field. The only exception is the string case, where one is * required to store the content in the repeated bytes string_data field. + * * When this raw_data field is used to store tensor value, elements MUST * be stored in as fixed-width, little-endian order. * Floating-point data types MUST be stored in IEEE 754 format. * Complex64 elements must be written as two consecutive FLOAT values, real component first. * Complex128 elements must be written as two consecutive DOUBLE values, real component first. * Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + * * Note: the advantage of specific field rather than the raw_data field is * that in some cases (e.g. int data), protobuf does a better packing via * variable length storage, and may lead to smaller binary footprint. @@ -9102,12 +9110,14 @@ public final class TensorNamespace { * Serializations can either use one of the fields above, or use this * raw bytes field. The only exception is the string case, where one is * required to store the content in the repeated bytes string_data field. + * * When this raw_data field is used to store tensor value, elements MUST * be stored in as fixed-width, little-endian order. * Floating-point data types MUST be stored in IEEE 754 format. * Complex64 elements must be written as two consecutive FLOAT values, real component first. * Complex128 elements must be written as two consecutive DOUBLE values, real component first. * Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + * * Note: the advantage of specific field rather than the raw_data field is * that in some cases (e.g. int data), protobuf does a better packing via * variable length storage, and may lead to smaller binary footprint. @@ -9130,12 +9140,14 @@ public final class TensorNamespace { * Serializations can either use one of the fields above, or use this * raw bytes field. The only exception is the string case, where one is * required to store the content in the repeated bytes string_data field. + * * When this raw_data field is used to store tensor value, elements MUST * be stored in as fixed-width, little-endian order. * Floating-point data types MUST be stored in IEEE 754 format. * Complex64 elements must be written as two consecutive FLOAT values, real component first. * Complex128 elements must be written as two consecutive DOUBLE values, real component first. * Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + * * Note: the advantage of specific field rather than the raw_data field is * that in some cases (e.g. int data), protobuf does a better packing via * variable length storage, and may lead to smaller binary footprint. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java new file mode 100644 index 000000000..0b21c8aa0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java @@ -0,0 +1,60 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.loss; + +import org.nd4j.autodiff.loss.LossReduce; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; + +import java.util.List; + +public class CtcLoss extends BaseLoss { + + + public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){ + super(sameDiff, lossReduce, predictions, weights, labels); + } + + public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, label); + } + + public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + super(lossReduce, predictions, weights, labels); + } + + public CtcLoss(){ } + + @Override + public String opName() { + return "ctc_loss"; + } + + @Override + public List doDiff(List grad){ + //No external gradient + //Args are: predictions, weights, label + return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/CtcLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/CtcLossBp.java new file mode 100644 index 000000000..bc5b8461d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/CtcLossBp.java @@ -0,0 +1,48 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.loss.bp; + +import org.nd4j.autodiff.loss.LossReduce; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +import java.util.List; + +public class CtcLossBp extends BaseLossBp { + + + public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){ + super(sameDiff, lossReduce, predictions, weights, labels); + } + + public CtcLossBp(){ } + + @Override + public String opName() { + return "ctc_loss_grad"; + } + + @Override + public List doDiff(List grad){ + throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index a68e9c8e7..31cb14fb5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -391,7 +391,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + -Dfile.encoding=UTF-8 " diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml index bec969be9..de219f99b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml @@ -107,7 +107,7 @@ *.java **/*.java
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index 6acfd5409..aa6f52514 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -99,7 +99,7 @@ *.java **/*.java - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index 7978240a2..a79bf1d18 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -125,7 +125,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 0ae3372d2..6ebcd12c8 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -103,7 +103,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + -Dfile.encoding=UTF-8 " diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml index b515d1583..b4bac2e13 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml +++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml @@ -159,7 +159,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + " diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml index 16a0687d6..8a69382ec 100644 --- a/python4j/python4j-numpy/pom.xml +++ b/python4j/python4j-numpy/pom.xml @@ -114,7 +114,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" + "