Fix ctc loss compilation
This commit is contained in:
		
							parent
							
								
									cfe28d6a52
								
							
						
					
					
						commit
						cb442f4b5a
					
				| @ -13,213 +13,34 @@ | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *******************************************************************************/ | ||||
|  // | ||||
|  // @author AbdelRauf | ||||
|  // | ||||
| 
 | ||||
| #include "cudnnUtils.h" | ||||
| #include <array/NDArrayFactory.h> | ||||
| #include <vector> | ||||
| // | ||||
| // @author AbdelRauf | ||||
| // | ||||
| 
 | ||||
| #include <type_traits> | ||||
| #include <cmath> | ||||
| #include <stdexcept> | ||||
| #include <memory> | ||||
| #include <execution/Threads.h> | ||||
| #include <execution/ThreadPool.h> | ||||
| #include <helpers/LoopsCoordsHelper.h> | ||||
| #include <ops/declarable/helpers/ctcLoss.h> | ||||
| 
 | ||||
| namespace sd | ||||
| { | ||||
|     namespace ops | ||||
|     { | ||||
|         namespace helpers | ||||
|         { | ||||
| 
 | ||||
| 
 | ||||
| namespace sd   { | ||||
| namespace ops     { | ||||
| namespace platforms { | ||||
|            void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){ | ||||
|                   //not imeplemented | ||||
|                   throw std::runtime_error("ctcLoss:: Not implemented yet"); | ||||
| 			} | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     template<typename Op, typename ...Args> | ||||
|     void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){ | ||||
|         if(err==CUDNN_STATUS_SUCCESS){ | ||||
|             err = op(std::forward<Args>(args)...); | ||||
|             if(err){ | ||||
|                 nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     const T* bufferInHost( const NDArray &array)  { | ||||
|         array.syncToHost(); | ||||
|         return reinterpret_cast<const T*>(array.buffer()); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<int> getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){ | ||||
|                 //concatenate target labels | ||||
|                 const int32_t *tlabels = bufferInHost<int32_t>(targetLabels); | ||||
|                 const int32_t *tlens =bufferInHost<int32_t>(targetLabelLengths); | ||||
|                 int32_t nextOffset = targetLabels.strideAt(0); | ||||
|                 int32_t elStride = targetLabels.strideAt(1); | ||||
|                 int32_t batchCount = targetLabelLengths.lengthOf(); | ||||
|                 std::vector<int> labels; | ||||
|                 labels.resize(targetLabels.lengthOf()); | ||||
|                 int j=0; | ||||
|                 if(targetLabels.ews()){ | ||||
|                     for(int i=0; i<batchCount;i++){ | ||||
|                         int count = tlens[i]; | ||||
|                         for( int k=0;k<count;k++){ | ||||
|                             labels[j] = tlabels[k]; | ||||
|                             j++; | ||||
|                         } | ||||
|                         tlabels+=nextOffset; | ||||
|                     } | ||||
|                 }else{ | ||||
|                     for(int i=0; i<batchCount;i++){ | ||||
|                         int count = tlens[i]; | ||||
|                         for( int k=0;k<count;k++){ | ||||
|                             labels[j] = tlabels[k*elStride]; | ||||
|                             j++; | ||||
|                         } | ||||
|                         tlabels+=nextOffset; | ||||
|                     } | ||||
|                 } | ||||
|                 return labels; | ||||
|     } | ||||
| 
 | ||||
|     cudnnStatus_t cudnnCtcLoss(const LaunchContext  &context, const  NDArray &probs, const int32_t* targetLabelsPtr, const NDArray&  probInputLengthes, | ||||
|                                const NDArray &targetLabelLengths, NDArray &ctcLosses,  NDArray &grads){ | ||||
|         const int dims[] = {(int)probs.sizeAt(0), (int)probs.sizeAt(1), (int)probs.sizeAt(2)}; | ||||
|         const int strides[] = {(int)probs.strideAt(0), (int)probs.strideAt(1), (int)probs.strideAt(2)}; | ||||
|         auto handle = reinterpret_cast<cudnnHandle_t *>(context.getCuDnnHandle()); | ||||
|         cudnnStatus_t err = CUDNN_STATUS_SUCCESS; | ||||
|         callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream()); | ||||
| 
 | ||||
|         cudnnCTCLossDescriptor_t  ctcLossDesc; | ||||
|         cudnnTensorDescriptor_t probsDesc = nullptr; | ||||
|         cudnnTensorDescriptor_t gradsDesc = nullptr; | ||||
|         callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc); | ||||
|         callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN); | ||||
|         callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc); | ||||
|         callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides); | ||||
|         if(!grads.isEmpty()){ | ||||
|             const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)}; | ||||
|             callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc); | ||||
|             callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides); | ||||
|         } | ||||
| 
 | ||||
|         size_t tempWorkSpaceSize=0; | ||||
|         callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle,  probsDesc, gradsDesc, | ||||
|             targetLabelsPtr, | ||||
|             bufferInHost<int32_t>(targetLabelLengths), | ||||
|             bufferInHost<int32_t>(probInputLengthes), | ||||
|             CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, | ||||
|             ctcLossDesc, &tempWorkSpaceSize); | ||||
| 
 | ||||
|         // Allocate temp tempWorkspace buffer | ||||
|         void *tempWorkSpace = nullptr; | ||||
|         cudaMalloc(&tempWorkSpace, tempWorkSpaceSize); | ||||
| 
 | ||||
|         NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs}); | ||||
|         callCudnnIfNoErr(err, cudnnCTCLoss,*handle, | ||||
|             probsDesc, | ||||
|             probs.specialBuffer(), | ||||
|             targetLabelsPtr, | ||||
|             bufferInHost<int32_t>(targetLabelLengths), | ||||
|             bufferInHost<int32_t>(probInputLengthes), | ||||
|             ctcLosses.specialBuffer(), | ||||
|             gradsDesc, | ||||
|             grads.specialBuffer(), | ||||
|             CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, | ||||
|             ctcLossDesc, | ||||
|             tempWorkSpace, | ||||
|             tempWorkSpaceSize); | ||||
| 
 | ||||
|         NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs}); | ||||
| 
 | ||||
|         cudaFree(tempWorkSpace); | ||||
|         callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc); | ||||
|         if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc); | ||||
|         callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc); | ||||
|         return err; | ||||
|      } | ||||
| 
 | ||||
|     PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputLosses = OUTPUT_VARIABLE(0); | ||||
|         auto context = block.launchContext(); | ||||
|         //in Cudnn Batch is in the middle dimension | ||||
|         logitInput->permutei({1,0,2}); | ||||
|         //in Cudnn targets are concantenated instead of batched as matrix | ||||
|         auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); | ||||
|         const int32_t *ldata= labels.data(); | ||||
|         auto emptyGrads= NDArrayFactory::empty<float>(); | ||||
|         auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads); | ||||
|         if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     template<typename T> | ||||
|     bool checkLabelLength(const NDArray &labelLengthArr){ | ||||
|             //check label lengthes | ||||
|             auto lenBatch = labelLengthArr.lengthOf(); | ||||
|             for(int i=0; i < lenBatch; i++){ | ||||
|                 // The labelLengths is greater than 256. | ||||
|                 if(labelLengthArr.e<int32_t>(i)>256) return false; | ||||
|             } | ||||
|             return true; | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputLosses = OUTPUT_VARIABLE(0); | ||||
|         int blankIndex = INT_ARG(0); | ||||
| 
 | ||||
|         auto dTypeInput = logitInput->dataType(); | ||||
|         auto intType = targetLabelLengths->dataType(); | ||||
|         auto dTypeOutput = outputLosses->dataType(); | ||||
| 
 | ||||
|         bool is_supported = blankIndex==0 && intType == DataType::INT32  && dTypeInput == DataType::FLOAT32; | ||||
|         is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); | ||||
|         is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths); | ||||
|         return  is_supported; | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputGradients = OUTPUT_VARIABLE(0); | ||||
|         auto context = block.launchContext(); | ||||
|         //in Cudnn Batch is in the middle dimension | ||||
|         logitInput->permutei({1,0,2}); | ||||
|         outputGradients->permutei({1,0,2}); | ||||
|         //in Cudnn targets are concantenated instead of batched as matrix | ||||
|         auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); | ||||
|         const int32_t * ldata= labels.data(); | ||||
|         auto tempLosses = NDArrayFactory::create<float>('c', {logitInputLengths->sizeAt(0)}); | ||||
|         auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients); | ||||
|         if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); | ||||
|         //restore grads shape from {T, BATCH, C} -> {BATCHS, T, C} | ||||
|         outputGradients->permutei({1,0,2}); | ||||
|         //tempLosses.printIndexedBuffer("tempLosses"); | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputGrads = OUTPUT_VARIABLE(0); | ||||
|         int blankIndex = INT_ARG(0); | ||||
| 
 | ||||
|         auto dTypeInput = logitInput->dataType(); | ||||
|         auto intType = targetLabelLengths->dataType(); | ||||
|         auto dTypeOutput = outputGrads->dataType(); | ||||
| 
 | ||||
|         bool is_supported = blankIndex==0 && intType == DataType::INT32  && dTypeInput == DataType::FLOAT32; | ||||
|         is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); | ||||
|         is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths); | ||||
|         return  is_supported; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| } | ||||
| } | ||||
|         } // namespace helpers | ||||
|     }     // namespace ops | ||||
| } // namespace sd | ||||
							
								
								
									
										225
									
								
								libnd4j/include/ops/declarable/platform/cudnn/ctcloss.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								libnd4j/include/ops/declarable/platform/cudnn/ctcloss.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,225 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2021 Deeplearning4j Contributors | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *******************************************************************************/tt | ||||
|  // | ||||
|  // @author AbdelRauf | ||||
|  // | ||||
| 
 | ||||
| #include "cudnnUtils.h" | ||||
| #include <array/NDArrayFactory.h> | ||||
| #include <vector> | ||||
| 
 | ||||
| 
 | ||||
| namespace sd   { | ||||
| namespace ops     { | ||||
| namespace platforms { | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     template<typename Op, typename ...Args> | ||||
|     void callCudnnIfNoErr(cudnnStatus_t &err, Op op, Args&&... args){ | ||||
|         if(err==CUDNN_STATUS_SUCCESS){ | ||||
|             err = op(std::forward<Args>(args)...); | ||||
|             if(err){ | ||||
|                 nd4j_printf("Cudnn error code %s\n",cudnnGetErrorString(err)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     template <typename T> | ||||
|     const T* bufferInHost( const NDArray &array)  { | ||||
|         array.syncToHost(); | ||||
|         return reinterpret_cast<const T*>(array.buffer()); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<int> getConcatTargets(const NDArray &targetLabels, const NDArray &targetLabelLengths){ | ||||
|                 //concatenate target labels | ||||
|                 const int32_t *tlabels = bufferInHost<int32_t>(targetLabels); | ||||
|                 const int32_t *tlens =bufferInHost<int32_t>(targetLabelLengths); | ||||
|                 int32_t nextOffset = targetLabels.strideAt(0); | ||||
|                 int32_t elStride = targetLabels.strideAt(1); | ||||
|                 int32_t batchCount = targetLabelLengths.lengthOf(); | ||||
|                 std::vector<int> labels; | ||||
|                 labels.resize(targetLabels.lengthOf()); | ||||
|                 int j=0; | ||||
|                 if(targetLabels.ews()){ | ||||
|                     for(int i=0; i<batchCount;i++){ | ||||
|                         int count = tlens[i]; | ||||
|                         for( int k=0;k<count;k++){ | ||||
|                             labels[j] = tlabels[k]; | ||||
|                             j++; | ||||
|                         } | ||||
|                         tlabels+=nextOffset; | ||||
|                     } | ||||
|                 }else{ | ||||
|                     for(int i=0; i<batchCount;i++){ | ||||
|                         int count = tlens[i]; | ||||
|                         for( int k=0;k<count;k++){ | ||||
|                             labels[j] = tlabels[k*elStride]; | ||||
|                             j++; | ||||
|                         } | ||||
|                         tlabels+=nextOffset; | ||||
|                     } | ||||
|                 } | ||||
|                 return labels; | ||||
|     } | ||||
| 
 | ||||
|     cudnnStatus_t cudnnCtcLoss(const LaunchContext  &context, const  NDArray &probs, const int32_t* targetLabelsPtr, const NDArray&  probInputLengthes, | ||||
|                                const NDArray &targetLabelLengths, NDArray &ctcLosses,  NDArray &grads){ | ||||
|         const int dims[] = {(int)probs.sizeAt(0), (int)probs.sizeAt(1), (int)probs.sizeAt(2)}; | ||||
|         const int strides[] = {(int)probs.strideAt(0), (int)probs.strideAt(1), (int)probs.strideAt(2)}; | ||||
|         auto handle = reinterpret_cast<cudnnHandle_t *>(context.getCuDnnHandle()); | ||||
|         cudnnStatus_t err = CUDNN_STATUS_SUCCESS; | ||||
|         callCudnnIfNoErr(err, cudnnSetStream, *handle, *context.getCudaStream()); | ||||
| 
 | ||||
|         cudnnCTCLossDescriptor_t  ctcLossDesc; | ||||
|         cudnnTensorDescriptor_t probsDesc = nullptr; | ||||
|         cudnnTensorDescriptor_t gradsDesc = nullptr; | ||||
|         callCudnnIfNoErr(err, cudnnCreateCTCLossDescriptor, &ctcLossDesc); | ||||
|         callCudnnIfNoErr(err, cudnnSetCTCLossDescriptorEx, ctcLossDesc, CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN); | ||||
|         callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &probsDesc); | ||||
|         callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, probsDesc, cudnnDataType(probs.dataType()), probs.rankOf() , dims, strides); | ||||
|         if(!grads.isEmpty()){ | ||||
|             const int gradStrides[] = {(int)grads.strideAt(0), (int)grads.strideAt(1), (int)grads.strideAt(2)}; | ||||
|             callCudnnIfNoErr(err, cudnnCreateTensorDescriptor, &gradsDesc); | ||||
|             callCudnnIfNoErr(err, cudnnSetTensorNdDescriptor, gradsDesc, cudnnDataType(grads.dataType()), grads.rankOf() , dims, gradStrides); | ||||
|         } | ||||
| 
 | ||||
|         size_t tempWorkSpaceSize=0; | ||||
|         callCudnnIfNoErr(err,cudnnGetCTCLossWorkspaceSize, *handle,  probsDesc, gradsDesc, | ||||
|             targetLabelsPtr, | ||||
|             bufferInHost<int32_t>(targetLabelLengths), | ||||
|             bufferInHost<int32_t>(probInputLengthes), | ||||
|             CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, | ||||
|             ctcLossDesc, &tempWorkSpaceSize); | ||||
| 
 | ||||
|         // Allocate temp tempWorkspace buffer | ||||
|         void *tempWorkSpace = nullptr; | ||||
|         cudaMalloc(&tempWorkSpace, tempWorkSpaceSize); | ||||
| 
 | ||||
|         NDArray::prepareSpecialUse({&ctcLosses, &grads}, {&probs}); | ||||
|         callCudnnIfNoErr(err, cudnnCTCLoss,*handle, | ||||
|             probsDesc, | ||||
|             probs.specialBuffer(), | ||||
|             targetLabelsPtr, | ||||
|             bufferInHost<int32_t>(targetLabelLengths), | ||||
|             bufferInHost<int32_t>(probInputLengthes), | ||||
|             ctcLosses.specialBuffer(), | ||||
|             gradsDesc, | ||||
|             grads.specialBuffer(), | ||||
|             CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, | ||||
|             ctcLossDesc, | ||||
|             tempWorkSpace, | ||||
|             tempWorkSpaceSize); | ||||
| 
 | ||||
|         NDArray::registerSpecialUse({&ctcLosses, &grads}, {&probs}); | ||||
| 
 | ||||
|         cudaFree(tempWorkSpace); | ||||
|         callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,probsDesc); | ||||
|         if(gradsDesc) callCudnnIfNoErr(err, cudnnDestroyTensorDescriptor,gradsDesc); | ||||
|         callCudnnIfNoErr(err, cudnnDestroyCTCLossDescriptor,ctcLossDesc); | ||||
|         return err; | ||||
|      } | ||||
| 
 | ||||
|     PLATFORM_IMPL(ctc_loss, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputLosses = OUTPUT_VARIABLE(0); | ||||
|         auto context = block.launchContext(); | ||||
|         //in Cudnn Batch is in the middle dimension | ||||
|         logitInput->permutei({1,0,2}); | ||||
|         //in Cudnn targets are concantenated instead of batched as matrix | ||||
|         auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); | ||||
|         const int32_t *ldata= labels.data(); | ||||
|         auto emptyGrads= NDArrayFactory::empty<float>(); | ||||
|         auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, *outputLosses, emptyGrads); | ||||
|         if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     template<typename T> | ||||
|     bool checkLabelLength(const NDArray &labelLengthArr){ | ||||
|             //check label lengthes | ||||
|             auto lenBatch = labelLengthArr.lengthOf(); | ||||
|             for(int i=0; i < lenBatch; i++){ | ||||
|                 // The labelLengths is greater than 256. | ||||
|                 if(labelLengthArr.e<int32_t>(i)>256) return false; | ||||
|             } | ||||
|             return true; | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_CHECK(ctc_loss, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputLosses = OUTPUT_VARIABLE(0); | ||||
|         int blankIndex = INT_ARG(0); | ||||
| 
 | ||||
|         auto dTypeInput = logitInput->dataType(); | ||||
|         auto intType = targetLabelLengths->dataType(); | ||||
|         auto dTypeOutput = outputLosses->dataType(); | ||||
| 
 | ||||
|         bool is_supported = blankIndex==0 && intType == DataType::INT32  && dTypeInput == DataType::FLOAT32; | ||||
|         is_supported = is_supported && outputLosses->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); | ||||
|         is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths); | ||||
|         return  is_supported; | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_IMPL(ctc_loss_grad, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputGradients = OUTPUT_VARIABLE(0); | ||||
|         auto context = block.launchContext(); | ||||
|         //in Cudnn Batch is in the middle dimension | ||||
|         logitInput->permutei({1,0,2}); | ||||
|         outputGradients->permutei({1,0,2}); | ||||
|         //in Cudnn targets are concantenated instead of batched as matrix | ||||
|         auto labels = getConcatTargets(*targetLabels, *targetLabelLengths); | ||||
|         const int32_t * ldata= labels.data(); | ||||
|         auto tempLosses = NDArrayFactory::create<float>('c', {logitInputLengths->sizeAt(0)}); | ||||
|         auto err = cudnnCtcLoss(*context, *logitInput, ldata, *logitInputLengths, *targetLabelLengths, tempLosses, *outputGradients); | ||||
|         if(err!=CUDNN_STATUS_SUCCESS) throw sd::cuda_exception::build("ctc_loss CUDNN call failure ", err); | ||||
|         //restore grads shape from {T, BATCH, C} -> {BATCHS, T, C} | ||||
|         outputGradients->permutei({1,0,2}); | ||||
|         //tempLosses.printIndexedBuffer("tempLosses"); | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     PLATFORM_CHECK(ctc_loss_grad, ENGINE_CUDA) { | ||||
|         auto targetLabels = INPUT_VARIABLE(0); | ||||
|         auto logitInput = INPUT_VARIABLE(1); | ||||
|         auto targetLabelLengths = INPUT_VARIABLE(2); | ||||
|         auto logitInputLengths = INPUT_VARIABLE(3); | ||||
|         auto outputGrads = OUTPUT_VARIABLE(0); | ||||
|         int blankIndex = INT_ARG(0); | ||||
| 
 | ||||
|         auto dTypeInput = logitInput->dataType(); | ||||
|         auto intType = targetLabelLengths->dataType(); | ||||
|         auto dTypeOutput = outputGrads->dataType(); | ||||
| 
 | ||||
|         bool is_supported = blankIndex==0 && intType == DataType::INT32  && dTypeInput == DataType::FLOAT32; | ||||
|         is_supported = is_supported && outputGrads->ews() && targetLabelLengths->ews() && targetLabels->ews() && logitInputLengths->ews(); | ||||
|         is_supported = is_supported && checkLabelLength<int32_t>(*targetLabelLengths); | ||||
|         return  is_supported; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| } | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user