Fix lenet input shape (#9130)

* Add options for changing the default CNN2D format

This means that when a model is being loaded (e.g., from a Keras H5 file), the default CNN2DFormat can be set. Currently it always defaults to 'Channels First' which can cause problems for some models.

* 🐛 Fix default Input shape for LeNet

Input shape should be [1, 28, 28], as per manual testing and https://github.com/BVLC/caffe/blob/master/examples/mnist/lenet.prototxt
master
Rhys Compton 2020-11-23 21:07:40 +13:00 committed by GitHub
parent a768f4c904
commit 8e591bbf39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -55,7 +55,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class LeNet extends ZooModel { public class LeNet extends ZooModel {
@Builder.Default private long seed = 1234; @Builder.Default private long seed = 1234;
@Builder.Default private int[] inputShape = new int[] {3, 224, 224}; @Builder.Default private int[] inputShape = new int[] {1, 28, 28};
@Builder.Default private int numClasses = 0; @Builder.Default private int numClasses = 0;
@Builder.Default private IUpdater updater = new AdaDelta(); @Builder.Default private IUpdater updater = new AdaDelta();
@Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private CacheMode cacheMode = CacheMode.NONE;