diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java index 712b9c12b..8567dc379 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java @@ -3,6 +3,7 @@ package org.deeplearning4j.util; import lombok.NonNull; import org.apache.commons.io.IOUtils; import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -121,7 +122,7 @@ public class DL4JModelValidator { } try{ - MultiLayerConfiguration.fromJson(config); + ComputationGraphConfiguration.fromJson(config); } catch (Throwable t){ return ValidationResult.builder() .formatType("ComputationGraph")