Add support for registering custom loss functions for Keras import (#427)
Signed-off-by: Paul Dubs <paul.dubs@gmail.com>master
parent
6c9a14d8c2
commit
4dbdaca967
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
|
||||||
public class KerasLoss extends KerasLayer {
|
public class KerasLoss extends KerasLayer {
|
||||||
|
|
||||||
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
||||||
private LossFunctions.LossFunction loss;
|
private ILossFunction loss;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
|
||||||
if (enforceTrainingConfig)
|
if (enforceTrainingConfig)
|
||||||
throw e;
|
throw e;
|
||||||
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
||||||
loss = LossFunctions.LossFunction.SQUARED_LOSS;
|
loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility functionality for keras loss functions
|
* Utility functionality for keras loss functions
|
||||||
*
|
*
|
||||||
|
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class KerasLossUtils {
|
public class KerasLossUtils {
|
||||||
|
static final Map<String, ILossFunction> customLoss = new HashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a custom loss function
|
||||||
|
*
|
||||||
|
* @param lossName name of the lambda layer in the serialized Keras model
|
||||||
|
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
|
||||||
|
*/
|
||||||
|
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
|
||||||
|
customLoss.put(lossName, lossFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all lambda layers
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public static void clearCustomLoss() {
|
||||||
|
customLoss.clear();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map Keras to DL4J loss functions.
|
* Map Keras to DL4J loss functions.
|
||||||
*
|
*
|
||||||
* @param kerasLoss String containing Keras loss function name
|
* @param kerasLoss String containing Keras loss function name
|
||||||
* @return String containing DL4J loss function
|
* @return String containing DL4J loss function
|
||||||
*/
|
*/
|
||||||
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||||
throws UnsupportedKerasConfigurationException {
|
throws UnsupportedKerasConfigurationException {
|
||||||
LossFunctions.LossFunction dl4jLoss;
|
LossFunctions.LossFunction dl4jLoss;
|
||||||
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
||||||
|
@ -67,8 +93,13 @@ public class KerasLossUtils {
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
||||||
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
||||||
} else {
|
} else {
|
||||||
|
ILossFunction lossClass = customLoss.get(kerasLoss);
|
||||||
|
if(lossClass != null){
|
||||||
|
return lossClass;
|
||||||
|
}else{
|
||||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||||
}
|
}
|
||||||
return dl4jLoss;
|
}
|
||||||
|
return dl4jLoss.getILossFunction();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras.e2e;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test importing Keras models with custom loss.
|
||||||
|
*
|
||||||
|
* @author Paul Dubs
|
||||||
|
*/
|
||||||
|
public class KerasCustomLossTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
public class LogCosh extends SameDiffLoss {
|
||||||
|
@Override
|
||||||
|
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||||
|
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSequentialLambdaLayerImport() throws Exception {
|
||||||
|
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
|
||||||
|
|
||||||
|
String modelPath = "modelimport/keras/examples/custom_loss.h5";
|
||||||
|
|
||||||
|
try(InputStream is = Resources.asStream(modelPath)) {
|
||||||
|
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
|
||||||
|
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||||
|
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
||||||
|
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
|
||||||
|
|
||||||
|
System.out.println(model.summary());
|
||||||
|
INDArray input = Nd4j.create(new int[]{10, 3});
|
||||||
|
|
||||||
|
model.output(input);
|
||||||
|
} finally {
|
||||||
|
KerasLossUtils.clearCustomLoss();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue