Add ctc loss to TF import
This commit is contained in:
		
							parent
							
								
									0c81654567
								
							
						
					
					
						commit
						df1be4a116
					
				| @ -22,6 +22,7 @@ package org.nd4j.linalg.api.ops.impl.loss; | ||||
| 
 | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||
| import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp; | ||||
| 
 | ||||
| @ -30,6 +31,11 @@ import java.util.List; | ||||
| public class CtcLoss extends DynamicCustomOp { | ||||
| 
 | ||||
| 
 | ||||
|     public CtcLoss(INDArray targetLabels, INDArray logitInputs, INDArray targetLabelLengths, INDArray logitInputLengths) { | ||||
|         super(new INDArray[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths},null); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){ | ||||
|         super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths}); | ||||
|     } | ||||
|  | ||||
| @ -40,7 +40,7 @@ public class HingeLoss extends BaseLoss { | ||||
|         this(sameDiff, lossReduce, predictions, weights, labels); | ||||
|     } | ||||
| 
 | ||||
|     public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ | ||||
|     public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce) { | ||||
|         super(lossReduce, predictions, weights, labels); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -419,6 +419,16 @@ val checkNumerics = TensorflowMappingProcess( | ||||
|         tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor"))) | ||||
| ) | ||||
| 
 | ||||
| val ctcLoss = TensorflowMappingProcess( | ||||
|         opName = "ctc_loss", | ||||
|         inputFrameworkOpName = "CTCLoss", | ||||
|         opMappingRegistry = tensorflowOpRegistry, | ||||
|         attributeMappingRules = listOf(), | ||||
|         tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("logitInput" to "inputs","targetLabels" to "labels_values", | ||||
|                 "targetLabelLengths" to "labels_indices","logitInputLengths" to "sequence_length"))) | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| //only exists in tf2, tf-java can't run it | ||||
| 
 | ||||
| val checkNumericsV2 = TensorflowMappingProcess( | ||||
|  | ||||
| @ -10307,6 +10307,41 @@ mappings { | ||||
|     inputFrameworkOpName: "UniqueWithCounts" | ||||
|   } | ||||
| } | ||||
| mappings { | ||||
|   frameworkName: "tensorflow" | ||||
|   opName: "ctc_loss" | ||||
|   inputFrameworkOpName: "CTCLoss" | ||||
|   rule { | ||||
|     ruleName: "ndarraymapping" | ||||
|     functionName: "ndarraymapping" | ||||
|     inputTensorName: "inputs" | ||||
|     inputTensorName: "labels_values" | ||||
|     inputTensorName: "labels_indices" | ||||
|     inputTensorName: "sequence_length" | ||||
|     outputTensorName: "logitInput" | ||||
|     outputTensorName: "targetLabels" | ||||
|     outputTensorName: "targetLabelLengths" | ||||
|     outputTensorName: "logitInputLengths" | ||||
|     inputToOutput { | ||||
|       key: "logitInput" | ||||
|       value: "inputs" | ||||
|     } | ||||
|     inputToOutput { | ||||
|       key: "targetLabels" | ||||
|       value: "labels_values" | ||||
|     } | ||||
|     inputToOutput { | ||||
|       key: "targetLabelLengths" | ||||
|       value: "labels_indices" | ||||
|     } | ||||
|     inputToOutput { | ||||
|       key: "logitInputLengths" | ||||
|       value: "sequence_length" | ||||
|     } | ||||
|     ruleType: "tensor" | ||||
|     inputFrameworkOpName: "CTCLoss" | ||||
|   } | ||||
| } | ||||
| mappings { | ||||
|   frameworkName: "tensorflow" | ||||
|   opName: "randomuniform" | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user