Add support for registering custom loss functions for Keras import (#427)
Signed-off-by: Paul Dubs <paul.dubs@gmail.com>master
parent
6c9a14d8c2
commit
4dbdaca967
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
|
|||
public class KerasLoss extends KerasLayer {
|
||||
|
||||
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
||||
private LossFunctions.LossFunction loss;
|
||||
private ILossFunction loss;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
|
|||
if (enforceTrainingConfig)
|
||||
throw e;
|
||||
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
||||
loss = LossFunctions.LossFunction.SQUARED_LOSS;
|
||||
loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
/**
|
||||
* Utility functionality for keras loss functions
|
||||
*
|
||||
|
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
*/
|
||||
@Slf4j
|
||||
public class KerasLossUtils {
|
||||
static final Map<String, ILossFunction> customLoss = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Register a custom loss function
|
||||
*
|
||||
* @param lossName name of the lambda layer in the serialized Keras model
|
||||
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
|
||||
*/
|
||||
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
|
||||
customLoss.put(lossName, lossFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all lambda layers
|
||||
*
|
||||
*/
|
||||
public static void clearCustomLoss() {
|
||||
customLoss.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Keras to DL4J loss functions.
|
||||
*
|
||||
* @param kerasLoss String containing Keras loss function name
|
||||
* @return String containing DL4J loss function
|
||||
*/
|
||||
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||
public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||
throws UnsupportedKerasConfigurationException {
|
||||
LossFunctions.LossFunction dl4jLoss;
|
||||
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
||||
|
@ -67,8 +93,13 @@ public class KerasLossUtils {
|
|||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
||||
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
||||
} else {
|
||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||
ILossFunction lossClass = customLoss.get(kerasLoss);
|
||||
if(lossClass != null){
|
||||
return lossClass;
|
||||
}else{
|
||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||
}
|
||||
}
|
||||
return dl4jLoss;
|
||||
return dl4jLoss.getILossFunction();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras.e2e;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
|
||||
/**
|
||||
* Test importing Keras models with custom loss.
|
||||
*
|
||||
* @author Paul Dubs
|
||||
*/
|
||||
public class KerasCustomLossTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
public class LogCosh extends SameDiffLoss {
|
||||
@Override
|
||||
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequentialLambdaLayerImport() throws Exception {
|
||||
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
|
||||
|
||||
String modelPath = "modelimport/keras/examples/custom_loss.h5";
|
||||
|
||||
try(InputStream is = Resources.asStream(modelPath)) {
|
||||
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
|
||||
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
||||
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
|
||||
|
||||
System.out.println(model.summary());
|
||||
INDArray input = Nd4j.create(new int[]{10, 3});
|
||||
|
||||
model.output(input);
|
||||
} finally {
|
||||
KerasLossUtils.clearCustomLoss();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue