Add support for registering custom loss functions for Keras import (#427)

Signed-off-by: Paul Dubs <paul.dubs@gmail.com>
master
Paul Dubs 2020-05-02 08:08:04 +02:00 committed by GitHub
parent 6c9a14d8c2
commit 4dbdaca967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 116 additions and 5 deletions

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList; import java.util.ArrayList;
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
public class KerasLoss extends KerasLayer { public class KerasLoss extends KerasLayer {
private final String KERAS_CLASS_NAME_LOSS = "Loss"; private final String KERAS_CLASS_NAME_LOSS = "Loss";
private LossFunctions.LossFunction loss; private ILossFunction loss;
/** /**
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
if (enforceTrainingConfig) if (enforceTrainingConfig)
throw e; throw e;
log.warn("Unsupported Keras loss function. Replacing with MSE."); log.warn("Unsupported Keras loss function. Replacing with MSE.");
loss = LossFunctions.LossFunction.SQUARED_LOSS; loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.HashMap;
import java.util.Map;
/** /**
* Utility functionality for keras loss functions * Utility functionality for keras loss functions
* *
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
*/ */
@Slf4j @Slf4j
public class KerasLossUtils { public class KerasLossUtils {
static final Map<String, ILossFunction> customLoss = new HashMap<>();
/**
* Register a custom loss function
*
* @param lossName name of the lambda layer in the serialized Keras model
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
*/
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
customLoss.put(lossName, lossFunction);
}
/**
* Clear all lambda layers
*
*/
public static void clearCustomLoss() {
customLoss.clear();
}
/** /**
* Map Keras to DL4J loss functions. * Map Keras to DL4J loss functions.
* *
* @param kerasLoss String containing Keras loss function name * @param kerasLoss String containing Keras loss function name
* @return String containing DL4J loss function * @return String containing DL4J loss function
*/ */
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
throws UnsupportedKerasConfigurationException { throws UnsupportedKerasConfigurationException {
LossFunctions.LossFunction dl4jLoss; LossFunctions.LossFunction dl4jLoss;
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
@ -67,8 +93,13 @@ public class KerasLossUtils {
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) { } else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY; dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
} else { } else {
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss); ILossFunction lossClass = customLoss.get(kerasLoss);
if(lossClass != null){
return lossClass;
}else{
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
}
} }
return dl4jLoss; return dl4jLoss.getILossFunction();
} }
} }

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.e2e;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
import java.io.File;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
/**
* Test importing Keras models with custom loss.
*
* @author Paul Dubs
*/
public class KerasCustomLossTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public class LogCosh extends SameDiffLoss {
@Override
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
}
}
@Test
public void testSequentialLambdaLayerImport() throws Exception {
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
String modelPath = "modelimport/keras/examples/custom_loss.h5";
try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
System.out.println(model.summary());
INDArray input = Nd4j.create(new int[]{10, 3});
model.output(input);
} finally {
KerasLossUtils.clearCustomLoss();
}
}
}