Merge pull request #9230 from eclipse/ag_ctc_loss_tf_import

Add ctc loss to TF import
master
Adam Gibson 2021-03-13 09:18:29 +09:00 committed by GitHub
commit 4b19482051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 1 deletions

View File

@ -22,6 +22,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp; import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
@ -30,6 +31,11 @@ import java.util.List;
public class CtcLoss extends DynamicCustomOp { public class CtcLoss extends DynamicCustomOp {
public CtcLoss(INDArray targetLabels, INDArray logitInputs, INDArray targetLabelLengths, INDArray logitInputLengths) {
super(new INDArray[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths},null);
}
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){ public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths}); super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
} }

View File

@ -419,6 +419,16 @@ val checkNumerics = TensorflowMappingProcess(
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor"))) tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor")))
) )
val ctcLoss = TensorflowMappingProcess(
opName = "ctc_loss",
inputFrameworkOpName = "CTCLoss",
opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("logitInput" to "inputs","targetLabels" to "labels_values",
"targetLabelLengths" to "labels_indices","logitInputLengths" to "sequence_length")))
)
//only exists in tf2, tf-java can't run it //only exists in tf2, tf-java can't run it
val checkNumericsV2 = TensorflowMappingProcess( val checkNumericsV2 = TensorflowMappingProcess(

View File

@ -10307,6 +10307,41 @@ mappings {
inputFrameworkOpName: "UniqueWithCounts" inputFrameworkOpName: "UniqueWithCounts"
} }
} }
mappings {
frameworkName: "tensorflow"
opName: "ctc_loss"
inputFrameworkOpName: "CTCLoss"
rule {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "inputs"
inputTensorName: "labels_values"
inputTensorName: "labels_indices"
inputTensorName: "sequence_length"
outputTensorName: "logitInput"
outputTensorName: "targetLabels"
outputTensorName: "targetLabelLengths"
outputTensorName: "logitInputLengths"
inputToOutput {
key: "logitInput"
value: "inputs"
}
inputToOutput {
key: "targetLabels"
value: "labels_values"
}
inputToOutput {
key: "targetLabelLengths"
value: "labels_indices"
}
inputToOutput {
key: "logitInputLengths"
value: "sequence_length"
}
ruleType: "tensor"
inputFrameworkOpName: "CTCLoss"
}
}
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "randomuniform" opName: "randomuniform"