Add ctc loss to TF import
parent
0c81654567
commit
df1be4a116
|
@ -22,6 +22,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
|||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
|
||||
|
||||
|
@ -30,6 +31,11 @@ import java.util.List;
|
|||
public class CtcLoss extends DynamicCustomOp {
|
||||
|
||||
|
||||
public CtcLoss(INDArray targetLabels, INDArray logitInputs, INDArray targetLabelLengths, INDArray logitInputLengths) {
|
||||
super(new INDArray[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths},null);
|
||||
}
|
||||
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||
}
|
||||
|
|
|
@ -419,6 +419,16 @@ val checkNumerics = TensorflowMappingProcess(
|
|||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor")))
|
||||
)
|
||||
|
||||
val ctcLoss = TensorflowMappingProcess(
|
||||
opName = "ctc_loss",
|
||||
inputFrameworkOpName = "CTCLoss",
|
||||
opMappingRegistry = tensorflowOpRegistry,
|
||||
attributeMappingRules = listOf(),
|
||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("logitInput" to "inputs","targetLabels" to "labels_values",
|
||||
"targetLabelLengths" to "labels_indices","logitInputLengths" to "sequence_length")))
|
||||
)
|
||||
|
||||
|
||||
//only exists in tf2, tf-java can't run it
|
||||
|
||||
val checkNumericsV2 = TensorflowMappingProcess(
|
||||
|
|
|
@ -10307,6 +10307,41 @@ mappings {
|
|||
inputFrameworkOpName: "UniqueWithCounts"
|
||||
}
|
||||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "ctc_loss"
|
||||
inputFrameworkOpName: "CTCLoss"
|
||||
rule {
|
||||
ruleName: "ndarraymapping"
|
||||
functionName: "ndarraymapping"
|
||||
inputTensorName: "inputs"
|
||||
inputTensorName: "labels_values"
|
||||
inputTensorName: "labels_indices"
|
||||
inputTensorName: "sequence_length"
|
||||
outputTensorName: "logitInput"
|
||||
outputTensorName: "targetLabels"
|
||||
outputTensorName: "targetLabelLengths"
|
||||
outputTensorName: "logitInputLengths"
|
||||
inputToOutput {
|
||||
key: "logitInput"
|
||||
value: "inputs"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "targetLabels"
|
||||
value: "labels_values"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "targetLabelLengths"
|
||||
value: "labels_indices"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "logitInputLengths"
|
||||
value: "sequence_length"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "CTCLoss"
|
||||
}
|
||||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "randomuniform"
|
||||
|
|
Loading…
Reference in New Issue