2019-08-29 17:31:57 -07:00
|
|
|
/*
|
2021-02-01 14:31:20 +09:00
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
2019-08-29 17:31:57 -07:00
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
package org.deeplearning4j.regressiontest;
|
|
|
|
|
|
2021-03-16 11:57:24 +09:00
|
|
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
import java.io.DataInputStream;
|
|
|
|
|
import java.io.File;
|
|
|
|
|
import java.io.FileInputStream;
|
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
|
|
|
import org.deeplearning4j.TestUtils;
|
|
|
|
|
import org.deeplearning4j.nn.conf.BackpropType;
|
|
|
|
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|
|
|
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
|
|
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
|
|
|
import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex;
|
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
|
|
|
|
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
|
2021-03-16 11:57:24 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
2019-08-29 17:31:57 -07:00
|
|
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
|
|
|
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
|
|
|
|
import org.nd4j.linalg.activations.impl.ActivationReLU;
|
|
|
|
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
|
|
|
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
|
|
|
|
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
|
|
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
|
|
|
import org.nd4j.linalg.learning.config.RmsProp;
|
|
|
|
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
|
|
|
|
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
|
|
|
|
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.common.resources.Resources;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-08-29 17:31:57 -07:00
|
|
|
public class RegressionTest100b4 extends BaseDL4JTest {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public DataType getDataType() {
|
|
|
|
|
return DataType.FLOAT;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testCustomLayer() throws Exception {
|
|
|
|
|
|
|
|
|
|
for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
|
|
|
|
|
|
|
|
|
String dtypeName = dtype.toString().toLowerCase();
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/CustomLayerExample_100b4_" + dtypeName + ".bin");
|
|
|
|
|
MultiLayerNetwork.load(f, true);
|
|
|
|
|
|
|
|
|
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
|
|
|
|
// net = net.clone();
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
|
|
|
|
assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new RmsProp(0.95), l0.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
|
|
|
|
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new RmsProp(0.95), l1.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources
|
|
|
|
|
.asFile("regression_testing/100b4/CustomLayerExample_Output_100b4_" + dtypeName + ".bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/CustomLayerExample_Input_100b4_" + dtypeName + ".bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assertEquals(dtype, in.dataType());
|
|
|
|
|
assertEquals(dtype, outExp.dataType());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(dtype, net.getModelParams().dataType());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(dtype, net.getFlattenedGradients().dataType());
|
|
|
|
|
assertEquals(dtype, net.getUpdater().getStateViewArray().dataType());
|
|
|
|
|
|
|
|
|
|
//System.out.println(Arrays.toString(net.params().data().asFloat()));
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.output(in);
|
|
|
|
|
assertEquals(dtype, outAct.dataType());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(dtype, net.getNetConfiguration().getDataType());
|
|
|
|
|
assertEquals(dtype, net.getModelParams().dataType());
|
2020-04-10 17:57:02 +03:00
|
|
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
2022-09-20 15:40:53 +02:00
|
|
|
assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct);
|
2019-08-29 17:31:57 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testLSTM() throws Exception {
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_100b4.bin");
|
|
|
|
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
|
|
|
|
assertEquals(200, l0.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
|
|
|
|
assertEquals(200, l1.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationSoftmax(), l2.getActivationFn());
|
|
|
|
|
assertEquals(77, l2.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l2.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l2.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType());
|
|
|
|
|
assertEquals(50, net.getNetConfiguration().getTbpttBackLength());
|
|
|
|
|
assertEquals(50, net.getNetConfiguration().getTbpttFwdLength());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Input_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.output(in);
|
|
|
|
|
|
|
|
|
|
assertEquals(outExp, outAct);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testVae() throws Exception {
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_100b4.bin");
|
|
|
|
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationLReLU(), l0.getActivationFn());
|
|
|
|
|
assertEquals(32, l0.getNOut());
|
|
|
|
|
assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes());
|
|
|
|
|
assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(1e-3), l0.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Input_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.output(in);
|
|
|
|
|
|
|
|
|
|
assertEquals(outExp, outAct);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
////@Ignore("Failing due to new data format changes. Sept 10,2020")
|
2019-08-29 17:31:57 -07:00
|
|
|
public void testYoloHouseNumber() throws Exception {
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/HouseNumberDetection_100b4.bin");
|
|
|
|
|
ComputationGraph net = ComputationGraph.load(f, true);
|
|
|
|
|
|
|
|
|
|
int nBoxes = 5;
|
|
|
|
|
int nClasses = 10;
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices()
|
|
|
|
|
.get("convolution2d_9")).getNetConfiguration().getFirstLayer();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(nBoxes * (5 + nClasses), cl.getNOut());
|
|
|
|
|
assertEquals(new ActivationIdentity(), cl.getActivationFn());
|
|
|
|
|
assertEquals(ConvolutionMode.Same, cl.getConvolutionMode());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), cl.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{1, 1}, cl.getKernelSize());
|
|
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources.asFile("regression_testing/100b4/HouseNumberDetection_Output_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/HouseNumberDetection_Input_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.outputSingle(in);
|
|
|
|
|
|
|
|
|
|
boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3);
|
|
|
|
|
assertTrue(eq);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
////@Ignore("failing due to new input data format changes.")
|
2019-08-29 17:31:57 -07:00
|
|
|
public void testSyntheticCNN() throws Exception {
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/SyntheticCNN_100b4.bin");
|
|
|
|
|
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationReLU(), l0.getActivationFn());
|
|
|
|
|
assertEquals(4, l0.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l0.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l0.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3}, l0.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{2, 1}, l0.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l0.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l0.getPadding());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
|
|
|
|
assertEquals(8, l1.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l1.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l1.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3}, l1.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l1.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l1.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l1.getPadding());
|
|
|
|
|
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
|
|
|
|
|
assertEquals(1, l1.getDepthMultiplier());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3}, l2.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{2, 2}, l2.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l2.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l2.getPadding());
|
|
|
|
|
assertEquals(PoolingType.MAX, l2.getPoolingType());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
Upsampling2D l4 = (Upsampling2D) net.getLayer(4).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3}, l4.getSize());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
|
|
|
|
assertEquals(16, l5.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l5.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l5.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3}, l5.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l5.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l5.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l5.getPadding());
|
|
|
|
|
assertEquals(2, l5.getDepthMultiplier());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{2, 2}, l6.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{2, 2}, l6.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l6.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l6.getPadding());
|
|
|
|
|
assertEquals(PoolingType.MAX, l6.getPoolingType());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
Cropping2D l7 = (Cropping2D) net.getLayer(7).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(4, l8.getNOut());
|
2023-03-23 17:39:00 +01:00
|
|
|
assertEquals(new WeightInitXavier(), l8.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l8.getUpdater());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertArrayEquals(new int[]{4, 4}, l8.getKernelSize());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l8.getStride());
|
|
|
|
|
assertArrayEquals(new int[]{1, 1}, l8.getDilation());
|
|
|
|
|
assertArrayEquals(new int[]{0, 0}, l8.getPadding());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration();
|
|
|
|
|
assertEquals(new WeightInitXavier(), l9.getWeightInit());
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9));
|
2023-05-08 09:22:38 +02:00
|
|
|
assertEquals(new Adam(0.005), l9.getUpdater());
|
2023-04-24 18:09:11 +02:00
|
|
|
assertEquals(new LossMAE(), l9.getLossFunction());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources.asFile("regression_testing/100b4/SyntheticCNN_Output_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/SyntheticCNN_Input_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.output(in);
|
|
|
|
|
|
2019-11-16 17:04:29 +11:00
|
|
|
//19 layers - CPU vs. GPU difference accumulates notably, but appears to be correct
|
|
|
|
|
if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("native")){
|
|
|
|
|
assertEquals(outExp, outAct);
|
|
|
|
|
} else {
|
|
|
|
|
boolean eq = outExp.equalsWithEps(outAct, 0.1);
|
|
|
|
|
assertTrue(eq);
|
|
|
|
|
}
|
2019-08-29 17:31:57 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testSyntheticBidirectionalRNNGraph() throws Exception {
|
|
|
|
|
|
|
|
|
|
File f = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_100b4.bin");
|
|
|
|
|
ComputationGraph net = ComputationGraph.load(f, true);
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
LSTM l1 = (LSTM) l0.getFwd();
|
|
|
|
|
assertEquals(16, l1.getNOut());
|
|
|
|
|
assertEquals(new ActivationReLU(), l1.getActivationFn());
|
|
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1));
|
|
|
|
|
|
|
|
|
|
LSTM l2 = (LSTM) l0.getBwd();
|
|
|
|
|
assertEquals(16, l2.getNOut());
|
|
|
|
|
assertEquals(new ActivationReLU(), l2.getActivationFn());
|
|
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2));
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
SimpleRnn l4 = (SimpleRnn) l3.getFwd();
|
|
|
|
|
assertEquals(16, l4.getNOut());
|
|
|
|
|
assertEquals(new ActivationReLU(), l4.getActivationFn());
|
|
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4));
|
|
|
|
|
|
|
|
|
|
SimpleRnn l5 = (SimpleRnn) l3.getBwd();
|
|
|
|
|
assertEquals(16, l5.getNOut());
|
|
|
|
|
assertEquals(new ActivationReLU(), l5.getActivationFn());
|
|
|
|
|
assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5));
|
|
|
|
|
|
|
|
|
|
MergeVertex mv = (MergeVertex) net.getVertex("concat");
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(PoolingType.MAX, gpl.getPoolingType());
|
|
|
|
|
assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions());
|
|
|
|
|
assertTrue(gpl.isCollapseDimensions());
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
OutputLayer outl = (OutputLayer) net.getLayer("out").getLayerConfiguration();
|
2019-08-29 17:31:57 -07:00
|
|
|
assertEquals(3, outl.getNOut());
|
2023-04-24 18:09:11 +02:00
|
|
|
assertEquals(new LossMCXENT(), outl.getLossFunction());
|
2019-08-29 17:31:57 -07:00
|
|
|
|
|
|
|
|
INDArray outExp;
|
|
|
|
|
File f2 = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_Output_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) {
|
|
|
|
|
outExp = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray in;
|
|
|
|
|
File f3 = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_Input_100b4.bin");
|
|
|
|
|
try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) {
|
|
|
|
|
in = Nd4j.read(dis);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
INDArray outAct = net.output(in)[0];
|
|
|
|
|
|
|
|
|
|
assertEquals(outExp, outAct);
|
|
|
|
|
}
|
|
|
|
|
}
|