From b57f1d52cc153edfc550c93623e10d163598e26a Mon Sep 17 00:00:00 2001 From: Susan Eraly Date: Thu, 25 Jul 2019 20:37:52 -0700 Subject: [PATCH] Keras model import - updater lr fix (#84) * Keras model import - updater lr fix Signed-off-by: eraly * Keras model import - updater lr fix, cleanup Signed-off-by: eraly --- .../modelimport/keras/utils/KerasOptimizerUtils.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasOptimizerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasOptimizerUtils.java index 6d230d1fe..c489f6b0e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasOptimizerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasOptimizerUtils.java @@ -75,7 +75,7 @@ public class KerasOptimizerUtils { dl4jOptimizer = new Adam.Builder() .beta1(beta1).beta2(beta2) .epsilon(epsilon).learningRate(lr) - .learningRateSchedule(new InverseSchedule(ScheduleType.ITERATION, 1, decay, 1)) + .learningRateSchedule(decay == 0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1)) .build(); break; } @@ -96,7 +96,7 @@ public class KerasOptimizerUtils { dl4jOptimizer = new AdaGrad.Builder() .epsilon(epsilon).learningRate(lr) - .learningRateSchedule(new InverseSchedule(ScheduleType.ITERATION, 1, decay, 1)) + .learningRateSchedule(decay == 0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1)) .build(); break; } @@ -119,8 +119,8 @@ public class KerasOptimizerUtils { dl4jOptimizer = new Nadam.Builder() .beta1(beta1).beta2(beta2) .epsilon(epsilon).learningRate(lr) - .learningRateSchedule(new InverseSchedule(ScheduleType.ITERATION, 1, - scheduleDecay, 1)) + .learningRateSchedule(scheduleDecay == 0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, + scheduleDecay, 1)) .build(); break; } @@ -132,7 +132,7 @@ public class KerasOptimizerUtils { dl4jOptimizer = new Nesterovs.Builder() .momentum(momentum).learningRate(lr) - .learningRateSchedule(new InverseSchedule(ScheduleType.ITERATION, 1, decay, 1)) + .learningRateSchedule(decay == 0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1)) .build(); break; } @@ -144,7 +144,7 @@ public class KerasOptimizerUtils { dl4jOptimizer = new RmsProp.Builder() .epsilon(epsilon).rmsDecay(rho).learningRate(lr) - .learningRateSchedule(new InverseSchedule(ScheduleType.ITERATION, 1, decay, 1)) + .learningRateSchedule(decay == 0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1)) .build(); break; }