More test fixes
parent
82dec223ac
commit
5c98c5e1ed
|
@ -11,3 +11,5 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.
|
|||
rm cmake-3.24.2-linux-x86_64.sh
|
||||
|
||||
|
||||
RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf
|
||||
|
||||
|
|
16
build.gradle
16
build.gradle
|
@ -44,7 +44,6 @@ ext {
|
|||
|
||||
scalaVersion = "2.12"
|
||||
logger.quiet("Scala main version is set to {}", scalaVersion)
|
||||
logger.quiet("Running java {}", JavaVersion.current())
|
||||
}
|
||||
|
||||
configurations.all {
|
||||
|
@ -64,8 +63,8 @@ allprojects { Project proj ->
|
|||
|
||||
|
||||
plugins.withType(JavaPlugin) {
|
||||
sourceCompatibility = JavaVersion.VERSION_11
|
||||
targetCompatibility = JavaVersion.VERSION_1_8
|
||||
sourceCompatibility = 11
|
||||
targetCompatibility = 1.8
|
||||
tasks.withType(JavaCompile) {
|
||||
options.release = 8
|
||||
}
|
||||
|
@ -107,17 +106,14 @@ allprojects { Project proj ->
|
|||
}
|
||||
|
||||
plugins.withType(MavenPublishPlugin) {
|
||||
|
||||
publishing {
|
||||
publications {
|
||||
if(! proj.name.contains("cavis-full")) {
|
||||
mavenJava(MavenPublication) {
|
||||
/* Need to verify the property exists, as some
|
||||
mavenJava(MavenPublication) {
|
||||
/* Need to verify the property exists, as some
|
||||
modules may not declare it (i.e. the java-platform plugin)
|
||||
*/
|
||||
if (components.hasProperty("java")) {
|
||||
from components.java
|
||||
}
|
||||
if (components.hasProperty("java") && !proj.name.equals("cavis-native-lib")) {
|
||||
from components.java
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,16 +129,4 @@ echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf
|
|||
# Buildparameter: #
|
||||
|
||||
-P\<xxx>\
|
||||
CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2
|
||||
|
||||
# Zeppelin Spark dependencies #
|
||||
3
|
||||
|
||||
|
||||
To add the dependency to the language models, use the following format in the Dependencies section of the of the Spark Interpreter configuration (Interpreters -> Spark -> Edit -> Dependencies):
|
||||
|
||||
groupId:artifactId:packaging:classifier:version
|
||||
|
||||
In your case it should work with
|
||||
|
||||
edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0
|
||||
CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2
|
|
@ -68,8 +68,8 @@ dependencies {
|
|||
api "org.projectlombok:lombok:1.18.24"
|
||||
|
||||
/*Logging*/
|
||||
api 'org.slf4j:slf4j-api:1.7.30'
|
||||
api 'org.slf4j:slf4j-simple:1.7.25'
|
||||
api 'org.slf4j:slf4j-api:2.0.3'
|
||||
api 'org.slf4j:slf4j-simple:2.0.3'
|
||||
|
||||
api "org.apache.logging.log4j:log4j-core:2.17.0"
|
||||
api "ch.qos.logback:logback-classic:1.2.3"
|
||||
|
|
|
@ -87,14 +87,14 @@ public class BalancedPathFilter extends RandomPathFilter {
|
|||
|
||||
protected boolean acceptLabel(String name) {
|
||||
if (labels == null || labels.length == 0) {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
for (String label : labels) {
|
||||
if (name.equals(label)) {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter {
|
|||
URI path = paths[i];
|
||||
Writable label = labelGenerator.getLabelForPath(path);
|
||||
if (!acceptLabel(label.toString())) {
|
||||
continue;
|
||||
continue; //we skip label in case it is null, empty or already in the collection
|
||||
}
|
||||
List<URI> pathList = labelPaths.get(label);
|
||||
if (pathList == null) {
|
||||
|
|
|
@ -20,13 +20,15 @@
|
|||
|
||||
package org.datavec.api.util.files;
|
||||
|
||||
import lombok.NonNull;
|
||||
|
||||
import java.io.File;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
|
||||
public class URIUtil {
|
||||
|
||||
public static URI fileToURI(File f) {
|
||||
public static URI fileToURI(@NonNull File f) {
|
||||
try {
|
||||
// manually construct URI (this is faster)
|
||||
String sp = slashify(f.getAbsoluteFile().getPath(), false);
|
||||
|
|
|
@ -23,7 +23,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
|||
dependencies {
|
||||
implementation "com.codepoetics:protonpack:1.15"
|
||||
implementation projects.cavisDatavec.cavisDatavecApi
|
||||
implementation projects.cavisDatavec.cavisDatavecArrow
|
||||
implementation projects.cavisDatavec.cavisDatavecData.cavisDatavecDataArrow
|
||||
|
||||
implementation projects.cavisDnn.cavisDnnApi
|
||||
implementation "com.google.guava:guava"
|
||||
|
|
|
@ -40,6 +40,8 @@ public class DL4JSystemProperties {
|
|||
*/
|
||||
public static final String DL4J_RESOURCES_DIR_PROPERTY = "org.deeplearning4j.resources.directory";
|
||||
|
||||
public static final String DISABLE_HELPER_PROPERTY = "org.deeplearning4j.disablehelperloading";
|
||||
public static final String HELPER_DISABLE_DEFAULT_VALUE = "false";
|
||||
/**
|
||||
* Applicability: Numerous modules, including deeplearning4j-datasets and deeplearning4j-zoo<br>
|
||||
* Description: Used to set the base URL for hosting of resources such as datasets (like MNIST) and pretrained
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions;
|
|||
import java.io.File;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class DL4JResources {
|
||||
|
||||
|
@ -128,6 +129,10 @@ public class DL4JResources {
|
|||
baseDirectory = f;
|
||||
}
|
||||
|
||||
public static void setBaseDirectory(@NonNull Path p) {
|
||||
setBaseDirectory(p.toFile());
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The base storage directory for DL4J resources
|
||||
*/
|
||||
|
|
|
@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.base.MnistFetcher;
|
||||
import org.deeplearning4j.common.resources.DL4JResources;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
|
@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@org.junit.jupiter.api.Timeout(300)
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class MnistFetcherTest extends BaseDL4JTest {
|
||||
|
||||
@TempDir
|
||||
public File testDir;
|
||||
|
||||
@BeforeAll
|
||||
public void setup() throws Exception {
|
||||
DL4JResources.setBaseDirectory(testDir);
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public void after() {
|
||||
DL4JResources.resetBaseDirectoryLocation();
|
||||
}
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testMnist() throws Exception {
|
||||
DL4JResources.setBaseDirectory(testDir);
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
||||
int count = 0;
|
||||
while(iter.hasNext()){
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -48,6 +49,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
|
||||
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
||||
|
||||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
|
@ -231,11 +233,6 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
|||
runModelConfigTest("modelimport/keras/configs/keras2/simple_add_tf_keras_2.json");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 999999999L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void embeddingConcatTest() throws Exception {
|
||||
runModelConfigTest("/modelimport/keras/configs/keras2/model_concat_embedding_sequences_tf_keras_2.json");
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.convolution.Convolution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -45,12 +46,8 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
|
|||
* Test import of Keras models.
|
||||
*/
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class KerasModelImportTest extends BaseDL4JTest {
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 9999999999999L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testH5WithoutTensorflowScope() throws Exception {
|
||||
MultiLayerNetwork model = loadModel("modelimport/keras/tfscope/model.h5");
|
||||
|
|
|
@ -44,10 +44,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
|
||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
|
@ -78,6 +75,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
* @author dave@skymind.io, Max Pumperla
|
||||
*/
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||
private static final String GROUP_ATTR_INPUTS = "inputs";
|
||||
private static final String GROUP_ATTR_OUTPUTS = "outputs";
|
||||
|
@ -93,11 +91,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
@TempDir
|
||||
public File testDir;
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 900000000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources
|
||||
}
|
||||
|
||||
@Test
|
||||
public void fileNotFoundEndToEnd() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/foo/bar.h5";
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -45,16 +46,12 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class KerasWeightSettingTests extends BaseDL4JTest {
|
||||
|
||||
@TempDir
|
||||
private File testDir;
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 9999999L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleLayersWithWeights() throws Exception {
|
||||
int[] kerasVersions = new int[]{1, 2};
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -42,13 +43,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class TsneTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L;
|
||||
}
|
||||
|
||||
@TempDir
|
||||
public File testDir;
|
||||
|
||||
|
|
|
@ -78,11 +78,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
@Timeout(240)
|
||||
public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return isIntegrationTests() ? 600_000 : 240_000;
|
||||
}
|
||||
|
||||
@TempDir
|
||||
public File testDir;
|
||||
|
||||
|
|
|
@ -60,14 +60,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
|
||||
|
||||
@Slf4j
|
||||
@Timeout(300)
|
||||
public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||
WordVectors word2vec;
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return isIntegrationTests() ? 240000 : 60000;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws Exception {
|
||||
word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
|
|||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
|
@ -46,13 +47,9 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
|
||||
@Timeout(300)
|
||||
public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 60000L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Basically all we want from this test - being able to finish without exceptions.
|
||||
*/
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
apply from: "${project.rootProject.projectDir}/createTestBackends.gradle"
|
||||
|
||||
dependencies {
|
||||
implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.deeplearning4j.nn.layers;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.deeplearning4j.common.config.DL4JSystemProperties.DISABLE_HELPER_PROPERTY;
|
||||
import static org.deeplearning4j.common.config.DL4JSystemProperties.HELPER_DISABLE_DEFAULT_VALUE;
|
||||
|
||||
/**
|
||||
* Simple meta helper util class for instantiating
|
||||
* platform specific layer helpers that handle interaction with
|
||||
* lower level libraries like cudnn and onednn.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
public class HelperUtils {
|
||||
|
||||
|
||||
/**
|
||||
* Creates a {@link LayerHelper}
|
||||
* for use with platform specific code.
|
||||
* @param <T> the actual class type to be returned
|
||||
* @param cudnnHelperClassName the cudnn class name
|
||||
* @param oneDnnClassName the one dnn class name
|
||||
* @param layerHelperSuperClass the layer helper super class
|
||||
* @param layerName the name of the layer to be created
|
||||
* @param arguments the arguments to be used in creation of the layer
|
||||
* @return
|
||||
*/
|
||||
public static <T extends LayerHelper> T createHelper(String cudnnHelperClassName,
|
||||
String oneDnnClassName,
|
||||
Class<? extends LayerHelper> layerHelperSuperClass,
|
||||
String layerName,
|
||||
Object... arguments) {
|
||||
|
||||
Boolean disabled = Boolean.parseBoolean(System.getProperty(DISABLE_HELPER_PROPERTY,HELPER_DISABLE_DEFAULT_VALUE));
|
||||
if(disabled) {
|
||||
System.out.println("Disabled helper creation, returning null");
|
||||
return null;
|
||||
}
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
LayerHelper helperRet = null;
|
||||
if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) {
|
||||
if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) {
|
||||
log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName);
|
||||
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance(
|
||||
cudnnHelperClassName,
|
||||
(Class<? super LayerHelper>) layerHelperSuperClass,
|
||||
new Object[]{arguments});
|
||||
log.debug("Cudnn helper {} successfully initialized",cudnnHelperClassName);
|
||||
|
||||
}
|
||||
else {
|
||||
log.warn("Unable to find class {} using the classloader set for Dl4jClassLoading. Trying to use class loader that loaded the class {} instead.",cudnnHelperClassName,layerHelperSuperClass.getName());
|
||||
ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader();
|
||||
DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass);
|
||||
try {
|
||||
helperRet = (LayerHelper) DL4JClassLoading.<LayerHelper>createNewInstance(
|
||||
cudnnHelperClassName,
|
||||
(Class<? super LayerHelper>) layerHelperSuperClass,
|
||||
arguments);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.warn("Unable to use helper implementation {} for helper type {}, please check your classpath. Falling back to built in normal methods for now.",cudnnHelperClassName,layerHelperSuperClass.getName());
|
||||
}
|
||||
|
||||
log.warn("Returning class loader to original one.");
|
||||
DL4JClassLoading.setDl4jClassloader(classLoader);
|
||||
|
||||
}
|
||||
|
||||
if (helperRet != null && !helperRet.checkSupported()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if(helperRet != null) {
|
||||
log.debug("{} successfully initialized",cudnnHelperClassName);
|
||||
}
|
||||
|
||||
} else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) {
|
||||
helperRet = DL4JClassLoading.<LayerHelper>createNewInstance(
|
||||
oneDnnClassName,
|
||||
arguments);
|
||||
log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName);
|
||||
}
|
||||
|
||||
if (helperRet != null && !helperRet.checkSupported()) {
|
||||
log.debug("Removed helper {} as not supported", helperRet.getClass());
|
||||
return null;
|
||||
}
|
||||
|
||||
return (T) helperRet;
|
||||
}
|
||||
|
||||
}
|
|
@ -37,4 +37,6 @@ public interface LayerHelper {
|
|||
*/
|
||||
Map<String,Long> helperMemoryUse();
|
||||
|
||||
boolean checkSupported();
|
||||
|
||||
}
|
||||
|
|
|
@ -59,4 +59,8 @@ public class BaseMKLDNNHelper {
|
|||
}
|
||||
}
|
||||
|
||||
public boolean checkSupported() {
|
||||
return mklDnnEnabled();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -197,4 +197,9 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
|||
public Map<String, Long> helperMemoryUse() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkSupported() {
|
||||
return BaseMKLDNNHelper.mklDnnEnabled();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
@ -43,6 +44,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
public class MKLDNNLSTMHelper implements LSTMHelper {
|
||||
public MKLDNNLSTMHelper(DataType dataType) {}
|
||||
@Override
|
||||
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
|
||||
//TODO check other activation functions for MKLDNN
|
||||
|
@ -159,6 +161,11 @@ public class MKLDNNLSTMHelper implements LSTMHelper {
|
|||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkSupported() {
|
||||
return BaseMKLDNNHelper.mklDnnEnabled();
|
||||
}
|
||||
|
||||
private int activationToArg(IActivation a){
|
||||
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
if(a instanceof ActivationTanH)
|
||||
|
|
|
@ -94,4 +94,5 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
|
|||
public Map<String, Long> helperMemoryUse() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -592,6 +592,11 @@ public class BidirectionalLayer implements RecurrentLayer {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkSupported() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,50 +20,21 @@
|
|||
package org.deeplearning4j.nn.layers;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
||||
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.*;
|
||||
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
||||
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationELU;
|
||||
import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
/**
|
||||
*/
|
||||
@DisplayName("Activation Layer Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.CUSTOM_FUNCTIONALITY)
|
||||
@Tag(TagNames.DL4J_OLD_API)
|
||||
public class HelperUtilsTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -25,6 +25,9 @@ dependencies {
|
|||
implementation 'org.slf4j:slf4j-api'
|
||||
implementation "com.google.guava:guava"
|
||||
|
||||
implementation "com.fasterxml.jackson.core:jackson-annotations"
|
||||
implementation "com.fasterxml.jackson.core:jackson-core"
|
||||
|
||||
implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerCore
|
||||
implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerClient
|
||||
implementation projects.cavisDnn.cavisDnnCore
|
||||
|
@ -36,7 +39,6 @@ dependencies {
|
|||
|
||||
testImplementation projects.cavisUi.cavisUiStandalone
|
||||
|
||||
|
||||
testImplementation projects.cavisDnn.cavisDnnCommonTests
|
||||
testImplementation projects.cavisUi.cavisUiModel
|
||||
testImplementation projects.cavisUi.cavisUiVertx
|
||||
|
|
|
@ -1,66 +1,52 @@
|
|||
plugins {
|
||||
id 'java-library'
|
||||
id 'maven-publish'
|
||||
id 'com.github.johnrengelman.shadow' version '7.1.2'
|
||||
}
|
||||
|
||||
apply from: rootProject.projectDir.path+"/chooseBackend.gradle"
|
||||
configurations.archives.artifacts.with { archives ->
|
||||
archives.each {
|
||||
println(it.name)
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
afterEvaluate {
|
||||
//Todo clean this
|
||||
api platform(project(":cavis-common-platform"))
|
||||
//api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
||||
//api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
||||
//api 'org.slf4j:slf4j-simple:2.0.3'
|
||||
//api 'org.slf4j:slf4j-api:2.0.3'
|
||||
//TODO for the two below.. either platform specific uber jars or a single big one with all platforms
|
||||
//api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64"
|
||||
//Todo clean this
|
||||
api platform(project(":cavis-common-platform"))
|
||||
api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise
|
||||
api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5"
|
||||
api 'org.slf4j:slf4j-simple:2.0.3'
|
||||
api 'org.slf4j:slf4j-api:2.0.3'
|
||||
//api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86"
|
||||
|
||||
rootProject.getAllprojects().each { Project sproj ->
|
||||
if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
||||
&& !sproj.name.equals("Cavis")
|
||||
&& !sproj.name.equals("cavis-datavec")
|
||||
&& !sproj.name.equals("cavis-dnn")
|
||||
&& !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib")
|
||||
&& !sproj.name.equals("cavis-nd4j")
|
||||
&& !sproj.name.equals("cavis-ui")
|
||||
&& !sproj.name.equals("cavis-zoo")) {
|
||||
api project(path: sproj.path, configuration: 'runtimeElements')
|
||||
rootProject.getAllprojects().each { Project sproj ->
|
||||
if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform")
|
||||
&& !sproj.name.equals("Cavis")
|
||||
&& !sproj.name.equals("cavis-datavec")
|
||||
&& !sproj.name.equals("cavis-dnn")
|
||||
&& !sproj.name.equals("cavis-native")
|
||||
&& !sproj.name.equals("cavis-nd4j")
|
||||
&& !sproj.name.equals("cavis-ui")
|
||||
&& !sproj.name.equals("cavis-zoo")) {
|
||||
//compileOnly project(""+sproj.path)
|
||||
api sproj
|
||||
if(! sproj.configurations.empty) {
|
||||
//compileOnly project(sproj.getPath())
|
||||
|
||||
/*
|
||||
sproj.configurations.each {Configuration conf ->
|
||||
conf.dependencies.each {Dependency dep ->
|
||||
compileOnly dep
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
}
|
||||
// if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements")
|
||||
// if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements")
|
||||
|
||||
api(projects.cavisNative.cavisNativeLib) {
|
||||
capabilities {
|
||||
//if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
|
||||
if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
|
||||
}
|
||||
}
|
||||
api(projects.cavisNative.cavisNativeLib) {
|
||||
capabilities {
|
||||
if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version)
|
||||
//if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version)
|
||||
}
|
||||
}
|
||||
|
||||
//if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation")
|
||||
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation")
|
||||
//if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath")
|
||||
|
||||
/*
|
||||
api (project(':cavis-native:cavis-native-lib')) {
|
||||
capabilities {
|
||||
if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support")
|
||||
//if(withCuda()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cuda-support")
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
tasks.getByName("jar") {
|
||||
|
||||
|
@ -85,39 +71,19 @@ tasks.getByName("jar") {
|
|||
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
|
||||
/*
|
||||
artifacts {
|
||||
archives shadowJar
|
||||
}
|
||||
|
||||
shadowJar {
|
||||
enabled true;
|
||||
zip64 true //need this to support jars with more than 65535 entries
|
||||
archiveClassifier.set('')
|
||||
archives customFatJar
|
||||
}
|
||||
*/
|
||||
|
||||
publishing {
|
||||
publications {
|
||||
/*mavenJava(MavenPublication) {
|
||||
//artifact customFatJar
|
||||
mavenJava(MavenPublication) {
|
||||
// artifact customFatJar
|
||||
// from components.java
|
||||
/* pom.withXml {
|
||||
def dependenciesNode = asNode().dependencies
|
||||
def dependencyNode = dependenciesNode.appendNode()
|
||||
|
||||
dependencyNode.appendNode('groupId', 'net.brutex.cavis')
|
||||
dependencyNode.appendNode('artifactId', 'cavis-native-lib')
|
||||
dependencyNode.appendNode('version', '1.0.0-SNAPSHOT')
|
||||
//dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu')
|
||||
//dependencyNode.appendNode('scope', 'compile')
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
shadow(MavenPublication) { publication ->
|
||||
project.shadow.component(publication)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,7 +116,6 @@ include ':cavis-dnn:cavis-dnn-spark:cavis-dnn-spark-parameterserver'
|
|||
include ':cavis-dnn:cavis-dnn-tsne'
|
||||
include ':cavis-datavec'
|
||||
include ':cavis-datavec:cavis-datavec-api'
|
||||
include ':cavis-datavec:dvec-api'
|
||||
include ':cavis-datavec:cavis-datavec-data'
|
||||
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow'
|
||||
include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image'
|
||||
|
@ -149,5 +148,7 @@ include ':cavis-ui:cavis-ui-standalone'
|
|||
include ':cavis-ui:cavis-ui-vertx'
|
||||
include ':cavis-zoo'
|
||||
include ':cavis-zoo:cavis-zoo-models'
|
||||
|
||||
include ':brutex-extended-tests'
|
||||
include ':cavis-full'
|
||||
|
||||
|
|
Loading…
Reference in New Issue