diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java index e3c603287..d47309d1d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; @@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo public class KerasLoss extends KerasLayer { private final String KERAS_CLASS_NAME_LOSS = "Loss"; - private LossFunctions.LossFunction loss; + private ILossFunction loss; /** @@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer { if (enforceTrainingConfig) throw e; log.warn("Unsupported Keras loss function. Replacing with MSE."); - loss = LossFunctions.LossFunction.SQUARED_LOSS; + loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction(); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java index 35cf34170..b9e0ddfce 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.HashMap; +import java.util.Map; + + /** * Utility functionality for keras loss functions * @@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; */ @Slf4j public class KerasLossUtils { + static final Map customLoss = new HashMap<>(); + + /** + * Register a custom loss function + * + * @param lossName name of the lambda layer in the serialized Keras model + * @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer + */ + public static void registerCustomLoss(String lossName, ILossFunction lossFunction) { + customLoss.put(lossName, lossFunction); + } + + /** + * Clear all lambda layers + * + */ + public static void clearCustomLoss() { + customLoss.clear(); + } + /** * Map Keras to DL4J loss functions. * * @param kerasLoss String containing Keras loss function name * @return String containing DL4J loss function */ - public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) + public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException { LossFunctions.LossFunction dl4jLoss; if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || @@ -67,8 +93,13 @@ public class KerasLossUtils { } else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) { dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY; } else { - throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss); + ILossFunction lossClass = customLoss.get(kerasLoss); + if(lossClass != null){ + return lossClass; + }else{ + throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss); + } } - return dl4jLoss; + return dl4jLoss.getILossFunction(); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java new file mode 100644 index 000000000..23c46835e --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.e2e; + +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; +import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.resources.Resources; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.SameDiffLoss; + +import java.io.File; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; + + +/** + * Test importing Keras models with custom loss. + * + * @author Paul Dubs + */ +public class KerasCustomLossTest extends BaseDL4JTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public class LogCosh extends SameDiffLoss { + @Override + public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { + return sd.math.log(sd.math.cosh(labels.sub(layerInput))); + } + } + + @Test + public void testSequentialLambdaLayerImport() throws Exception { + KerasLossUtils.registerCustomLoss("logcosh", new LogCosh()); + + String modelPath = "modelimport/keras/examples/custom_loss.h5"; + + try(InputStream is = Resources.asStream(modelPath)) { + File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5"); + Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) + .enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork(); + + System.out.println(model.summary()); + INDArray input = Nd4j.create(new int[]{10, 3}); + + model.output(input); + } finally { + KerasLossUtils.clearCustomLoss(); + } + } + + +}