306 lines
11 KiB
Java
306 lines
11 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j.exceptions;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.exception.DL4JException;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
import org.deeplearning4j.nn.conf.layers.*;
|
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.junit.jupiter.api.Tag;
|
|
import org.junit.jupiter.api.Test;
|
|
import org.nd4j.common.tests.tags.NativeTag;
|
|
import org.nd4j.common.tests.tags.TagNames;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
import java.util.Map;
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
@Slf4j
|
|
@NativeTag
|
|
@Tag(TagNames.EVAL_METRICS)
|
|
@Tag(TagNames.TRAINING)
|
|
@Tag(TagNames.DL4J_OLD_API)
|
|
public class TestInvalidInput extends BaseDL4JTest {
|
|
|
|
@Test
|
|
public void testInputNinMismatchDense() {
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
|
|
.layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.feedForward(Nd4j.create(1, 20));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testLabelsNOutMismatchOutputLayer() {
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
|
|
.layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.fit(Nd4j.create(1, 10), Nd4j.create(1, 20));
|
|
fail("Expected IllegalArgumentException");
|
|
} catch (IllegalArgumentException e) {
|
|
//From loss function
|
|
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testLabelsNOutMismatchRnnOutputLayer() {
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new LSTM.Builder().nIn(5).nOut(5).build())
|
|
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.fit(Nd4j.create(1, 5, 8), Nd4j.create(1, 10, 8));
|
|
fail("Expected IllegalArgumentException");
|
|
} catch (IllegalArgumentException e) {
|
|
//From loss function
|
|
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testInputNinMismatchConvolutional() {
|
|
//Rank 4 input, but input channels does not match nIn channels
|
|
|
|
int h = 16;
|
|
int w = 16;
|
|
int d = 3;
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build())
|
|
.layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
|
|
.setInputType(InputType.convolutional(h, w, d)).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.feedForward(Nd4j.create(1, 5, h, w));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testInputNinRank2Convolutional() {
|
|
//Rank 2 input, instead of rank 4 input. For example, forgetting the
|
|
|
|
int h = 16;
|
|
int w = 16;
|
|
int d = 3;
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build())
|
|
.layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
|
|
.setInputType(InputType.convolutional(h, w, d)).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.feedForward(Nd4j.create(1, 5 * h * w));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testInputNinRank2Subsampling() {
|
|
//Rank 2 input, instead of rank 4 input. For example, using the wrong input type
|
|
int h = 16;
|
|
int w = 16;
|
|
int d = 3;
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new SubsamplingLayer.Builder().kernelSize(2, 2).build())
|
|
.layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
|
|
.setInputType(InputType.convolutional(h, w, d)).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.feedForward(Nd4j.create(1, 5 * h * w));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testInputNinMismatchLSTM() {
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build())
|
|
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.fit(Nd4j.create(1, 10, 5), Nd4j.create(1, 5, 5));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testInputNinMismatchBidirectionalLSTM() {
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).build())
|
|
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.fit(Nd4j.create(1, 10, 5), Nd4j.create(1, 5, 5));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
|
|
}
|
|
|
|
@Test
|
|
public void testInputNinMismatchEmbeddingLayer() {
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(0, new EmbeddingLayer.Builder().nIn(10).nOut(10).build())
|
|
.layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
try {
|
|
net.feedForward(Nd4j.create(10, 5));
|
|
fail("Expected DL4JException");
|
|
} catch (DL4JException e) {
|
|
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
fail("Expected DL4JException");
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testInvalidRnnTimeStep() {
|
|
//Idea: Using rnnTimeStep with a different number of examples between calls
|
|
//(i.e., not calling reset between time steps)
|
|
|
|
for(String layerType : new String[]{"simple", "lstm", "graves"}) {
|
|
|
|
Layer l;
|
|
switch (layerType){
|
|
case "simple":
|
|
l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
|
|
break;
|
|
case "lstm":
|
|
l = new LSTM.Builder().nIn(5).nOut(5).build();
|
|
break;
|
|
case "graves":
|
|
l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
|
|
break;
|
|
default:
|
|
throw new RuntimeException();
|
|
}
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
|
.layer(l)
|
|
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
net.rnnTimeStep(Nd4j.create(3, 5, 10));
|
|
|
|
Map<String, INDArray> m = net.rnnGetPreviousState(0);
|
|
assertNotNull(m);
|
|
assertFalse(m.isEmpty());
|
|
|
|
try {
|
|
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
|
fail("Expected Exception - " + layerType);
|
|
} catch (Exception e) {
|
|
log.error("",e);
|
|
String msg = e.getMessage();
|
|
assertTrue(msg != null && msg.contains("rnn") && msg.contains("batch"), msg);
|
|
}
|
|
}
|
|
}
|
|
}
|