diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java index 5edfdc363..cf4253e22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CtcLoss.java @@ -22,6 +22,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp; @@ -30,6 +31,11 @@ import java.util.List; public class CtcLoss extends DynamicCustomOp { + public CtcLoss(INDArray targetLabels, INDArray logitInputs, INDArray targetLabelLengths, INDArray logitInputLengths) { + super(new INDArray[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths},null); + } + + public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){ super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index d0fcf2e83..3829deaf8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -40,7 +40,7 @@ public class HingeLoss extends BaseLoss { this(sameDiff, lossReduce, predictions, weights, labels); } - public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce) { super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt index 94a6a416a..00fdf907d 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt @@ -419,6 +419,16 @@ val checkNumerics = TensorflowMappingProcess( tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "tensor"))) ) +val ctcLoss = TensorflowMappingProcess( + opName = "ctc_loss", + inputFrameworkOpName = "CTCLoss", + opMappingRegistry = tensorflowOpRegistry, + attributeMappingRules = listOf(), + tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("logitInput" to "inputs","targetLabels" to "labels_values", + "targetLabelLengths" to "labels_indices","logitInputLengths" to "sequence_length"))) +) + + //only exists in tf2, tf-java can't run it val checkNumericsV2 = TensorflowMappingProcess( diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt index 96303b8f3..2c711166c 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt @@ -10307,6 +10307,41 @@ mappings { inputFrameworkOpName: "UniqueWithCounts" } } +mappings { + frameworkName: "tensorflow" + opName: "ctc_loss" + inputFrameworkOpName: "CTCLoss" + rule { + ruleName: "ndarraymapping" + functionName: "ndarraymapping" + inputTensorName: "inputs" + inputTensorName: "labels_values" + inputTensorName: "labels_indices" + inputTensorName: "sequence_length" + outputTensorName: "logitInput" + outputTensorName: "targetLabels" + outputTensorName: "targetLabelLengths" + outputTensorName: "logitInputLengths" + inputToOutput { + key: "logitInput" + value: "inputs" + } + inputToOutput { + key: "targetLabels" + value: "labels_values" + } + inputToOutput { + key: "targetLabelLengths" + value: "labels_indices" + } + inputToOutput { + key: "logitInputLengths" + value: "sequence_length" + } + ruleType: "tensor" + inputFrameworkOpName: "CTCLoss" + } +} mappings { frameworkName: "tensorflow" opName: "randomuniform"