Add ctc loss to TF import
parent
0c81654567
commit
df1be4a116
|
@ -22,6 +22,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
|
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
|
||||||
|
|
||||||
|
@ -30,6 +31,11 @@ import java.util.List;
|
||||||
public class CtcLoss extends DynamicCustomOp {
|
public class CtcLoss extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
|
public CtcLoss(INDArray targetLabels, INDArray logitInputs, INDArray targetLabelLengths, INDArray logitInputLengths) {
|
||||||
|
super(new INDArray[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths},null);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||||
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||||
}
|
}
|
||||||
|
|
|
@ -419,6 +419,16 @@ val checkNumerics = TensorflowMappingProcess(
|
||||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor")))
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor")))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val ctcLoss = TensorflowMappingProcess(
|
||||||
|
opName = "ctc_loss",
|
||||||
|
inputFrameworkOpName = "CTCLoss",
|
||||||
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
|
attributeMappingRules = listOf(),
|
||||||
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("logitInput" to "inputs","targetLabels" to "labels_values",
|
||||||
|
"targetLabelLengths" to "labels_indices","logitInputLengths" to "sequence_length")))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
//only exists in tf2, tf-java can't run it
|
//only exists in tf2, tf-java can't run it
|
||||||
|
|
||||||
val checkNumericsV2 = TensorflowMappingProcess(
|
val checkNumericsV2 = TensorflowMappingProcess(
|
||||||
|
|
|
@ -10307,6 +10307,41 @@ mappings {
|
||||||
inputFrameworkOpName: "UniqueWithCounts"
|
inputFrameworkOpName: "UniqueWithCounts"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
mappings {
|
||||||
|
frameworkName: "tensorflow"
|
||||||
|
opName: "ctc_loss"
|
||||||
|
inputFrameworkOpName: "CTCLoss"
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraymapping"
|
||||||
|
functionName: "ndarraymapping"
|
||||||
|
inputTensorName: "inputs"
|
||||||
|
inputTensorName: "labels_values"
|
||||||
|
inputTensorName: "labels_indices"
|
||||||
|
inputTensorName: "sequence_length"
|
||||||
|
outputTensorName: "logitInput"
|
||||||
|
outputTensorName: "targetLabels"
|
||||||
|
outputTensorName: "targetLabelLengths"
|
||||||
|
outputTensorName: "logitInputLengths"
|
||||||
|
inputToOutput {
|
||||||
|
key: "logitInput"
|
||||||
|
value: "inputs"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "targetLabels"
|
||||||
|
value: "labels_values"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "targetLabelLengths"
|
||||||
|
value: "labels_indices"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "logitInputLengths"
|
||||||
|
value: "sequence_length"
|
||||||
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "CTCLoss"
|
||||||
|
}
|
||||||
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "randomuniform"
|
opName: "randomuniform"
|
||||||
|
|
Loading…
Reference in New Issue