From ebe413267b3951cedd428c4f60b1e8ce8050ee15 Mon Sep 17 00:00:00 2001 From: Rhys Date: Tue, 24 Mar 2020 12:56:53 +1300 Subject: [PATCH] Added swish activation function in mapToActivation Swish function already implemented, and accounted for in getActivationFunction, just not for in the if-else chain of mapToActivation --- .../nn/modelimport/keras/config/KerasLayerConfiguration.java | 1 + .../nn/modelimport/keras/utils/KerasActivationUtils.java | 2 ++ 2 files changed, 3 insertions(+) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index 7841fdf27..4ffdc20c3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -334,6 +334,7 @@ public class KerasLayerConfiguration { private final String KERAS_ACTIVATION_SIGMOID = "sigmoid"; private final String KERAS_ACTIVATION_HARD_SIGMOID = "hard_sigmoid"; private final String KERAS_ACTIVATION_LINEAR = "linear"; + private final String KERAS_ACTIVATION_SWISH = "swish"; private final String KERAS_ACTIVATION_ELU = "elu"; // keras 2 only private final String KERAS_ACTIVATION_SELU = "selu"; // keras 2 only diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java index f0ddfd912..cdf8bc923 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java @@ -63,6 +63,8 @@ public class KerasActivationUtils { dl4jActivation = Activation.HARDSIGMOID; } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_LINEAR())) { dl4jActivation = Activation.IDENTITY; + } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SWISH())) { + dl4jActivation = Activation.SWISH; } else { throw new UnsupportedKerasConfigurationException( "Unknown Keras activation function " + kerasActivation);