92 lines
3.5 KiB
Java
Raw Normal View History

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
package org.deeplearning4j.cuda;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.nd4j.common.base.Preconditions;
2021-03-20 19:06:24 +09:00
import org.nd4j.common.tests.tags.NativeTag;
2019-06-06 15:21:15 +03:00
import java.lang.reflect.Field;
/**
* Test utility methods specific to CuDNN
*
* @author Alex Black
*/
2021-03-20 19:06:24 +09:00
@NativeTag
2019-06-06 15:21:15 +03:00
public class CuDNNTestUtils {
private CuDNNTestUtils(){ }
public static void removeHelpers(Layer[] layers) throws Exception {
for(Layer l : layers){
if(l instanceof ConvolutionLayer){
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
f1.setAccessible(true);
f1.set(l, null);
} else if(l instanceof SubsamplingLayer){
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
f2.setAccessible(true);
f2.set(l, null);
} else if(l instanceof BatchNormalization) {
Field f3 = BatchNormalization.class.getDeclaredField("helper");
f3.setAccessible(true);
f3.set(l, null);
} else if(l instanceof LSTM){
Field f4 = LSTM.class.getDeclaredField("helper");
f4.setAccessible(true);
f4.set(l, null);
} else if(l instanceof LocalResponseNormalization){
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
f5.setAccessible(true);
f5.set(l, null);
}
if(l.getHelper() != null){
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
}
}
}
public static void assertHelpersPresent(Layer[] layers) throws Exception {
for(Layer l : layers){
//Don't use instanceof here - there are sub conv subclasses
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
}
}
}
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
for(Layer l : layers){
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
}
}
}