diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index 7841fdf27..4ffdc20c3 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -334,6 +334,7 @@ public class KerasLayerConfiguration { private final String KERAS_ACTIVATION_SIGMOID = "sigmoid"; private final String KERAS_ACTIVATION_HARD_SIGMOID = "hard_sigmoid"; private final String KERAS_ACTIVATION_LINEAR = "linear"; + private final String KERAS_ACTIVATION_SWISH = "swish"; private final String KERAS_ACTIVATION_ELU = "elu"; // keras 2 only private final String KERAS_ACTIVATION_SELU = "selu"; // keras 2 only diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java index f0ddfd912..cdf8bc923 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasActivationUtils.java @@ -63,6 +63,8 @@ public class KerasActivationUtils { dl4jActivation = Activation.HARDSIGMOID; } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_LINEAR())) { dl4jActivation = Activation.IDENTITY; + } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SWISH())) { + dl4jActivation = Activation.SWISH; } else { throw new UnsupportedKerasConfigurationException( "Unknown Keras activation function " + kerasActivation);