315 lines
12 KiB
Java
315 lines
12 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j;
|
|
|
|
import org.apache.commons.compress.utils.IOUtils;
|
|
import org.deeplearning4j.nn.api.Layer;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
|
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
|
|
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
|
|
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
|
|
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.util.ModelSerializer;
|
|
import org.nd4j.base.Preconditions;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
|
|
|
import java.io.*;
|
|
import java.lang.reflect.Field;
|
|
import java.util.List;
|
|
import java.util.Random;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
import static org.junit.Assert.assertNotNull;
|
|
|
|
public class TestUtils {
|
|
|
|
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
|
|
|
|
MultiLayerNetwork restored;
|
|
try {
|
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
ModelSerializer.writeModel(net, baos, true);
|
|
byte[] bytes = baos.toByteArray();
|
|
|
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
|
|
|
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
|
assertEquals(net.params(), restored.params());
|
|
} catch (IOException e){
|
|
//Should never happen
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
|
|
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
|
|
serializeDeserializeJava(conf);
|
|
|
|
return restored;
|
|
}
|
|
|
|
public static ComputationGraph testModelSerialization(ComputationGraph net){
|
|
ComputationGraph restored;
|
|
try {
|
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
ModelSerializer.writeModel(net, baos, true);
|
|
byte[] bytes = baos.toByteArray();
|
|
|
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
|
|
|
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
|
assertEquals(net.params(), restored.params());
|
|
} catch (IOException e){
|
|
//Should never happen
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
|
ComputationGraphConfiguration conf = net.getConfiguration();
|
|
serializeDeserializeJava(conf);
|
|
|
|
return restored;
|
|
}
|
|
|
|
private static <T> T serializeDeserializeJava(T object){
|
|
byte[] bytes;
|
|
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
|
oos.writeObject(object);
|
|
oos.close();
|
|
bytes = baos.toByteArray();
|
|
} catch (IOException e){
|
|
//Should never happen
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
T out;
|
|
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
|
|
out = (T)ois.readObject();
|
|
} catch (IOException | ClassNotFoundException e){
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
assertEquals(object, out);
|
|
return out;
|
|
}
|
|
|
|
public static INDArray randomOneHot(long examples, long nOut){
|
|
return randomOneHot(examples, nOut, new Random(12345));
|
|
}
|
|
|
|
public static INDArray randomOneHot(DataType dataType, long examples, long nOut){
|
|
return randomOneHot(dataType, examples, nOut, new Random(12345));
|
|
}
|
|
|
|
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
|
|
return randomOneHot(examples, nOut, new Random(rngSeed));
|
|
}
|
|
|
|
public static INDArray randomOneHot(long examples, long nOut, Random rng) {
|
|
return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng);
|
|
}
|
|
|
|
public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){
|
|
INDArray arr = Nd4j.create(dataType, examples, nOut);
|
|
for( int i=0; i<examples; i++ ){
|
|
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
|
}
|
|
return arr;
|
|
}
|
|
|
|
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength){
|
|
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random());
|
|
}
|
|
|
|
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, long rngSeed){
|
|
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
|
|
}
|
|
|
|
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
|
|
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
|
|
for( int i=0; i<minibatch; i++ ){
|
|
for( int j=0; j<tsLength; j++ ){
|
|
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
public static INDArray randomBernoulli(int... shape) {
|
|
return randomBernoulli(0.5, shape);
|
|
}
|
|
|
|
public static INDArray randomBernoulli(double p, int... shape){
|
|
INDArray ret = Nd4j.createUninitialized(shape);
|
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
|
|
return ret;
|
|
}
|
|
|
|
public static void writeStreamToFile(File out, InputStream is) throws IOException {
|
|
byte[] b = IOUtils.toByteArray(is);
|
|
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
|
|
os.write(b);
|
|
}
|
|
}
|
|
|
|
public static L1Regularization getL1Reg(List<Regularization> l){
|
|
for(Regularization r : l){
|
|
if(r instanceof L1Regularization){
|
|
return (L1Regularization) r;
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
public static L2Regularization getL2Reg(BaseLayer baseLayer){
|
|
return getL2Reg(baseLayer.getRegularization());
|
|
}
|
|
|
|
public static L2Regularization getL2Reg(List<Regularization> l){
|
|
for(Regularization r : l){
|
|
if(r instanceof L2Regularization){
|
|
return (L2Regularization) r;
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
public static WeightDecay getWeightDecayReg(BaseLayer bl){
|
|
return getWeightDecayReg(bl.getRegularization());
|
|
}
|
|
|
|
public static WeightDecay getWeightDecayReg(List<Regularization> l){
|
|
for(Regularization r : l){
|
|
if(r instanceof WeightDecay){
|
|
return (WeightDecay) r;
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
public static double getL1(BaseLayer layer) {
|
|
List<Regularization> l = layer.getRegularization();
|
|
return getL1(l);
|
|
}
|
|
|
|
public static double getL1(List<Regularization> l){
|
|
L1Regularization l1Reg = null;
|
|
for(Regularization reg : l){
|
|
if(reg instanceof L1Regularization)
|
|
l1Reg = (L1Regularization) reg;
|
|
}
|
|
assertNotNull(l1Reg);
|
|
return l1Reg.getL1().valueAt(0,0);
|
|
}
|
|
|
|
public static double getL2(BaseLayer layer) {
|
|
List<Regularization> l = layer.getRegularization();
|
|
return getL2(l);
|
|
}
|
|
|
|
public static double getL2(List<Regularization> l){
|
|
L2Regularization l2Reg = null;
|
|
for(Regularization reg : l){
|
|
if(reg instanceof L2Regularization)
|
|
l2Reg = (L2Regularization) reg;
|
|
}
|
|
assertNotNull(l2Reg);
|
|
return l2Reg.getL2().valueAt(0,0);
|
|
}
|
|
|
|
public static double getL1(AbstractSameDiffLayer layer){
|
|
return getL1(layer.getRegularization());
|
|
}
|
|
|
|
public static double getL2(AbstractSameDiffLayer layer){
|
|
return getL2(layer.getRegularization());
|
|
}
|
|
|
|
public static double getWeightDecay(BaseLayer layer) {
|
|
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
|
|
}
|
|
|
|
public static void removeHelper(Layer layer) throws Exception {
|
|
removeHelpers(new Layer[]{layer});
|
|
}
|
|
|
|
public static void removeHelpers(Layer[] layers) throws Exception {
|
|
for(Layer l : layers){
|
|
|
|
if(l instanceof ConvolutionLayer){
|
|
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
|
|
f1.setAccessible(true);
|
|
f1.set(l, null);
|
|
} else if(l instanceof SubsamplingLayer){
|
|
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
|
|
f2.setAccessible(true);
|
|
f2.set(l, null);
|
|
} else if(l instanceof BatchNormalization) {
|
|
Field f3 = BatchNormalization.class.getDeclaredField("helper");
|
|
f3.setAccessible(true);
|
|
f3.set(l, null);
|
|
} else if(l instanceof LSTM){
|
|
Field f4 = LSTM.class.getDeclaredField("helper");
|
|
f4.setAccessible(true);
|
|
f4.set(l, null);
|
|
} else if(l instanceof LocalResponseNormalization){
|
|
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
|
|
f5.setAccessible(true);
|
|
f5.set(l, null);
|
|
}
|
|
|
|
|
|
if(l.getHelper() != null){
|
|
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
|
|
}
|
|
}
|
|
}
|
|
|
|
public static void assertHelperPresent(Layer layer){
|
|
|
|
}
|
|
|
|
public static void assertHelpersPresent(Layer[] layers) throws Exception {
|
|
for(Layer l : layers){
|
|
//Don't use instanceof here - there are sub conv subclasses
|
|
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
|
|
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
|
|
}
|
|
}
|
|
}
|
|
|
|
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
|
|
for(Layer l : layers){
|
|
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
|
|
}
|
|
}
|
|
}
|