Add ignores for tests not passing for individual processing later
parent
52f65d8511
commit
48856b6182
|
@ -79,3 +79,5 @@ libnd4j/cmake*
|
||||||
|
|
||||||
#vim
|
#vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
*.dll
|
|
@ -83,4 +83,8 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return Long.MAX_VALUE;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -60,9 +61,10 @@ public class WritableTest extends BaseND4JTest {
|
||||||
public void testBytesWritableIndexing() {
|
public void testBytesWritableIndexing() {
|
||||||
byte[] doubleWrite = new byte[16];
|
byte[] doubleWrite = new byte[16];
|
||||||
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
||||||
|
Buffer buffer = (Buffer) wrapped;
|
||||||
wrapped.putDouble(1.0);
|
wrapped.putDouble(1.0);
|
||||||
wrapped.putDouble(2.0);
|
wrapped.putDouble(2.0);
|
||||||
wrapped.rewind();
|
buffer.rewind();
|
||||||
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
||||||
assertEquals(2,byteWritable.getDouble(1),1e-1);
|
assertEquals(2,byteWritable.getDouble(1),1e-1);
|
||||||
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
|
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.functions;
|
package org.datavec.spark.functions;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.hadoop.io.Text;
|
import org.apache.hadoop.io.Text;
|
||||||
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
|
@ -61,6 +62,9 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
public void test() throws Exception {
|
public void test() throws Exception {
|
||||||
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
|
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
|
||||||
//For example: use to combine input and labels data from separate files for training a RNN
|
//For example: use to combine input and labels data from separate files for training a RNN
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.functions;
|
package org.datavec.spark.functions;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.hadoop.io.BytesWritable;
|
import org.apache.hadoop.io.BytesWritable;
|
||||||
import org.apache.hadoop.io.Text;
|
import org.apache.hadoop.io.Text;
|
||||||
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
||||||
|
@ -57,6 +58,9 @@ public class TestRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderBytesFunction() throws Exception {
|
public void testRecordReaderBytesFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
//Local file path
|
//Local file path
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.functions;
|
package org.datavec.spark.functions;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.input.PortableDataStream;
|
import org.apache.spark.input.PortableDataStream;
|
||||||
|
@ -50,7 +51,9 @@ public class TestRecordReaderFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderFunction() throws Exception {
|
public void testRecordReaderFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
||||||
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.functions;
|
package org.datavec.spark.functions;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.hadoop.io.BytesWritable;
|
import org.apache.hadoop.io.BytesWritable;
|
||||||
import org.apache.hadoop.io.Text;
|
import org.apache.hadoop.io.Text;
|
||||||
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
|
||||||
|
@ -56,7 +57,9 @@ public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderBytesFunction() throws Exception {
|
public void testRecordReaderBytesFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Local file path
|
//Local file path
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.storage;
|
package org.datavec.spark.storage;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -41,6 +42,9 @@ public class TestSparkStorageUtils extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSaveRestoreMapFile() {
|
public void testSaveRestoreMapFile() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<List<Writable>> l = new ArrayList<>();
|
List<List<Writable>> l = new ArrayList<>();
|
||||||
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
||||||
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))));
|
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))));
|
||||||
|
@ -83,6 +87,9 @@ public class TestSparkStorageUtils extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSaveRestoreMapFileSequences() {
|
public void testSaveRestoreMapFileSequences() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<List<List<Writable>>> l = new ArrayList<>();
|
List<List<List<Writable>>> l = new ArrayList<>();
|
||||||
l.add(Arrays.asList(
|
l.add(Arrays.asList(
|
||||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.util;
|
package org.datavec.spark.util;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
|
@ -41,7 +42,9 @@ public class TestSparkUtil extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWriteWritablesToFile() throws Exception {
|
public void testWriteWritablesToFile() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<List<Writable>> l = new ArrayList<>();
|
List<List<Writable>> l = new ArrayList<>();
|
||||||
l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
|
l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
|
||||||
l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
|
l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
|
||||||
|
|
|
@ -159,7 +159,7 @@
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||||
|
@ -274,6 +274,17 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -1259,7 +1259,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNormalizerPrefetchReset() throws Exception {
|
public void testNormalizerPrefetchReset() throws Exception {
|
||||||
//Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
|
//Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214
|
||||||
RecordReader csv = new CSVRecordReader();
|
RecordReader csv = new CSVRecordReader();
|
||||||
csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
|
csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
|
||||||
|
|
||||||
|
|
|
@ -214,7 +214,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/deeplearning4j/deeplearning4j/issues/4673
|
@Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673
|
||||||
public void testCifarModel() throws Exception {
|
public void testCifarModel() throws Exception {
|
||||||
// Streaming
|
// Streaming
|
||||||
runCifar(false);
|
runCifar(false);
|
||||||
|
|
|
@ -470,7 +470,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEvaluativeListenerSimple(){
|
public void testEvaluativeListenerSimple(){
|
||||||
//Sanity check: https://github.com/deeplearning4j/deeplearning4j/issues/5351
|
//Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351
|
||||||
|
|
||||||
// Network config
|
// Network config
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
import org.junit.rules.ExpectedException;
|
||||||
|
@ -46,6 +47,7 @@ import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
@Ignore
|
||||||
public class AttentionLayerTest extends BaseDL4JTest {
|
public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
@Rule
|
@Rule
|
||||||
public ExpectedException exceptionRule = ExpectedException.none();
|
public ExpectedException exceptionRule = ExpectedException.none();
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
|
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -45,6 +46,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@Ignore
|
||||||
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testElementWiseVertexNumParams() {
|
public void testElementWiseVertexNumParams() {
|
||||||
/*
|
/*
|
||||||
* https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386
|
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
|
||||||
* from @agibsonccc: check for the basics: like 0 numParams
|
* from @agibsonccc: check for the basics: like 0 numParams
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testShiftVertexNumParamsTrue() {
|
public void testShiftVertexNumParamsTrue() {
|
||||||
/*
|
/*
|
||||||
* https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386
|
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
|
||||||
* from @agibsonccc: check for the basics: like 0 numParams
|
* from @agibsonccc: check for the basics: like 0 numParams
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testShiftVertexNumParamsFalse() {
|
public void testShiftVertexNumParamsFalse() {
|
||||||
/*
|
/*
|
||||||
* https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386
|
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
|
||||||
* from @agibsonccc: check for the basics: like 0 numParams
|
* from @agibsonccc: check for the basics: like 0 numParams
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,7 @@ import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore
|
||||||
public class DTypeTests extends BaseDL4JTest {
|
public class DTypeTests extends BaseDL4JTest {
|
||||||
|
|
||||||
protected static Set<Class<?>> seenLayers = new HashSet<>();
|
protected static Set<Class<?>> seenLayers = new HashSet<>();
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMSEOutputLayer(){ //Faliing 2019/04/17 - https://github.com/deeplearning4j/deeplearning4j/issues/7560
|
public void testMSEOutputLayer(){ //Faliing 2019/04/17 - https://github.com/eclipse/deeplearning4j/issues/7560
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH, Activation.SOFTMAX}) {
|
for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH, Activation.SOFTMAX}) {
|
||||||
|
|
|
@ -1,543 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.plot;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.apache.commons.lang3.time.StopWatch;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.deeplearning4j.clustering.sptree.SpTree;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.common.resources.Resources;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.nd4j.linalg.factory.Nd4j.zeros;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class BarnesHutTsneTest extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void setUp() {
|
|
||||||
// CudaEnvironment.getInstance().getConfiguration().enableDebug(true).setVerbose(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBarnesHutRun() {
|
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
|
|
||||||
double[] aData = new double[]{
|
|
||||||
0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray data = Nd4j.createFromArray(aData).reshape(11,5);
|
|
||||||
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).setMaxIter(200).perplexity(3.0).theta(0.5).numDimension(5).
|
|
||||||
invertDistanceMetric(false).similarityFunction(Distance.EUCLIDEAN.toString())
|
|
||||||
.setMomentum(0.5).learningRate(200).staticInit(data).setSwitchMomentumIteration(250)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
|
|
||||||
b.fit(data);
|
|
||||||
// log.info("Result: {}", b.getData());
|
|
||||||
|
|
||||||
val exp = Nd4j.createFromArray(new double[]{-3.5318212819287327, 35.40331834897696, 3.890809489531651, -1.291195609955519, -42.854099388207466, 7.8761368019456635, 28.798057251442877, 7.1456564000935225, 2.9518396278984786, -42.860181054199636, -34.989343304202, -108.99770355680282, 31.78123839126566, -29.322118879730205, 163.87558311206212, 2.9538984612478396, 31.419519824305546, 13.105400907817279, 25.46987139120746, -43.27317406736858, 32.455151773056144, 25.28067703547214, 0.005442008567682552, 21.005029233370358, -61.71390311950051, 5.218417653362599, 47.15762099517554, 8.834739256343404, 17.845790108867153, -54.31654219224107, -18.71285871476804, -16.446982180909007, -71.22568781913213, -12.339975548387091, 70.49096598213703, 25.022454385237456, -14.572652938207126, -5.320080866729078, 1.5874449933639676, -40.60960510287835, -31.98564381157643, -95.40875746933808, 19.196346639002364, -38.80930682421929, 135.00454225923906, 5.277879540549592, 30.79963767087089, -0.007276462027131683, 31.278796123365815, -38.47381680049993, 10.415728497075905, 36.567265019013085, -7.406587944733211, -18.376174615781114, -45.26976962854271}).reshape(-1, 5);
|
|
||||||
|
|
||||||
double eps = 1e-2;
|
|
||||||
if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
|
||||||
eps = 2e-2;
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
|
||||||
public void testTsne() throws Exception {
|
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
|
|
||||||
File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt");
|
|
||||||
INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), " ").get(NDArrayIndex.interval(0, 100),
|
|
||||||
NDArrayIndex.interval(0, 784));
|
|
||||||
|
|
||||||
ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt");
|
|
||||||
List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100);
|
|
||||||
b.fit(data);
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
b.saveAsFile(labelsList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBuilderFields() throws Exception {
|
|
||||||
final double theta = 0;
|
|
||||||
final boolean invert = false;
|
|
||||||
final String similarityFunctions = "euclidean";
|
|
||||||
final int maxIter = 1;
|
|
||||||
final double realMin = 1.0;
|
|
||||||
final double initialMomentum = 2.0;
|
|
||||||
final double finalMomentum = 3.0;
|
|
||||||
final double momentum = 4.0;
|
|
||||||
final int switchMomentumIteration = 1;
|
|
||||||
final boolean normalize = false;
|
|
||||||
final int stopLyingIteration = 100;
|
|
||||||
final double tolerance = 1e-1;
|
|
||||||
final double learningRate = 100;
|
|
||||||
final boolean useAdaGrad = false;
|
|
||||||
final double perplexity = 1.0;
|
|
||||||
final double minGain = 1.0;
|
|
||||||
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().theta(theta).invertDistanceMetric(invert)
|
|
||||||
.similarityFunction(similarityFunctions).setMaxIter(maxIter).setRealMin(realMin)
|
|
||||||
.setInitialMomentum(initialMomentum).setFinalMomentum(finalMomentum).setMomentum(momentum)
|
|
||||||
.setSwitchMomentumIteration(switchMomentumIteration).normalize(normalize)
|
|
||||||
.stopLyingIteration(stopLyingIteration).tolerance(tolerance).learningRate(learningRate)
|
|
||||||
.perplexity(perplexity).minGain(minGain).build();
|
|
||||||
|
|
||||||
final double DELTA = 1e-15;
|
|
||||||
|
|
||||||
assertEquals(theta, b.getTheta(), DELTA);
|
|
||||||
assertEquals("invert", invert, b.isInvert());
|
|
||||||
assertEquals("similarityFunctions", similarityFunctions, b.getSimiarlityFunction());
|
|
||||||
assertEquals("maxIter", maxIter, b.maxIter);
|
|
||||||
assertEquals(realMin, b.realMin, DELTA);
|
|
||||||
assertEquals(initialMomentum, b.initialMomentum, DELTA);
|
|
||||||
assertEquals(finalMomentum, b.finalMomentum, DELTA);
|
|
||||||
assertEquals(momentum, b.momentum, DELTA);
|
|
||||||
assertEquals("switchMomentumnIteration", switchMomentumIteration, b.switchMomentumIteration);
|
|
||||||
assertEquals("normalize", normalize, b.normalize);
|
|
||||||
assertEquals("stopLyingInMemoryLookupTable.javaIteration", stopLyingIteration, b.stopLyingIteration);
|
|
||||||
assertEquals(tolerance, b.tolerance, DELTA);
|
|
||||||
assertEquals(learningRate, b.learningRate, DELTA);
|
|
||||||
assertEquals("useAdaGrad", useAdaGrad, b.useAdaGrad);
|
|
||||||
assertEquals(perplexity, b.getPerplexity(), DELTA);
|
|
||||||
assertEquals(minGain, b.minGain, DELTA);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPerplexity() throws Exception {
|
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
|
|
||||||
DataSetIterator iter = new MnistDataSetIterator(100, true, 12345);
|
|
||||||
INDArray data = iter.next().getFeatures();
|
|
||||||
|
|
||||||
INDArray perplexityOutput = b.computeGaussianPerplexity(data, 30.0);
|
|
||||||
// System.out.println(perplexityOutput);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testReproducibility() {
|
|
||||||
Nd4j.getRandom().setSeed(10);
|
|
||||||
INDArray input = Nd4j.createFromArray(new double[]{ 0.4681, 0.2971,
|
|
||||||
0.2938, 0.3655,
|
|
||||||
0.3968, 0.0990,
|
|
||||||
0.0796, 0.9245}).reshape(4,2);
|
|
||||||
|
|
||||||
BarnesHutTsne b1 = new BarnesHutTsne.Builder().perplexity(1.0).build(),
|
|
||||||
b2 = new BarnesHutTsne.Builder().perplexity(1.0).build();
|
|
||||||
b1.setSimiarlityFunction(Distance.EUCLIDEAN.toString());
|
|
||||||
b2.setSimiarlityFunction(Distance.EUCLIDEAN.toString());
|
|
||||||
|
|
||||||
b1.fit(input);
|
|
||||||
INDArray ret1 = b1.getData();
|
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(10);
|
|
||||||
b2.fit(input);
|
|
||||||
INDArray ret2 = b2.getData();
|
|
||||||
assertEquals(ret1, ret2);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testCorrectness() throws IOException {
|
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().perplexity(20.0).numDimension(2).learningRate(200).setMaxIter(50)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
|
|
||||||
ClassPathResource resource = new ClassPathResource("/mnist2500_X.txt");
|
|
||||||
File f = resource.getTempFileFromArchive();
|
|
||||||
INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), " ");
|
|
||||||
StopWatch watch = new StopWatch();
|
|
||||||
watch.start();
|
|
||||||
b.fit(data);
|
|
||||||
// System.out.println(b.getData());
|
|
||||||
watch.stop();
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt");
|
|
||||||
List<String> labelsList = IOUtils.readLines(labels.getInputStream());
|
|
||||||
b.saveAsFile(/*labelsList,*/ new File(outDir, "raw.txt").getAbsolutePath());
|
|
||||||
// System.out.println(b.getData());
|
|
||||||
|
|
||||||
System.out.println("Fit done in " + watch);
|
|
||||||
assertEquals(2500, b.getData().size(0));
|
|
||||||
// System.out.println(b.getData());
|
|
||||||
|
|
||||||
INDArray a1 = b.getData().getRow(0);
|
|
||||||
INDArray a2 = b.getData().getRow(1);
|
|
||||||
INDArray a3 = b.getData().getRow(1000);
|
|
||||||
INDArray a4 = b.getData().getRow(2498);
|
|
||||||
INDArray a5 = b.getData().getRow(2499);
|
|
||||||
|
|
||||||
INDArray expectedRow0 = Nd4j.createFromArray(new double[]{ 167.8292, 32.5092, 75.6999, -27.1170, 17.6490, 107.4103, 46.2925, 0.4640, -30.7644, -5.6178, 18.9462, 0.0773, 16.9440, 82.9042, 82.0447, 57.1004, -65.7106, 21.9009, 31.2762, -46.9130, -79.2331, -47.1991, -84.3263, 53.6706, 90.2068, -35.2406, -39.4955, -34.6930, -27.5715, -4.8603, -126.0396, -58.8744, -101.5482, -0.2450, -12.1293, 74.7684, 69.9875, -42.2529, -23.4274, 24.8436, 1.4931, 3.3617, -85.8046, 31.6360, 29.9752, -118.0233, 65.4318, -16.9101, 65.3177, -37.1838, 21.2493, 32.0591, 2.8582, -62.2490, -61.2909});
|
|
||||||
INDArray expectedRow1 = Nd4j.createFromArray(new double[]{ 32.3478, 118.7499, -5.2345, 18.1522, -5.7661, 55.0841, 19.1792, 0.6082, 18.7637, 145.1893, 56.9232, 95.6905, 0.6450, 54.9728, -47.6037, 18.9907, 44.9000, 62.0607, 11.3163, 12.5538, 71.6602, 62.7464, 26.8367, 9.9804, 21.2930, 26.7346, -25.4178, 0.8815, 127.8388, 95.7059, 61.8721, 198.7351, 3.7012, 38.8855, 56.8623, -1.9203, -21.2366, 26.3412, -15.0002, -5.5686, -70.1437, -75.2662, 5.2471, 32.7884, 9.0304, 25.5222, 52.0305, -25.6134, 48.3513, 24.0128, -15.4485, -139.3574, 7.2340, 82.3224, 12.1519});
|
|
||||||
INDArray expectedRow1000 = Nd4j.createFromArray(new double[]{ 30.8645, -15.0904, -8.3493, 3.7487, -24.4678, 8.1096, 42.3257, 15.6477, -45.1260, 31.5830, 40.2178, -28.7947, -83.6021, -4.2135, -9.8731, 0.3819, -5.6642, -34.0559, -67.8494, -33.4919, -0.6254, 6.2422, -56.9254, -16.5402, 52.7575, -72.3746, 18.7587, -47.5842, 12.8834, -20.3063, 21.7613, -59.9718, 9.4924, 49.3242, -36.5622, -83.7369, 24.9921, 20.6678, 0.0452, -69.3666, 13.2417, -63.0318, 8.8107, -34.4605, -7.9497, -12.0326, 27.4876, -5.1647, 0.4363, -24.6792, -7.2241, 47.9472, 16.9052, -8.1184, -35.9715});
|
|
||||||
INDArray expectedRow2498 = Nd4j.createFromArray(new double[]{ -0.0919, -153.8959, -51.5028, -73.8650, -0.1183, -14.4633, -13.5049, 43.3787, 80.7100, 3.4296, 16.9782, -75.3470, 103.3307, 13.8846, -6.9218, 96.0892, 6.9730, -2.1582, -24.3647, 39.9077, -10.5426, -135.5623, -3.5470, 27.1481, -24.0933, -47.3872, 4.5534, -118.1384, -100.2693, -64.9634, -85.7244, 64.6426, -48.8833, -31.1378, -93.3141, 37.8991, 8.5912, -58.7564, 93.5057, 43.7609, -34.8800, -26.4699, -37.5039, 10.8743, 22.7238, -46.8137, 22.4390, -12.9343, 32.6593, -11.9136, -123.9708, -5.3310, -65.2792, -72.1379, 36.7171});
|
|
||||||
INDArray expectedRow2499 = Nd4j.createFromArray(new double[]{ -48.1854, 54.6014, 61.4287, 7.2306, 67.0068, 97.8297, 79.4408, 40.5714, -18.2712, -0.4891, 36.9610, 70.8634, 109.1919, -28.6810, 13.5949, -4.6143, 11.4054, -95.5810, 20.6512, 77.8442, 33.2472, 53.7065, 4.3208, -85.9796, 38.1717, -9.6965, 44.0203, 1.0427, -17.6281, -54.7104, -88.1742, -24.6297, 33.5158, -10.4808, 16.7051, 21.7057, 42.1260, 61.4450, -9.4028, -68.3737, 18.8957, 45.0714, 14.3170, 84.0521, 80.0860, -15.4343, -73.6115, -15.5358, -41.5067, -55.7111, 0.1811, -75.5584, 16.4112, -128.0799, 119.3907});
|
|
||||||
|
|
||||||
assertArrayEquals(expectedRow0.toDoubleVector(), b.getData().getRow(0).toDoubleVector(), 1e-4);
|
|
||||||
assertArrayEquals(expectedRow1.toDoubleVector(), b.getData().getRow(1).toDoubleVector(), 1e-4);
|
|
||||||
assertArrayEquals(expectedRow1000.toDoubleVector(), b.getData().getRow(1000).toDoubleVector(), 1e-4);
|
|
||||||
assertArrayEquals(expectedRow2498.toDoubleVector(), b.getData().getRow(2498).toDoubleVector(), 1e-4);
|
|
||||||
assertArrayEquals(expectedRow2499.toDoubleVector(), b.getData().getRow(2499).toDoubleVector(), 1e-4);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCorrectness1() {
|
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
|
|
||||||
double[] aData = new double[]{
|
|
||||||
0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray data = Nd4j.createFromArray(aData).reshape(11,5);
|
|
||||||
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).setMaxIter(20).perplexity(3.0).theta(0.5).numDimension(5).
|
|
||||||
invertDistanceMetric(false).similarityFunction(Distance.EUCLIDEAN.toString())
|
|
||||||
.setMomentum(0.5).learningRate(200).staticInit(data).setSwitchMomentumIteration(250)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
|
|
||||||
b.fit(data);
|
|
||||||
|
|
||||||
double[] expectedData = new double[]{ 63.8206, 80.4013, -19.4424, -140.4326, 198.7239,
|
|
||||||
106.1148, -96.6273, -124.3634, 78.4174, -83.6621,
|
|
||||||
-121.8706, 3.0888, -172.8560, 255.1262, 20.7021,
|
|
||||||
-120.7942, -78.1829, 56.6021, -112.3294, 185.4084,
|
|
||||||
88.5330, 78.0497, -18.8673, -11.0155, -175.1564,
|
|
||||||
-297.8463, 174.2511, -103.8793, 72.5455, -15.8498,
|
|
||||||
-134.5235, 42.3300, 154.0391, -280.1010, -167.9765,
|
|
||||||
306.9938, -150.9666, 83.4419, -36.0877, 83.9992,
|
|
||||||
245.1813, -81.5018, -14.8430, 16.1557, 166.8651,
|
|
||||||
-65.9247, -138.1783, 72.5444, 176.3088, -25.6732,
|
|
||||||
-69.6843, 167.3360, 87.6238, -18.5874, -187.3806};
|
|
||||||
|
|
||||||
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
|
|
||||||
for (int i = 0; i < expectedArray.rows(); ++i)
|
|
||||||
assertArrayEquals(expectedArray.getRow(i).toDoubleVector(), b.getData().getRow(i).toDoubleVector(), 1e-2);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testComputePerplexity() {
|
|
||||||
double[] input = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
b.computeGaussianPerplexity(ndinput, 3.0);
|
|
||||||
INDArray expectedRows = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99});
|
|
||||||
INDArray expectedCols = Nd4j.createFromArray(new int[] {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
|
||||||
INDArray expectedValues = Nd4j.createFromArray(new double[]{0.6199394088807811, 0.1964597878478939, 0.13826096288374987, 0.019500202354103796, 0.00892011933324624, 0.008390894278481041, 0.00333353509170543, 0.0026231979968002537, 0.0025718913332382506, 0.5877813741023542, 0.2824053513290301, 0.08100641562340703, 0.014863269403258283, 0.01219532549481422, 0.011522812905961816, 0.004243949243254114, 0.0034625890823446427, 0.002518912815575669, 0.6776991917357972, 0.18322100043035286, 0.040180871517768765, 0.02941481903928284, 0.021638322103495665, 0.019899251613183868, 0.011684443899339756, 0.008438621670147969, 0.007823477990631192, 0.6771051692354304, 0.16616561426152007, 0.06038657043891834, 0.04649900136463559, 0.01688479525099354, 0.014596215509122025, 0.006410339053808227, 0.006075759373243866, 0.005876535512328113, 0.6277958923349469, 0.23516301304728018, 0.07022275517450298, 0.030895020584550934, 0.012294459258033335, 0.009236709512467177, 0.00821667460222265, 0.0043013613064171955, 0.0018741141795786528, 0.7122763773574693, 0.07860063708191449, 0.07060648172121314, 0.06721282603559373, 0.028960026354739106, 0.017791245039439314, 0.01482510169996304, 0.005496178688168659, 0.004231126021499254, 0.5266697563046261, 0.33044733058681547, 0.10927281903651001, 0.018510201893239094, 0.006973656012751928, 0.006381768970069082, 0.0010596892780182746, 6.535010081417198E-4, 3.127690982824874E-5, 0.7176189632561156, 0.08740746743997298, 0.059268842313360166, 0.04664131589557433, 0.03288791302822797, 0.029929724912968133, 0.013368915822982491, 0.010616377319500762, 0.0022604800112974647, 0.689185362462809, 0.13977758696450715, 0.05439663822300743, 0.05434167873889952, 0.028687383013327405, 0.02099540802182275, 0.0072154477293594615, 0.0032822412915506907, 0.0021182535547164334, 0.6823844384306867, 0.13452128016104092, 0.08713547969428868, 0.04287399325857787, 0.025452813990877978, 0.016881841237860937, 0.0072200814416566415, 0.0019232561582331975, 0.0016068156267770154, 0.6425943207872832, 0.18472852256294967, 0.1089653923564887, 0.03467849453890959, 0.013282484305873534, 0.005149863792637524, 0.0037974408302766656, 0.003787710699822367, 0.003015770125758626});
|
|
||||||
assertArrayEquals(expectedCols.toIntVector(), b.getCols().toIntVector());
|
|
||||||
assertArrayEquals(expectedRows.toIntVector(), b.getRows().toIntVector());
|
|
||||||
assertArrayEquals(expectedValues.toDoubleVector(), b.getVals().toDoubleVector(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testComputeGradient() {
|
|
||||||
double[] input = new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803,
|
|
||||||
0.1096, 0.7950, 0.5918, 0.2738, 0.9520,
|
|
||||||
0.9690, 0.8586, 0.8088, 0.5338, 0.5961,
|
|
||||||
0.7187, 0.4630, 0.0867, 0.7748, 0.4802,
|
|
||||||
0.2493, 0.3227, 0.3064, 0.6980, 0.7977,
|
|
||||||
0.7674, 0.1680, 0.3107, 0.0217, 0.1380,
|
|
||||||
0.8619, 0.8413, 0.5285, 0.9703, 0.6774,
|
|
||||||
0.2624, 0.4374, 0.1569, 0.1107, 0.0601,
|
|
||||||
0.4094, 0.9564, 0.5994, 0.8279, 0.3859,
|
|
||||||
0.6202, 0.7604, 0.0788, 0.0865, 0.7445,
|
|
||||||
0.6548, 0.3385, 0.0582, 0.6249, 0.7432};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5)
|
|
||||||
.useAdaGrad(false).staticInit(ndinput).build();
|
|
||||||
b.setY(ndinput);
|
|
||||||
b.setN(11);
|
|
||||||
|
|
||||||
INDArray rowsP = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99});
|
|
||||||
INDArray colsP = Nd4j.createFromArray(new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
|
||||||
INDArray valsP = Nd4j.createFromArray(new double[]{0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030});
|
|
||||||
|
|
||||||
b.setRows(rowsP);
|
|
||||||
b.setCols(colsP);
|
|
||||||
b.setVals(valsP);
|
|
||||||
Gradient gradient = b.gradient();
|
|
||||||
|
|
||||||
double[] dC = {-0.0618386320333619, -0.06266654959379839, 0.029998268806149204, 0.10780566335888186, -0.19449543068355346, -0.14763764361792697, 0.17493572758118422, 0.1926109839221966, -0.15176648259935419, 0.10974665709698186, 0.13102419155322598, 0.004941641352409449, 0.19159764518354974, -0.26332838053474944, -0.023631441261541583, 0.09838669432305949, 0.09709129638394683, -0.01605053000727605, 0.06566171635025217, -0.17325078066035252, -0.1090854255505605, 0.023350644966904276, 0.075192354899586, -0.08278373866517603, 0.18431338134579323, 0.2766031655578053, -0.17557907233268688, 0.10616148241800637, -0.09999024423215641, -0.017181932145255287, 0.06711331400576945, -0.01388231800826619, -0.10248189290485302, 0.20786521034824304, 0.11254913977572988, -0.289564646781519, 0.13491805919337516, -0.07504249344962562, 0.004154656287570634, -0.10516715438388784, -0.27984655075804576, 0.09811828071286613, 0.03684521473995052, -0.054645216532387256, -0.18147132772800725, 0.027588750493223044, 0.214734364419479, -0.026729138234415008, -0.28410504978879136, 0.007015481601883835, 0.04427981739424874, -0.059253265830134655, -0.05325479031206952, -0.11319889109674944, 0.1530133971867549};
|
|
||||||
INDArray actual = gradient.getGradientFor("yIncs");
|
|
||||||
// System.out.println(actual);
|
|
||||||
assertArrayEquals(dC, actual.reshape(1,55).toDoubleVector(), 1e-05);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testApplyGradient() {
|
|
||||||
double[] Y = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(Y).reshape(11,5);
|
|
||||||
|
|
||||||
double[] gradient = { -0.0635, -0.0791, 0.0228, 0.1360, -0.2016,
|
|
||||||
-0.1034, 0.0976, 0.1266, -0.0781, 0.0707,
|
|
||||||
0.1184, -0.0018, 0.1719, -0.2529, -0.0209,
|
|
||||||
0.1204, 0.0855, -0.0530, 0.1069, -0.1860,
|
|
||||||
-0.0890, -0.0763, 0.0181, 0.0048, 0.1798,
|
|
||||||
0.2917, -0.1699, 0.1038, -0.0736, 0.0159,
|
|
||||||
0.1324, -0.0409, -0.1502, 0.2738, 0.1668,
|
|
||||||
-0.3012, 0.1489, -0.0801, 0.0329, -0.0817,
|
|
||||||
-0.2405, 0.0810, 0.0171, -0.0201, -0.1638,
|
|
||||||
0.0656, 0.1383, -0.0707, -0.1757, 0.0144,
|
|
||||||
0.0708, -0.1725, -0.0870, 0.0160, 0.1921};
|
|
||||||
INDArray ndgrad = Nd4j.createFromArray(gradient).reshape(11, 5);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString())
|
|
||||||
.invertDistanceMetric(false).theta(0.5).learningRate(200)
|
|
||||||
.useAdaGrad(false).staticInit(ndinput).build();
|
|
||||||
b.setY(ndinput);
|
|
||||||
b.setN(11);
|
|
||||||
INDArray yIncs = Nd4j.zeros(DataType.DOUBLE, ndinput.shape());
|
|
||||||
b.setYIncs(yIncs);
|
|
||||||
INDArray gains = Nd4j.zeros(DataType.DOUBLE, ndinput.shape());
|
|
||||||
b.setGains(gains);
|
|
||||||
b.update(ndgrad, "yIncs");
|
|
||||||
|
|
||||||
double[] expected = {2.54, 3.164, -0.912, -5.44, 8.064, 4.136, -3.9040000000000004, -5.064, 3.124, -2.828, -4.736000000000001, 0.072, -6.8759999999999994, 10.116, 0.836, -4.816, -3.4200000000000004, 2.12, -4.276, 7.4399999999999995, 3.5599999999999996, 3.0520000000000005, -0.7240000000000001, -0.19199999999999998, -7.191999999999999, -11.668000000000001, 6.795999999999999, -4.152, 2.944, -0.636, -5.295999999999999, 1.636, 6.008, -10.952, -6.672000000000001, 12.048000000000002, -5.956, 3.204, -1.3159999999999998, 3.268, 9.62, -3.24, -0.684, 0.804, 6.552, -2.624, -5.532, 2.828, 7.028, -0.576, -2.832, 6.8999999999999995, 3.4799999999999995, -0.64, -7.683999999999999};
|
|
||||||
assertArrayEquals(expected, b.getYIncs().reshape(55).toDoubleVector(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testComputeEdgeForces() {
|
|
||||||
double[] input = new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803,
|
|
||||||
0.1096, 0.7950, 0.5918, 0.2738, 0.9520,
|
|
||||||
0.9690, 0.8586, 0.8088, 0.5338, 0.5961,
|
|
||||||
0.7187, 0.4630, 0.0867, 0.7748, 0.4802,
|
|
||||||
0.2493, 0.3227, 0.3064, 0.6980, 0.7977,
|
|
||||||
0.7674, 0.1680, 0.3107, 0.0217, 0.1380,
|
|
||||||
0.8619, 0.8413, 0.5285, 0.9703, 0.6774,
|
|
||||||
0.2624, 0.4374, 0.1569, 0.1107, 0.0601,
|
|
||||||
0.4094, 0.9564, 0.5994, 0.8279, 0.3859,
|
|
||||||
0.6202, 0.7604, 0.0788, 0.0865, 0.7445,
|
|
||||||
0.6548, 0.3385, 0.0582, 0.6249, 0.7432};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5);
|
|
||||||
SpTree tree = new SpTree(ndinput);
|
|
||||||
INDArray rows = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99});
|
|
||||||
INDArray cols = Nd4j.createFromArray(new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
|
||||||
INDArray vals = Nd4j.createFromArray(new double[]{0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030});
|
|
||||||
int N = 11;
|
|
||||||
INDArray posF = Nd4j.create(ndinput.shape());
|
|
||||||
tree.computeEdgeForces(rows, cols, vals, N, posF);
|
|
||||||
double[] expectedPosF = {-0.08017022778816381, -0.08584612446002386, 0.024041740837932417, 0.13353853518214748, -0.19989209255196486, -0.17059164865362167, 0.18730152809351328, 0.20582835656173232, -0.1652505189678666, 0.13123839113710167, 0.15511476126066306, 0.021425546153174206, 0.21755440369356663, -0.2628756936897519, -0.021079609911707077, 0.11455959658671841, 0.08803186126822704, -0.039212116057989604, 0.08800854045636688, -0.1795568260613919, -0.13265313037184673, 0.0036829788349159154, 0.07205631770917967, -0.06873974602987808, 0.20446419876515043, 0.28724205607738795, -0.19397780156808536, 0.10457369548573531, -0.12340830629973816, -0.03634773269456816, 0.0867775929922852, 0.0029761730963277894, -0.09131897988004745, 0.2348924028566898, 0.12026408931908775, -0.30400848137321873, 0.1282943410872978, -0.08487864823843354, -0.017561758195375168, -0.13082811573092396, -0.2885857462722986, 0.12469730654026252, 0.05408469871148934, -0.03417740859260864, -0.19261929748672968, 0.03318694717819495, 0.22818123908045765, -0.044944593551341956, -0.3141734963080852, 0.020297428845239652, 0.05442118949793863, -0.07890301602838638, -0.07823705950336371, -0.10455483898962027, 0.16980714813230746};
|
|
||||||
INDArray indExpectedPositive = Nd4j.createFromArray(expectedPosF).reshape(11, 5);
|
|
||||||
assertEquals(indExpectedPositive, posF);
|
|
||||||
|
|
||||||
AtomicDouble sumQ = new AtomicDouble(0.0);
|
|
||||||
double theta = 0.5;
|
|
||||||
INDArray negF = Nd4j.create(ndinput.shape());
|
|
||||||
|
|
||||||
double[][] neg = {{-1.6243229118532043, -2.0538918185758117, -0.5277950148630416, 2.280133920112387, -0.4781864949257863},
|
|
||||||
{-2.033904565482581, 1.0957067439325718, 1.1711627018218371, -1.1947911960637323, 1.904335906364157},
|
|
||||||
{2.134613094178481, 1.4606030267537151, 2.299972033488509, 0.040111598796927175, 0.22611223726312565},
|
|
||||||
{1.4330457669590706, -0.8027368824700638, -2.052297868677289, 1.9801035811739054, -0.5587649959721402},
|
|
||||||
{-2.088283171473531, -1.7427092080895168, -0.27787744880128185, 1.2444077055013942, 1.7855201950031347},
|
|
||||||
{0.9426889976629138, -1.6302714638583877, -0.14069035384185855, -2.075023651861262, -1.698239988087389},
|
|
||||||
{1.7424090804808496, 1.493794306111751, 0.989121494481274, 2.394820866756112, 0.6836049340540907},
|
|
||||||
{-1.279836833417519, -0.5869132848699253, -0.871560326864079, -1.9242443527432451, -2.273762088892443},
|
|
||||||
{-0.7743611464510498, 2.3551097898757134, 1.527553257122278, 1.813608037002701, -0.9877974041073948},
|
|
||||||
{0.49604405759812625, 1.1914983778171337, -1.6140319597311803, -2.6642997837396654, 1.1768845173097966},
|
|
||||||
{0.8986049706740562, -1.7411217160869163, -2.213624650045752, 0.7659306956507013, 1.4880578211349607}};
|
|
||||||
|
|
||||||
double expectedSumQ = 88.60782954084712;
|
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ);
|
|
||||||
assertArrayEquals(neg[n], negF.slice(n).toDoubleVector(), 1e-05);
|
|
||||||
}
|
|
||||||
assertEquals(expectedSumQ, sumQ.get(), 1e-05);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
@Test
|
|
||||||
public void testSymmetrized() {
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5)
|
|
||||||
.useAdaGrad(false).build();
|
|
||||||
INDArray expectedSymmetrized = Nd4j.createFromArray(new double[]{0.6239, 0.1813, 0.12359999999999999, 0.03695, 0.00795, 0.03385, 0.0074, 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051000000000000004, 0.00895, 0.016050000000000002, 0.00245, 0.00705, 0.00125, 0.0021, 0.016050000000000002, 0.6022, 0.1615, 0.0233, 0.0183, 0.0108, 0.0068000000000000005, 0.0042, 0.011300000000000001, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, 0.0779, 0.03565, 0.05085, 0.038349999999999995, 0.026250000000000002, 0.6239, 0.3093, 0.0068000000000000005, 0.0653, 0.2099, 0.0205, 0.0173, 0.007300000000000001, 0.0171, 0.0089, 0.0158, 0.011300000000000001, 0.038349999999999995, 0.71495, 0.04775, 0.03615, 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, 0.0779, 0.007300000000000001, 0.5098, 0.015899999999999997, 0.00135, 1.5623E-5, 0.03385, 0.00705, 0.026250000000000002, 0.0171, 0.71495, 0.06515, 0.018349999999999998, 0.00775, 0.00115, 0.03695, 0.051000000000000004, 0.1615, 0.03565, 0.0205, 0.00275, 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, 0.05085, 0.0173, 0.04775, 0.00135, 0.06515, 0.0026, 0.35855, 0.12359999999999999, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.015899999999999997, 0.018349999999999998, 0.0055, 0.35855});
|
|
||||||
INDArray rowsP = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99});
|
|
||||||
INDArray colsP = Nd4j.createFromArray(new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
|
||||||
INDArray valsP = Nd4j.createFromArray(new double[]{0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030});
|
|
||||||
b.setN(11);
|
|
||||||
BarnesHutTsne.SymResult actualSymmetrized = b.symmetrized(rowsP, colsP, valsP);
|
|
||||||
System.out.println("Symmetrized from Java:" + actualSymmetrized);
|
|
||||||
System.out.println(actualSymmetrized.rows);
|
|
||||||
System.out.println(actualSymmetrized.cols);
|
|
||||||
assertArrayEquals(expectedSymmetrized.toDoubleVector(), actualSymmetrized.vals.toDoubleVector(), 1e-5);
|
|
||||||
|
|
||||||
|
|
||||||
INDArray rowsFromCpp = Nd4j.create(new int[]{rowsP.rows(),rowsP.columns()}, DataType.INT);
|
|
||||||
BarnesHutSymmetrize op = new BarnesHutSymmetrize(rowsP, colsP, valsP, 11, rowsFromCpp);
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
INDArray valsFromCpp = op.getSymmetrizedValues();
|
|
||||||
INDArray colsFromCpp = op.getSymmetrizedCols();
|
|
||||||
System.out.println("Symmetrized from C++: " + valsP);
|
|
||||||
assertArrayEquals(expectedSymmetrized.toDoubleVector(), valsFromCpp.toDoubleVector(), 1e-5);
|
|
||||||
|
|
||||||
int[] expectedRows = new int[]{0, 10, 20, 30, 40, 50, 60, 69, 78, 88, 98, 108};
|
|
||||||
int[] expectedCols = new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 2, 0, 4, 9, 8, 10, 2, 6, 7, 3, 5, 1, 6, 8, 3, 9, 10, 4, 0, 5, 7, 0, 1, 2, 10, 4, 6, 8, 9, 5, 7, 0, 1, 2, 3, 10, 8, 9, 6, 7, 5, 0, 2, 3, 7, 9, 10, 4, 8, 1, 6, 0, 1, 2, 3, 4, 8, 10, 9, 5, 0, 1, 3, 4, 5, 9, 10, 8, 2, 0, 1, 2, 3, 4, 5, 6, 7, 10, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
|
|
||||||
assertArrayEquals(expectedRows, rowsFromCpp.toIntVector());
|
|
||||||
assertArrayEquals(expectedCols, colsFromCpp.toIntVector());
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testVPTree() {
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
|
||||||
double[] d = new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803,
|
|
||||||
0.1096, 0.7950, 0.5918, 0.2738, 0.9520,
|
|
||||||
0.9690, 0.8586, 0.8088, 0.5338, 0.5961,
|
|
||||||
0.7187, 0.4630, 0.0867, 0.7748, 0.4802,
|
|
||||||
0.2493, 0.3227, 0.3064, 0.6980, 0.7977,
|
|
||||||
0.7674, 0.1680, 0.3107, 0.0217, 0.1380,
|
|
||||||
0.8619, 0.8413, 0.5285, 0.9703, 0.6774,
|
|
||||||
0.2624, 0.4374, 0.1569, 0.1107, 0.0601,
|
|
||||||
0.4094, 0.9564, 0.5994, 0.8279, 0.3859,
|
|
||||||
0.6202, 0.7604, 0.0788, 0.0865, 0.7445,
|
|
||||||
0.6548, 0.3385, 0.0582, 0.6249, 0.7432};
|
|
||||||
VPTree tree = new VPTree(Nd4j.createFromArray(d).reshape(11, 5), "euclidean", 1, false);
|
|
||||||
INDArray target = Nd4j.createFromArray(new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803});
|
|
||||||
List<DataPoint> results = new ArrayList<>();
|
|
||||||
List<Double> distances = new ArrayList<>();
|
|
||||||
tree.search(target, 11, results, distances);
|
|
||||||
// System.out.println("Results:" + results);
|
|
||||||
// System.out.println("Distances:" + distances);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSpTree() {
|
|
||||||
double[] input = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5);
|
|
||||||
|
|
||||||
int[] rows = {0, 10, 20, 30, 40, 50, 60, 69, 78, 88, 98, 108};
|
|
||||||
INDArray indRows = Nd4j.createFromArray(rows);
|
|
||||||
int[] cols = {4, 3, 10, 8, 6, 7, 1, 5, 9, 2, 0, 4, 9, 8, 10, 2, 6, 7, 3, 5, 1, 6, 8, 3, 9, 10, 4, 0, 5, 7, 0, 1, 2, 10, 4, 6, 8, 9,
|
|
||||||
5, 7, 0, 1, 2, 3, 10, 8, 9, 6, 7, 5, 0, 2, 3, 7, 9, 10, 4, 8, 1, 6, 0, 1, 2, 3, 4, 8, 10, 9, 5, 0, 1, 3, 4, 5, 9, 10, 8, 2, 0, 1, 2, 3, 4, 5, 6, 7, 10, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
|
||||||
INDArray indCols = Nd4j.createFromArray(cols);
|
|
||||||
double[] vals = {0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911};
|
|
||||||
INDArray indVals = Nd4j.createFromArray(vals);
|
|
||||||
|
|
||||||
final int N = 11;
|
|
||||||
INDArray posF = Nd4j.create(DataType.DOUBLE, ndinput.shape());
|
|
||||||
SpTree tree = new SpTree(ndinput);
|
|
||||||
tree.computeEdgeForces(indRows, indCols, indVals, N, posF);
|
|
||||||
double[]expectedPosF = {-0.0818453583761987, -0.10231102631753211, 0.016809473355579547, 0.16176252194290375, -0.20703464777007444, -0.1263832139293613, 0.10996898963389254, 0.13983782727968627, -0.09164547825742625, 0.09219041827159041, 0.14252277104691244, 0.014676985587529433, 0.19786703075718223, -0.25244374832212546, -0.018387062879777892, 0.13652061663449183, 0.07639155593531936, -0.07616591260449279, 0.12919565310762643, -0.19229222179037395, -0.11250575155166542, -0.09598877143033444, 0.014899570740339653, 0.018867923701997365, 0.19996253097190828, 0.30233811684856743, -0.18830455752593392, 0.10223346521208224, -0.09703007177169608, -0.003280966942428477, 0.15213078827243462, -0.02397414389327494, -0.1390550777479942, 0.30088735606726813, 0.17456236098186903, -0.31560012032960044, 0.142309945794784, -0.08988089476622348, 0.011236280978163357, -0.10732740266565795, -0.24928551644245478, 0.10762735102220329, 0.03434270193250408, 2.831838829882295E-4, -0.17494982967210068, 0.07114328804840916, 0.15171552834583996, -0.08888924450773618, -0.20576831397087963, 0.027662749212463134, 0.08096437977846523, -0.19211185715249313, -0.11199893965092741, 0.024654692641180212, 0.20889407228258244};
|
|
||||||
assertArrayEquals(expectedPosF, posF.reshape(1,55).toDoubleVector(), 1e-5);
|
|
||||||
|
|
||||||
final double theta = 0.5;
|
|
||||||
AtomicDouble sumQ = new AtomicDouble(0.0);
|
|
||||||
INDArray negF = Nd4j.create(DataType.DOUBLE, ndinput.shape());
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
INDArray prev = ((n == 0) ? negF.slice(n ): negF.slice(n-1));
|
|
||||||
tree.computeNonEdgeForces(n, theta, negF.slice(0), sumQ);
|
|
||||||
}
|
|
||||||
|
|
||||||
double[] expectedNegF = {-0.15349944039348173, -0.9608688924710804, -1.7099994806905086, 2.6604989787415203, 1.2677709150619332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
|
||||||
double expectedSum = 88.60715062760883;
|
|
||||||
|
|
||||||
assertArrayEquals(expectedNegF, negF.reshape(1,55).toDoubleVector(), 1e-5);
|
|
||||||
assertEquals(expectedSum, sumQ.get(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testZeroMean() {
|
|
||||||
double[] aData = new double[]{
|
|
||||||
0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486,
|
|
||||||
0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856,
|
|
||||||
0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657,
|
|
||||||
0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635,
|
|
||||||
0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357,
|
|
||||||
0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949,
|
|
||||||
0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041};
|
|
||||||
INDArray ndinput = Nd4j.createFromArray(aData).reshape(11,5);
|
|
||||||
BarnesHutTsne.zeroMean(ndinput);
|
|
||||||
double[] expected = {-0.2384362257971937, -0.3014583649756485, -0.07747340086583643, 0.3347228669042438, -0.07021239883787267, -0.4288269552188002, 0.23104543246717713, 0.24692615118463546, -0.2518949460768749, 0.40149075100042775, 0.43058455530728645, 0.2945826924287568, 0.46391735081548713, 0.008071612942910145, 0.04560992908478334, 0.18029509736889826, -0.10100112958911733, -0.25819965185986493, 0.249076993761699, -0.07027581217344359, -0.28914440219989934, -0.2412528624510093, -0.03844377463128612, 0.17229766891014098, 0.24724071459311825, 0.22893338884928305, -0.39601068985406596, -0.034122795135254735, -0.5040218199596326, -0.4125030539615038, 0.3234774312676665, 0.2773549760319213, 0.18363699390132904, 0.44461322249255764, 0.12691041508560408, -0.275970422630463, -0.12656919880264839, -0.18800328403712419, -0.41499425466692597, -0.4904037222152954, -0.12902604875790624, 0.3924120572383435, 0.2545557508323111, 0.30216923841015575, -0.16460937225707228, 0.0817665510120591, 0.1964040455733127, -0.26610182764728363, -0.4392121790122696, 0.19404338217447925, 0.11634703079906861, -0.22550695806702292, -0.2866915125571131, 0.09917159629399586, 0.19270916750677514};
|
|
||||||
assertArrayEquals(expected, ndinput.reshape(55).toDoubleVector(), 1e-5);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -190,7 +190,7 @@ public class ValidateCuDNN extends BaseDL4JTest {
|
||||||
validateLayers(net, classesToTest, false, fShape, lShape, CuDNNValidationUtil.MAX_REL_ERROR, CuDNNValidationUtil.MIN_ABS_ERROR);
|
validateLayers(net, classesToTest, false, fShape, lShape, CuDNNValidationUtil.MAX_REL_ERROR, CuDNNValidationUtil.MIN_ABS_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test @Ignore //AB 2019/05/20 - https://github.com/deeplearning4j/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later
|
@Test @Ignore //AB 2019/05/20 - https://github.com/eclipse/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later
|
||||||
public void validateConvLayersLRN() {
|
public void validateConvLayersLRN() {
|
||||||
//Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else)
|
//Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else)
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -80,7 +80,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat
|
||||||
log.error("Checksums do not match. Cleaning up files and failing...");
|
log.error("Checksums do not match. Cleaning up files and failing...");
|
||||||
tmpFile.delete();
|
tmpFile.delete();
|
||||||
throw new IllegalStateException( "Dataset file failed checksum: " + tmpFile + " - expected checksum " + expectedChecksum(set)
|
throw new IllegalStateException( "Dataset file failed checksum: " + tmpFile + " - expected checksum " + expectedChecksum(set)
|
||||||
+ " vs. actual checksum " + localChecksum + ". If this error persists, please open an issue at https://github.com/deeplearning4j/deeplearning4j.");
|
+ " vs. actual checksum " + localChecksum + ". If this error persists, please open an issue at https://github.com/eclipse/deeplearning4j.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-manifold</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-tsne</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-tsne</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>nearestneighbor-core</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nn</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.projectlombok</groupId>
|
|
||||||
<artifactId>lombok</artifactId>
|
|
||||||
<version>${lombok.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,433 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.plot;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dimensionalityreduction.PCA;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
||||||
import org.nd4j.linalg.learning.legacy.AdaGrad;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.BufferedWriter;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileWriter;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.nd4j.linalg.factory.Nd4j.*;
|
|
||||||
import static org.nd4j.linalg.ops.transforms.Transforms.*;
|
|
||||||
|
|
||||||
public class Tsne {
|
|
||||||
protected int maxIter = 1000;
|
|
||||||
protected double realMin = Nd4j.EPS_THRESHOLD;
|
|
||||||
protected double initialMomentum = 0.5;
|
|
||||||
protected double finalMomentum = 0.8;
|
|
||||||
protected double minGain = 1e-2;
|
|
||||||
protected double momentum = initialMomentum;
|
|
||||||
protected int switchMomentumIteration = 100;
|
|
||||||
protected boolean normalize = true;
|
|
||||||
protected boolean usePca = false;
|
|
||||||
protected int stopLyingIteration = 250;
|
|
||||||
protected double tolerance = 1e-5;
|
|
||||||
protected double learningRate = 500;
|
|
||||||
protected AdaGrad adaGrad;
|
|
||||||
protected boolean useAdaGrad = true;
|
|
||||||
protected double perplexity = 30;
|
|
||||||
//protected INDArray gains,yIncs;
|
|
||||||
protected INDArray Y;
|
|
||||||
|
|
||||||
protected static final Logger logger = LoggerFactory.getLogger(Tsne.class);
|
|
||||||
|
|
||||||
|
|
||||||
public Tsne(final int maxIter, final double realMin, final double initialMomentum, final double finalMomentum,
|
|
||||||
final double minGain, final double momentum, final int switchMomentumIteration,
|
|
||||||
final boolean normalize, final boolean usePca, final int stopLyingIteration, final double tolerance,
|
|
||||||
final double learningRate, final boolean useAdaGrad, final double perplexity) {
|
|
||||||
this.maxIter = maxIter;
|
|
||||||
this.realMin = realMin;
|
|
||||||
this.initialMomentum = initialMomentum;
|
|
||||||
this.finalMomentum = finalMomentum;
|
|
||||||
this.minGain = minGain;
|
|
||||||
this.momentum = momentum;
|
|
||||||
this.switchMomentumIteration = switchMomentumIteration;
|
|
||||||
this.normalize = normalize;
|
|
||||||
this.usePca = usePca;
|
|
||||||
this.stopLyingIteration = stopLyingIteration;
|
|
||||||
this.tolerance = tolerance;
|
|
||||||
this.learningRate = learningRate;
|
|
||||||
this.useAdaGrad = useAdaGrad;
|
|
||||||
this.perplexity = perplexity;
|
|
||||||
this.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void init() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray calculate(INDArray X, int targetDimensions, double perplexity) {
|
|
||||||
// pca hook
|
|
||||||
if (usePca) {
|
|
||||||
X = PCA.pca(X, Math.min(50, X.columns()), normalize);
|
|
||||||
} else if (normalize) {
|
|
||||||
X.subi(X.min(Integer.MAX_VALUE));
|
|
||||||
X = X.divi(X.max(Integer.MAX_VALUE));
|
|
||||||
X = X.subiRowVector(X.mean(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int n = X.rows();
|
|
||||||
// FIXME: this is wrong, another distribution required here
|
|
||||||
Y = Nd4j.randn(X.dataType(), X.rows(), targetDimensions);
|
|
||||||
INDArray dY = Nd4j.zeros(n, targetDimensions);
|
|
||||||
INDArray iY = Nd4j.zeros(n, targetDimensions);
|
|
||||||
INDArray gains = Nd4j.ones(n, targetDimensions);
|
|
||||||
|
|
||||||
boolean stopLying = false;
|
|
||||||
logger.debug("Y:Shape is = " + Arrays.toString(Y.shape()));
|
|
||||||
|
|
||||||
// compute P-values
|
|
||||||
INDArray P = x2p(X, tolerance, perplexity);
|
|
||||||
|
|
||||||
// do training
|
|
||||||
for (int i = 0; i < maxIter; i++) {
|
|
||||||
INDArray sumY = pow(Y, 2).sum(1).transpose();
|
|
||||||
|
|
||||||
//Student-t distribution
|
|
||||||
//also un normalized q
|
|
||||||
// also known as num in original implementation
|
|
||||||
INDArray qu = Y.mmul(Y.transpose()).muli(-2).addiRowVector(sumY).transpose().addiRowVector(sumY).addi(1)
|
|
||||||
.rdivi(1);
|
|
||||||
|
|
||||||
// doAlongDiagonal(qu,new Zero());
|
|
||||||
|
|
||||||
INDArray Q = qu.div(qu.sumNumber().doubleValue());
|
|
||||||
BooleanIndexing.replaceWhere(Q, 1e-12, Conditions.lessThan(1e-12));
|
|
||||||
|
|
||||||
INDArray PQ = P.sub(Q).muli(qu);
|
|
||||||
|
|
||||||
logger.debug("PQ shape is: " + Arrays.toString(PQ.shape()));
|
|
||||||
logger.debug("PQ.sum(1) shape is: " + Arrays.toString(PQ.sum(1).shape()));
|
|
||||||
|
|
||||||
dY = diag(PQ.sum(1)).subi(PQ).mmul(Y).muli(4);
|
|
||||||
|
|
||||||
|
|
||||||
if (i < switchMomentumIteration) {
|
|
||||||
momentum = initialMomentum;
|
|
||||||
} else {
|
|
||||||
momentum = finalMomentum;
|
|
||||||
}
|
|
||||||
|
|
||||||
gains = gains.add(.2).muli(dY.cond(Conditions.greaterThan(0)).neq(iY.cond(Conditions.greaterThan(0))))
|
|
||||||
.addi(gains.mul(0.8).muli(dY.cond(Conditions.greaterThan(0))
|
|
||||||
.eq(iY.cond(Conditions.greaterThan(0)))));
|
|
||||||
|
|
||||||
BooleanIndexing.replaceWhere(gains, minGain, Conditions.lessThan(minGain));
|
|
||||||
|
|
||||||
INDArray gradChange = gains.mul(dY);
|
|
||||||
|
|
||||||
gradChange.muli(learningRate);
|
|
||||||
|
|
||||||
iY.muli(momentum).subi(gradChange);
|
|
||||||
|
|
||||||
double cost = P.mul(log(P.div(Q), false)).sumNumber().doubleValue();
|
|
||||||
logger.info("Iteration [" + i + "] error is: [" + cost + "]");
|
|
||||||
|
|
||||||
Y.addi(iY);
|
|
||||||
// Y.addi(iY).subiRowVector(Y.mean(0));
|
|
||||||
INDArray tiled = Nd4j.tile(Y.mean(0), new int[] {Y.rows(), 1});
|
|
||||||
Y.subi(tiled);
|
|
||||||
|
|
||||||
if (!stopLying && (i > maxIter / 2 || i >= stopLyingIteration)) {
|
|
||||||
P.divi(4);
|
|
||||||
stopLying = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Y;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray diag(INDArray ds) {
|
|
||||||
boolean isLong = ds.rows() > ds.columns();
|
|
||||||
INDArray sliceZero = ds.slice(0);
|
|
||||||
int dim = Math.max(ds.columns(), ds.rows());
|
|
||||||
INDArray result = Nd4j.create(dim, dim);
|
|
||||||
for (int i = 0; i < dim; i++) {
|
|
||||||
INDArray sliceSrc = ds.slice(i);
|
|
||||||
INDArray sliceDst = result.slice(i);
|
|
||||||
for (int j = 0; j < dim; j++) {
|
|
||||||
if (i == j) {
|
|
||||||
if (isLong)
|
|
||||||
sliceDst.putScalar(j, sliceSrc.getDouble(0));
|
|
||||||
else
|
|
||||||
sliceDst.putScalar(j, sliceZero.getDouble(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
|
|
||||||
|
|
||||||
calculate(matrix, nDims, perplexity);
|
|
||||||
|
|
||||||
BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), true));
|
|
||||||
|
|
||||||
for (int i = 0; i < Y.rows(); i++) {
|
|
||||||
if (i >= labels.size())
|
|
||||||
break;
|
|
||||||
String word = labels.get(i);
|
|
||||||
if (word == null)
|
|
||||||
continue;
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
INDArray wordVector = Y.getRow(i);
|
|
||||||
for (int j = 0; j < wordVector.length(); j++) {
|
|
||||||
sb.append(wordVector.getDouble(j));
|
|
||||||
if (j < wordVector.length() - 1)
|
|
||||||
sb.append(",");
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.append(",");
|
|
||||||
sb.append(word);
|
|
||||||
sb.append(" ");
|
|
||||||
|
|
||||||
sb.append("\n");
|
|
||||||
write.write(sb.toString());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
write.flush();
|
|
||||||
write.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes a gaussian kernel
|
|
||||||
* given a vector of squared distance distances
|
|
||||||
*
|
|
||||||
* @param d the data
|
|
||||||
* @param beta
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Pair<Double, INDArray> hBeta(INDArray d, double beta) {
|
|
||||||
INDArray P = exp(d.neg().muli(beta));
|
|
||||||
double sumP = P.sumNumber().doubleValue();
|
|
||||||
double logSumP = FastMath.log(sumP);
|
|
||||||
Double H = logSumP + ((beta * (d.mul(P).sumNumber().doubleValue())) / sumP);
|
|
||||||
P.divi(sumP);
|
|
||||||
return new Pair<>(H, P);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method build probabilities for given source data
|
|
||||||
*
|
|
||||||
* @param X
|
|
||||||
* @param tolerance
|
|
||||||
* @param perplexity
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private INDArray x2p(final INDArray X, double tolerance, double perplexity) {
|
|
||||||
int n = X.rows();
|
|
||||||
final INDArray p = zeros(n, n);
|
|
||||||
final INDArray beta = ones(n, 1);
|
|
||||||
final double logU = Math.log(perplexity);
|
|
||||||
|
|
||||||
INDArray sumX = pow(X, 2).sum(1);
|
|
||||||
|
|
||||||
logger.debug("sumX shape: " + Arrays.toString(sumX.shape()));
|
|
||||||
|
|
||||||
INDArray times = X.mmul(X.transpose()).muli(-2);
|
|
||||||
|
|
||||||
logger.debug("times shape: " + Arrays.toString(times.shape()));
|
|
||||||
|
|
||||||
INDArray prodSum = times.transpose().addiColumnVector(sumX);
|
|
||||||
|
|
||||||
logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape()));
|
|
||||||
|
|
||||||
INDArray D = X.mmul(X.transpose()).mul(-2) // thats times
|
|
||||||
.transpose().addColumnVector(sumX) // thats prodSum
|
|
||||||
.addRowVector(sumX.transpose()); // thats D
|
|
||||||
|
|
||||||
logger.info("Calculating probabilities of data similarities...");
|
|
||||||
logger.debug("Tolerance: " + tolerance);
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
if (i % 500 == 0 && i > 0)
|
|
||||||
logger.info("Handled [" + i + "] records out of [" + n + "]");
|
|
||||||
|
|
||||||
double betaMin = Double.NEGATIVE_INFINITY;
|
|
||||||
double betaMax = Double.POSITIVE_INFINITY;
|
|
||||||
int[] vals = Ints.concat(ArrayUtil.range(0, i), ArrayUtil.range(i + 1, n));
|
|
||||||
INDArrayIndex[] range = new INDArrayIndex[] {new SpecifiedIndex(vals)};
|
|
||||||
|
|
||||||
INDArray row = D.slice(i).get(range);
|
|
||||||
Pair<Double, INDArray> pair = hBeta(row, beta.getDouble(i));
|
|
||||||
//INDArray hDiff = pair.getFirst().sub(logU);
|
|
||||||
double hDiff = pair.getFirst() - logU;
|
|
||||||
int tries = 0;
|
|
||||||
|
|
||||||
//while hdiff > tolerance
|
|
||||||
while (Math.abs(hDiff) > tolerance && tries < 50) {
|
|
||||||
//if hdiff > 0
|
|
||||||
if (hDiff > 0) {
|
|
||||||
betaMin = beta.getDouble(i);
|
|
||||||
if (Double.isInfinite(betaMax))
|
|
||||||
beta.putScalar(i, beta.getDouble(i) * 2.0);
|
|
||||||
else
|
|
||||||
beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0);
|
|
||||||
} else {
|
|
||||||
betaMax = beta.getDouble(i);
|
|
||||||
if (Double.isInfinite(betaMin))
|
|
||||||
beta.putScalar(i, beta.getDouble(i) / 2.0);
|
|
||||||
else
|
|
||||||
beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
pair = hBeta(row, beta.getDouble(i));
|
|
||||||
hDiff = pair.getFirst() - logU;
|
|
||||||
tries++;
|
|
||||||
}
|
|
||||||
p.slice(i).put(range, pair.getSecond());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//dont need data in memory after
|
|
||||||
logger.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
|
|
||||||
BooleanIndexing.replaceWhere(p, 1e-12, Conditions.isNan());
|
|
||||||
|
|
||||||
//set 0 along the diagonal
|
|
||||||
INDArray permute = p.transpose();
|
|
||||||
|
|
||||||
INDArray pOut = p.add(permute);
|
|
||||||
|
|
||||||
pOut.divi(pOut.sumNumber().doubleValue() + 1e-6);
|
|
||||||
|
|
||||||
pOut.muli(4);
|
|
||||||
|
|
||||||
BooleanIndexing.replaceWhere(pOut, 1e-12, Conditions.lessThan(1e-12));
|
|
||||||
//ensure no nans
|
|
||||||
|
|
||||||
return pOut;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
protected int maxIter = 1000;
|
|
||||||
protected double realMin = 1e-12f;
|
|
||||||
protected double initialMomentum = 5e-1f;
|
|
||||||
protected double finalMomentum = 8e-1f;
|
|
||||||
protected double momentum = 5e-1f;
|
|
||||||
protected int switchMomentumIteration = 100;
|
|
||||||
protected boolean normalize = true;
|
|
||||||
protected boolean usePca = false;
|
|
||||||
protected int stopLyingIteration = 100;
|
|
||||||
protected double tolerance = 1e-5f;
|
|
||||||
protected double learningRate = 1e-1f;
|
|
||||||
protected boolean useAdaGrad = false;
|
|
||||||
protected double perplexity = 30;
|
|
||||||
protected double minGain = 1e-1f;
|
|
||||||
|
|
||||||
|
|
||||||
public Builder minGain(double minGain) {
|
|
||||||
this.minGain = minGain;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder perplexity(double perplexity) {
|
|
||||||
this.perplexity = perplexity;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder useAdaGrad(boolean useAdaGrad) {
|
|
||||||
this.useAdaGrad = useAdaGrad;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder learningRate(double learningRate) {
|
|
||||||
this.learningRate = learningRate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public Builder tolerance(double tolerance) {
|
|
||||||
this.tolerance = tolerance;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder stopLyingIteration(int stopLyingIteration) {
|
|
||||||
this.stopLyingIteration = stopLyingIteration;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder usePca(boolean usePca) {
|
|
||||||
this.usePca = usePca;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder normalize(boolean normalize) {
|
|
||||||
this.normalize = normalize;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setMaxIter(int maxIter) {
|
|
||||||
this.maxIter = maxIter;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setRealMin(double realMin) {
|
|
||||||
this.realMin = realMin;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setInitialMomentum(double initialMomentum) {
|
|
||||||
this.initialMomentum = initialMomentum;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setFinalMomentum(double finalMomentum) {
|
|
||||||
this.finalMomentum = finalMomentum;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setMomentum(double momentum) {
|
|
||||||
this.momentum = momentum;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
|
|
||||||
this.switchMomentumIteration = switchMomentumIteration;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Tsne build() {
|
|
||||||
return new Tsne(maxIter, realMin, initialMomentum, finalMomentum, minGain, momentum,
|
|
||||||
switchMomentumIteration, normalize, usePca, stopLyingIteration, tolerance, learningRate,
|
|
||||||
useAdaGrad, perplexity);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.plot;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class Test6058 extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test() throws Exception {
|
|
||||||
//All zero input -> cosine similarity isn't defined
|
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6058
|
|
||||||
val iterations = 10;
|
|
||||||
val cacheList = new ArrayList<String>();
|
|
||||||
|
|
||||||
int nWords = 100;
|
|
||||||
for(int i=0; i<nWords; i++ ) {
|
|
||||||
cacheList.add("word_" + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
System.out.println("Build model....");
|
|
||||||
val tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(1000)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
//.usePca(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
System.out.println("fit");
|
|
||||||
INDArray weights = Nd4j.rand(new int[]{nWords, 100});
|
|
||||||
weights.getRow(1).assign(0);
|
|
||||||
try {
|
|
||||||
tsne.fit(weights);
|
|
||||||
} catch (IllegalStateException e){
|
|
||||||
assertTrue(e.getMessage().contains("may not be defined"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
//package org.deeplearning4j.plot;
|
|
||||||
//
|
|
||||||
//import lombok.extern.slf4j.Slf4j;
|
|
||||||
//import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
|
||||||
//import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
|
||||||
//import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
|
||||||
//import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
||||||
//import org.junit.Test;
|
|
||||||
//import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
//import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
//import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
//import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
//import org.nd4j.linalg.primitives.Pair;
|
|
||||||
//
|
|
||||||
//import java.io.File;
|
|
||||||
//import java.util.ArrayList;
|
|
||||||
//import java.util.List;
|
|
||||||
//
|
|
||||||
//@Slf4j
|
|
||||||
//public class TsneTest {
|
|
||||||
//
|
|
||||||
// @Test
|
|
||||||
// public void testSimple() throws Exception {
|
|
||||||
// //Simple sanity check
|
|
||||||
//
|
|
||||||
// for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}){
|
|
||||||
//
|
|
||||||
// //STEP 1: Initialization
|
|
||||||
// int iterations = 100;
|
|
||||||
// //create an n-dimensional array of doubles
|
|
||||||
// Nd4j.setDataType(DataType.DOUBLE);
|
|
||||||
// List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
//
|
|
||||||
// //STEP 2: Turn text input into a list of words
|
|
||||||
// log.info("Load & Vectorize data....");
|
|
||||||
// File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
|
||||||
// //Get the data of all unique word vectors
|
|
||||||
// Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
|
||||||
// VocabCache cache = vectors.getSecond();
|
|
||||||
// INDArray weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
|
||||||
//
|
|
||||||
// for(int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
|
||||||
// cacheList.add(cache.wordAtIndex(i));
|
|
||||||
//
|
|
||||||
// //STEP 3: build a dual-tree tsne to use later
|
|
||||||
// log.info("Build model....");
|
|
||||||
// BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
// .setMaxIter(iterations).theta(0.5)
|
|
||||||
// .normalize(false)
|
|
||||||
// .learningRate(500)
|
|
||||||
// .useAdaGrad(false)
|
|
||||||
// .workspaceMode(wsm)
|
|
||||||
// .build();
|
|
||||||
//
|
|
||||||
// //STEP 4: establish the tsne values and save them to a file
|
|
||||||
// log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
// String outputFile = "target/archive-tmp/tsne-standard-coords.csv";
|
|
||||||
// (new File(outputFile)).getParentFile().mkdirs();
|
|
||||||
//
|
|
||||||
// tsne.fit(weights);
|
|
||||||
// tsne.saveAsFile(cacheList, outputFile);
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//}
|
|
|
@ -1,51 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-manifold</artifactId>
|
|
||||||
<packaging>pom</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-manifold</name>
|
|
||||||
|
|
||||||
<modules>
|
|
||||||
<module>deeplearning4j-tsne</module>
|
|
||||||
</modules>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -127,6 +127,51 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<inherited>true</inherited>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<configuration>
|
||||||
|
<environmentVariables>
|
||||||
|
|
||||||
|
</environmentVariables>
|
||||||
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*.java</include>
|
||||||
|
<include>**/*.java</include>
|
||||||
|
<include>**/Test*.java</include>
|
||||||
|
<include>**/*Test.java</include>
|
||||||
|
<include>**/*TestCase.java</include>
|
||||||
|
</includes>
|
||||||
|
<junitArtifactName>junit:junit</junitArtifactName>
|
||||||
|
<systemPropertyVariables>
|
||||||
|
<org.nd4j.linalg.defaultbackend>
|
||||||
|
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||||
|
</org.nd4j.linalg.defaultbackend>
|
||||||
|
<org.nd4j.linalg.tests.backendstorun>
|
||||||
|
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||||
|
</org.nd4j.linalg.tests.backendstorun>
|
||||||
|
</systemPropertyVariables>
|
||||||
|
<!--
|
||||||
|
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
||||||
|
Depending on a build machine, default value is not always enough.
|
||||||
|
|
||||||
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
|
-->
|
||||||
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
|
@ -138,6 +183,47 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.maven.surefire</groupId>
|
||||||
|
<artifactId>surefire-junit47</artifactId>
|
||||||
|
<version>2.19.1</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<configuration>
|
||||||
|
<environmentVariables>
|
||||||
|
</environmentVariables>
|
||||||
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*.java</include>
|
||||||
|
<include>**/*.java</include>
|
||||||
|
<include>**/Test*.java</include>
|
||||||
|
<include>**/*Test.java</include>
|
||||||
|
<include>**/*TestCase.java</include>
|
||||||
|
</includes>
|
||||||
|
<junitArtifactName>junit:junit</junitArtifactName>
|
||||||
|
<systemPropertyVariables>
|
||||||
|
<org.nd4j.linalg.defaultbackend>
|
||||||
|
org.nd4j.linalg.jcublas.JCublasBackend
|
||||||
|
</org.nd4j.linalg.defaultbackend>
|
||||||
|
<org.nd4j.linalg.tests.backendstorun>
|
||||||
|
org.nd4j.linalg.jcublas.JCublasBackend
|
||||||
|
</org.nd4j.linalg.tests.backendstorun>
|
||||||
|
</systemPropertyVariables>
|
||||||
|
<!--
|
||||||
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
|
Depending on a build machine, default value is not always enough.
|
||||||
|
-->
|
||||||
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -1001,7 +1001,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
|
|
||||||
for (Layer l : netToTest.getLayers()) {
|
for (Layer l : netToTest.getLayers()) {
|
||||||
// Remove any dropout manually - until this is fixed:
|
// Remove any dropout manually - until this is fixed:
|
||||||
// https://github.com/deeplearning4j/deeplearning4j/issues/4368
|
// https://github.com/eclipse/deeplearning4j/issues/4368
|
||||||
l.conf().getLayer().setIDropout(null);
|
l.conf().getLayer().setIDropout(null);
|
||||||
|
|
||||||
//Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable...
|
//Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable...
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.models.embeddings;
|
||||||
|
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.deeplearning4j.core.ui.UiConnectionInfo;
|
import org.deeplearning4j.core.ui.UiConnectionInfo;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -74,27 +73,7 @@ public interface WeightLookupTable<T extends SequenceElement> extends Serializab
|
||||||
*/
|
*/
|
||||||
void resetWeights(boolean reset);
|
void resetWeights(boolean reset);
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via TSNE
|
|
||||||
* @param tsne the tsne to use
|
|
||||||
*/
|
|
||||||
void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via TSNE
|
|
||||||
* @param tsne the tsne to use
|
|
||||||
*/
|
|
||||||
void plotVocab(BarnesHutTsne tsne, int numWords, File file);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via tsne
|
|
||||||
*/
|
|
||||||
void plotVocab(int numWords, UiConnectionInfo connectionInfo);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via tsne
|
|
||||||
*/
|
|
||||||
void plotVocab(int numWords, File file);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -29,7 +29,6 @@ import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.deeplearning4j.core.ui.UiConnectionInfo;
|
import org.deeplearning4j.core.ui.UiConnectionInfo;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -154,123 +153,8 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
|
||||||
initNegative();
|
initNegative();
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> fitTnseAndGetLabels(final BarnesHutTsne tsne, final int numWords) {
|
|
||||||
INDArray array = Nd4j.create(numWords, vectorLength);
|
|
||||||
List<String> labels = new ArrayList<>();
|
|
||||||
for (int i = 0; i < numWords && i < vocab.numWords(); i++) {
|
|
||||||
labels.add(vocab.wordAtIndex(i));
|
|
||||||
array.putRow(i, syn0.slice(i));
|
|
||||||
}
|
|
||||||
tsne.fit(array);
|
|
||||||
return labels;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void plotVocab(BarnesHutTsne tsne, int numWords, File file) {
|
|
||||||
final List<String> labels = fitTnseAndGetLabels(tsne, numWords);
|
|
||||||
try {
|
|
||||||
tsne.saveAsFile(labels, file.getAbsolutePath());
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via tsne
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void plotVocab(int numWords, File file) {
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2)
|
|
||||||
.setMaxIter(1000).build();
|
|
||||||
plotVocab(tsne, numWords, file);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via tsne
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void plotVocab(int numWords, UiConnectionInfo connectionInfo) {
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2)
|
|
||||||
.setMaxIter(1000).build();
|
|
||||||
plotVocab(tsne, numWords, connectionInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Render the words via TSNE
|
|
||||||
*
|
|
||||||
* @param tsne the tsne to use
|
|
||||||
* @param numWords
|
|
||||||
* @param connectionInfo
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo) {
|
|
||||||
try {
|
|
||||||
final List<String> labels = fitTnseAndGetLabels(tsne, numWords);
|
|
||||||
final INDArray reducedData = tsne.getData();
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
for (int i = 0; i < reducedData.rows() && i < numWords; i++) {
|
|
||||||
String word = labels.get(i);
|
|
||||||
INDArray wordVector = reducedData.getRow(i);
|
|
||||||
for (int j = 0; j < wordVector.length(); j++) {
|
|
||||||
sb.append(String.valueOf(wordVector.getDouble(j))).append(",");
|
|
||||||
}
|
|
||||||
sb.append(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
String address = connectionInfo.getFirstPart() + "/tsne/post/" + connectionInfo.getSessionId();
|
|
||||||
// System.out.println("ADDRESS: " + address);
|
|
||||||
URI uri = new URI(address);
|
|
||||||
|
|
||||||
HttpURLConnection connection = (HttpURLConnection) uri.toURL().openConnection();
|
|
||||||
connection.setRequestMethod("POST");
|
|
||||||
connection.setRequestProperty("User-Agent", "Mozilla/5.0");
|
|
||||||
// connection.setRequestProperty("Content-Type", "application/json");
|
|
||||||
connection.setRequestProperty("Content-Type", "multipart/form-data; boundary=-----TSNE-POST-DATA-----");
|
|
||||||
connection.setDoOutput(true);
|
|
||||||
|
|
||||||
final OutputStream outputStream = connection.getOutputStream();
|
|
||||||
final PrintWriter writer = new PrintWriter(outputStream);
|
|
||||||
writer.println("-------TSNE-POST-DATA-----");
|
|
||||||
writer.println("Content-Disposition: form-data; name=\"fileupload\"; filename=\"tsne.csv\"");
|
|
||||||
writer.println("Content-Type: text/plain; charset=UTF-16");
|
|
||||||
writer.println("Content-Transfer-Encoding: binary");
|
|
||||||
writer.println();
|
|
||||||
writer.flush();
|
|
||||||
|
|
||||||
DataOutputStream dos = new DataOutputStream(outputStream);
|
|
||||||
dos.writeBytes(sb.toString());
|
|
||||||
dos.flush();
|
|
||||||
writer.println();
|
|
||||||
writer.flush();
|
|
||||||
dos.close();
|
|
||||||
outputStream.close();
|
|
||||||
|
|
||||||
try {
|
|
||||||
int responseCode = connection.getResponseCode();
|
|
||||||
System.out.println("RESPONSE CODE: " + responseCode);
|
|
||||||
|
|
||||||
if (responseCode != 200) {
|
|
||||||
BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
|
|
||||||
String inputLine;
|
|
||||||
StringBuilder response = new StringBuilder();
|
|
||||||
|
|
||||||
while ((inputLine = in.readLine()) != null) {
|
|
||||||
response.append(inputLine);
|
|
||||||
}
|
|
||||||
in.close();
|
|
||||||
|
|
||||||
log.warn("Error posting to remote UI - received response code {}\tContent: {}", response,
|
|
||||||
response.toString());
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn("Error posting to remote UI at {}", uri, e);
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param codeIndex
|
* @param codeIndex
|
||||||
* @param code
|
* @param code
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -62,152 +61,4 @@ public class TsneTest extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSimple() throws Exception {
|
|
||||||
//Simple sanity check
|
|
||||||
|
|
||||||
for( int test=0; test <=1; test++){
|
|
||||||
boolean syntheticData = test == 1;
|
|
||||||
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
|
||||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
|
||||||
|
|
||||||
//STEP 1: Initialization
|
|
||||||
int iterations = 50;
|
|
||||||
//create an n-dimensional array of doubles
|
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
|
||||||
INDArray weights;
|
|
||||||
if(syntheticData){
|
|
||||||
weights = Nd4j.rand(250, 200);
|
|
||||||
} else {
|
|
||||||
log.info("Load & Vectorize data....");
|
|
||||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
|
||||||
//Get the data of all unique word vectors
|
|
||||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
|
||||||
VocabCache cache = vectors.getSecond();
|
|
||||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
|
||||||
|
|
||||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.workspaceMode(wsm)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
tsne.fit(weights);
|
|
||||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPerformance() throws Exception {
|
|
||||||
|
|
||||||
StopWatch watch = new StopWatch();
|
|
||||||
watch.start();
|
|
||||||
for( int test=0; test <=1; test++){
|
|
||||||
boolean syntheticData = test == 1;
|
|
||||||
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
|
||||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
|
||||||
|
|
||||||
//STEP 1: Initialization
|
|
||||||
int iterations = 50;
|
|
||||||
//create an n-dimensional array of doubles
|
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
|
||||||
INDArray weights;
|
|
||||||
if(syntheticData){
|
|
||||||
weights = Nd4j.rand(DataType.FLOAT, 250, 20);
|
|
||||||
} else {
|
|
||||||
log.info("Load & Vectorize data....");
|
|
||||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
|
||||||
//Get the data of all unique word vectors
|
|
||||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
|
||||||
VocabCache cache = vectors.getSecond();
|
|
||||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
|
||||||
|
|
||||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.workspaceMode(wsm)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
tsne.fit(weights);
|
|
||||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
|
||||||
watch.stop();
|
|
||||||
System.out.println("Elapsed time : " + watch);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void testTSNEPerformance() throws Exception {
|
|
||||||
|
|
||||||
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
|
||||||
|
|
||||||
//STEP 1: Initialization
|
|
||||||
int iterations = 50;
|
|
||||||
//create an n-dimensional array of doubles
|
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
|
||||||
INDArray weights = Nd4j.rand(10000,300);
|
|
||||||
|
|
||||||
StopWatch watch = new StopWatch();
|
|
||||||
watch.start();
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.workspaceMode(wsm)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
watch.stop();
|
|
||||||
System.out.println("Elapsed time for construction: " + watch);
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
|
|
||||||
watch.reset();
|
|
||||||
watch.start();
|
|
||||||
tsne.fit(weights);
|
|
||||||
watch.stop();
|
|
||||||
System.out.println("Elapsed time for fit: " + watch);
|
|
||||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.iterator;
|
package org.deeplearning4j.iterator;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
||||||
|
@ -57,9 +58,11 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
public TestBertIterator() throws IOException {
|
public TestBertIterator() throws IOException {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test()
|
||||||
public void testBertSequenceClassification() throws Exception {
|
public void testBertSequenceClassification() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
int minibatchSize = 2;
|
int minibatchSize = 2;
|
||||||
TestSentenceHelper testHelper = new TestSentenceHelper();
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
|
@ -308,6 +311,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsSingle() throws IOException {
|
public void testSentencePairsSingle() throws IOException {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
boolean prependAppend;
|
boolean prependAppend;
|
||||||
int numOfSentences;
|
int numOfSentences;
|
||||||
|
|
||||||
|
@ -367,7 +373,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsUnequalLengths() throws IOException {
|
public void testSentencePairsUnequalLengths() throws IOException {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
int minibatchSize = 4;
|
int minibatchSize = 4;
|
||||||
int numOfSentencesinIter = 3;
|
int numOfSentencesinIter = 3;
|
||||||
|
|
||||||
|
@ -456,6 +464,9 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairFeaturizer() throws IOException {
|
public void testSentencePairFeaturizer() throws IOException {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
int minibatchSize = 2;
|
int minibatchSize = 2;
|
||||||
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
|
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
@ -43,6 +44,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore
|
||||||
public class FastTextTest extends BaseDL4JTest {
|
public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
|
|
|
@ -23,7 +23,6 @@ package org.deeplearning4j.models.word2vec;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -40,11 +39,5 @@ public class Word2VecVisualizationTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBarnesHutTsneVisualization() throws Exception {
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder().setMaxIter(4).stopLyingIteration(250).learningRate(500)
|
|
||||||
.useAdaGrad(false).theta(0.5).setMomentum(0.5).normalize(true).build();
|
|
||||||
|
|
||||||
//vectors.lookupTable().plotVocab(tsne);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIte
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -56,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
* Basically all we want from this test - being able to finish without exceptions.
|
* Basically all we want from this test - being able to finish without exceptions.
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
|
@Ignore
|
||||||
public void testIterator1() throws Exception {
|
public void testIterator1() throws Exception {
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
|
|
@ -42,6 +42,7 @@ import java.util.List;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore
|
||||||
public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
||||||
|
|
||||||
private File pathToVocab = Resources.asFile("other/vocab.txt");
|
private File pathToVocab = Resources.asFile("other/vocab.txt");
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class LocalResponseNormalization
|
||||||
dataType);
|
dataType);
|
||||||
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
|
log.debug("CudnnLocalResponseNormalizationHelper successfully initialized");
|
||||||
}
|
}
|
||||||
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272
|
//2019-03-09 AB - MKL-DNN helper disabled: https://github.com/eclipse/deeplearning4j/issues/7272
|
||||||
// else if("CPU".equalsIgnoreCase(backend)){
|
// else if("CPU".equalsIgnoreCase(backend)){
|
||||||
// helper = new MKLDNNLocalResponseNormalizationHelper();
|
// helper = new MKLDNNLocalResponseNormalizationHelper();
|
||||||
// log.debug("Created MKLDNNLocalResponseNormalizationHelper");
|
// log.debug("Created MKLDNNLocalResponseNormalizationHelper");
|
||||||
|
|
|
@ -953,7 +953,7 @@ public class ModelSerializer {
|
||||||
|
|
||||||
|
|
||||||
private static void checkInputStream(InputStream inputStream) throws IOException {
|
private static void checkInputStream(InputStream inputStream) throws IOException {
|
||||||
//available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887
|
//available method can return 0 in some cases: https://github.com/eclipse/deeplearning4j/issues/4887
|
||||||
int available;
|
int available;
|
||||||
try{
|
try{
|
||||||
//InputStream.available(): A subclass' implementation of this method may choose to throw an IOException
|
//InputStream.available(): A subclass' implementation of this method may choose to throw an IOException
|
||||||
|
|
|
@ -370,7 +370,7 @@ public class NetworkUtils {
|
||||||
final String message;
|
final String message;
|
||||||
if (model.getClass().getName().startsWith("org.deeplearning4j")) {
|
if (model.getClass().getName().startsWith("org.deeplearning4j")) {
|
||||||
message = model.getClass().getName() + " models are not yet supported and " +
|
message = model.getClass().getName() + " models are not yet supported and " +
|
||||||
"pull requests are welcome: https://github.com/deeplearning4j/deeplearning4j";
|
"pull requests are welcome: https://github.com/eclipse/deeplearning4j";
|
||||||
} else {
|
} else {
|
||||||
message = model.getClass().getName() + " models are unsupported.";
|
message = model.getClass().getName() + " models are unsupported.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.models.sequencevectors;
|
package org.deeplearning4j.spark.models.sequencevectors;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -87,6 +88,11 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFrequenciesCount() throws Exception {
|
public void testFrequenciesCount() throws Exception {
|
||||||
|
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
|
JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
|
||||||
|
|
||||||
SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
|
SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.models.embeddings.word2vec;
|
package org.deeplearning4j.spark.models.embeddings.word2vec;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -54,6 +55,10 @@ public class Word2VecTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConcepts() throws Exception {
|
public void testConcepts() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
// These are all default values for word2vec
|
// These are all default values for word2vec
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
||||||
.set("spark.driver.host", "localhost")
|
.set("spark.driver.host", "localhost")
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.text;
|
package org.deeplearning4j.spark.text;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -94,6 +95,10 @@ public class TextPipelineTest extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTokenizer() throws Exception {
|
public void testTokenizer() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
|
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
|
||||||
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
|
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver.accumulation;
|
package org.deeplearning4j.spark.parameterserver.accumulation;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -33,6 +34,10 @@ public class SharedTrainingAccumulationFunctionTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAccumulation1() throws Exception {
|
public void testAccumulation1() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
||||||
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
||||||
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver.accumulation;
|
package org.deeplearning4j.spark.parameterserver.accumulation;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
|
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -36,6 +37,10 @@ public class SharedTrainingAggregateFunctionTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAggregate1() throws Exception {
|
public void testAggregate1() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
||||||
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
||||||
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver.iterators;
|
package org.deeplearning4j.spark.parameterserver.iterators;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -39,6 +40,10 @@ public class VirtualDataSetIteratorTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple1() throws Exception {
|
public void testSimple1() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<Iterator<DataSet>> iterators = new ArrayList<>();
|
List<Iterator<DataSet>> iterators = new ArrayList<>();
|
||||||
|
|
||||||
List<DataSet> first = new ArrayList<>();
|
List<DataSet> first = new ArrayList<>();
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver.iterators;
|
package org.deeplearning4j.spark.parameterserver.iterators;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -36,6 +37,10 @@ public class VirtualIteratorTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIteration1() throws Exception {
|
public void testIteration1() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<Integer> integers = new ArrayList<>();
|
List<Integer> integers = new ArrayList<>();
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
integers.add(i);
|
integers.add(i);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver.modelimport.elephas;
|
package org.deeplearning4j.spark.parameterserver.modelimport.elephas;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
||||||
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
||||||
|
@ -40,6 +41,10 @@ public class TestElephasImport extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testElephasSequentialImport() throws Exception {
|
public void testElephasSequentialImport() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
String modelPath = "modelimport/elephas/elephas_sequential.h5";
|
String modelPath = "modelimport/elephas/elephas_sequential.h5";
|
||||||
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
||||||
// System.out.println(model.getNetwork().summary());
|
// System.out.println(model.getNetwork().summary());
|
||||||
|
@ -48,7 +53,11 @@ public class TestElephasImport extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testElephasSequentialImportAsync() throws Exception {
|
public void testElephasSequentialImportAsync() throws Exception {
|
||||||
String modelPath = "modelimport/elephas/elephas_sequential_async.h5";
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String modelPath = "modelimport/elephas/elephas_sequential_async.h5";
|
||||||
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
||||||
// System.out.println(model.getNetwork().summary());
|
// System.out.println(model.getNetwork().summary());
|
||||||
assertTrue(model.getTrainingMaster() instanceof SharedTrainingMaster);
|
assertTrue(model.getTrainingMaster() instanceof SharedTrainingMaster);
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
#
|
||||||
|
# /* ******************************************************************************
|
||||||
|
# *
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ******************************************************************************/
|
||||||
|
#
|
||||||
|
|
||||||
|
real.class.double = org.nd4j.linalg.cpu.NDArray
|
||||||
|
shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider
|
||||||
|
constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache
|
||||||
|
affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager
|
||||||
|
memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager
|
||||||
|
dtype = float
|
||||||
|
blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper
|
||||||
|
|
||||||
|
native.ops= org.nd4j.nativeblas.Nd4jCpu
|
||||||
|
ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory
|
||||||
|
ndarray.order = c
|
||||||
|
resourcemanager_state = false
|
||||||
|
databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory
|
||||||
|
workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager
|
||||||
|
alloc = javacpp
|
||||||
|
opexec= org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner
|
||||||
|
opexec.mode= native
|
||||||
|
random=org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark;
|
package org.deeplearning4j.spark;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
@ -63,6 +64,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEarlyStoppingIris() {
|
public void testEarlyStoppingIris() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
||||||
|
@ -113,7 +118,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBadTuning() {
|
public void testBadTuning() {
|
||||||
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
|
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -150,7 +158,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testTimeTermination() {
|
public void testTimeTermination() {
|
||||||
//test termination after max time
|
//test termination after max time
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -193,7 +204,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
public void testNoImprovementNEpochsTermination() {
|
public void testNoImprovementNEpochsTermination() {
|
||||||
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
|
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
|
||||||
//Simulate this by setting LR = 0.0
|
//Simulate this by setting LR = 0.0
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -228,6 +242,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListeners() {
|
public void testListeners() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark;
|
package org.deeplearning4j.spark;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
@ -66,6 +67,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEarlyStoppingIris() {
|
public void testEarlyStoppingIris() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
||||||
|
@ -114,7 +119,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBadTuning() {
|
public void testBadTuning() {
|
||||||
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
|
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -152,7 +160,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testTimeTermination() {
|
public void testTimeTermination() {
|
||||||
//test termination after max time
|
//test termination after max time
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -197,7 +208,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
public void testNoImprovementNEpochsTermination() {
|
public void testNoImprovementNEpochsTermination() {
|
||||||
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
|
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
|
||||||
//Simulate this by setting LR = 0.0
|
//Simulate this by setting LR = 0.0
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
@ -235,6 +249,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListeners() {
|
public void testListeners() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.datavec;
|
package org.deeplearning4j.spark.datavec;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.hadoop.io.Text;
|
import org.apache.hadoop.io.Text;
|
||||||
|
@ -68,6 +69,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataVecDataSetFunction() throws Exception {
|
public void testDataVecDataSetFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
|
@ -178,6 +183,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataVecSequenceDataSetFunction() throws Exception {
|
public void testDataVecSequenceDataSetFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
//Test Spark record reader functionality vs. local
|
//Test Spark record reader functionality vs. local
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.newFolder();
|
||||||
|
@ -236,6 +245,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataVecSequencePairDataSetFunction() throws Exception {
|
public void testDataVecSequencePairDataSetFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
|
@ -332,7 +345,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testDataVecSequencePairDataSetFunctionVariableLength() throws Exception {
|
public void testDataVecSequencePairDataSetFunctionVariableLength() throws Exception {
|
||||||
//Same sort of test as testDataVecSequencePairDataSetFunction() but with variable length time series (labels shorter, align end)
|
//Same sort of test as testDataVecSequencePairDataSetFunction() but with variable length time series (labels shorter, align end)
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
File dirFeatures = testDir.newFolder();
|
File dirFeatures = testDir.newFolder();
|
||||||
ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/");
|
ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/");
|
||||||
cpr.copyDirectory(dirFeatures);
|
cpr.copyDirectory(dirFeatures);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.datavec;
|
package org.deeplearning4j.spark.datavec;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -44,6 +45,10 @@ public class TestExport extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchAndExportDataSetsFunction() throws Exception {
|
public void testBatchAndExportDataSetsFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
String baseDir = System.getProperty("java.io.tmpdir");
|
String baseDir = System.getProperty("java.io.tmpdir");
|
||||||
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/");
|
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/");
|
||||||
baseDir = baseDir.replaceAll("\\\\", "/");
|
baseDir = baseDir.replaceAll("\\\\", "/");
|
||||||
|
@ -102,6 +107,10 @@ public class TestExport extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
|
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
String baseDir = System.getProperty("java.io.tmpdir");
|
String baseDir = System.getProperty("java.io.tmpdir");
|
||||||
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
|
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
|
||||||
baseDir = baseDir.replaceAll("\\\\", "/");
|
baseDir = baseDir.replaceAll("\\\\", "/");
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.datavec;
|
package org.deeplearning4j.spark.datavec;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
|
@ -63,6 +64,10 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testPreprocessedData() {
|
public void testPreprocessedData() {
|
||||||
//Test _loading_ of preprocessed data
|
//Test _loading_ of preprocessed data
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 10;
|
int batchSizePerExecutor = 10;
|
||||||
|
|
||||||
|
@ -109,6 +114,10 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testPreprocessedDataCompGraphDataSet() {
|
public void testPreprocessedDataCompGraphDataSet() {
|
||||||
//Test _loading_ of preprocessed DataSet data
|
//Test _loading_ of preprocessed DataSet data
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 10;
|
int batchSizePerExecutor = 10;
|
||||||
|
|
||||||
|
@ -157,6 +166,10 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testPreprocessedDataCompGraphMultiDataSet() throws IOException {
|
public void testPreprocessedDataCompGraphMultiDataSet() throws IOException {
|
||||||
//Test _loading_ of preprocessed MultiDataSet data
|
//Test _loading_ of preprocessed MultiDataSet data
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 10;
|
int batchSizePerExecutor = 10;
|
||||||
|
|
||||||
|
@ -206,6 +219,10 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvPreprocessedDataGeneration() throws Exception {
|
public void testCsvPreprocessedDataGeneration() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
|
@ -292,6 +309,10 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
|
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
|
||||||
//Same as above test, but without any labels (in which case: input and output arrays are the same)
|
//Same as above test, but without any labels (in which case: input and output arrays are the same)
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.impl.customlayer;
|
package org.deeplearning4j.spark.impl.customlayer;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -44,6 +45,10 @@ public class TestCustomLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSparkWithCustomLayer() {
|
public void testSparkWithCustomLayer() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Basic test - checks whether exceptions etc are thrown with custom layers + spark
|
//Basic test - checks whether exceptions etc are thrown with custom layers + spark
|
||||||
//Custom layers are tested more extensively in dl4j core
|
//Custom layers are tested more extensively in dl4j core
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.impl.multilayer;
|
package org.deeplearning4j.spark.impl.multilayer;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
@ -69,6 +70,10 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEvaluationSimple() throws Exception {
|
public void testEvaluationSimple() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.impl.paramavg;
|
package org.deeplearning4j.spark.impl.paramavg;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -65,57 +66,57 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
private static MultiLayerConfiguration getConf(int seed, IUpdater updater) {
|
private static MultiLayerConfiguration getConf(int seed, IUpdater updater) {
|
||||||
Nd4j.getRandom().setSeed(seed);
|
Nd4j.getRandom().setSeed(seed);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
|
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder()
|
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder()
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
.lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
||||||
.build();
|
.build();
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) {
|
private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) {
|
||||||
Nd4j.getRandom().setSeed(seed);
|
Nd4j.getRandom().setSeed(seed);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
|
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
|
||||||
.layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
.layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
.layer(1, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
.layer(1, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
|
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
|
||||||
.build())
|
.build())
|
||||||
.setInputType(InputType.convolutional(10, 10, 3)).build();
|
.setInputType(InputType.convolutional(10, 10, 3)).build();
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) {
|
private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) {
|
||||||
Nd4j.getRandom().setSeed(seed);
|
Nd4j.getRandom().setSeed(seed);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
|
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1",
|
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1",
|
||||||
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10)
|
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10)
|
||||||
.nOut(10).build(),
|
.nOut(10).build(),
|
||||||
"0")
|
"0")
|
||||||
.setOutputs("1").build();
|
.setOutputs("1").build();
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) {
|
private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) {
|
||||||
Nd4j.getRandom().setSeed(seed);
|
Nd4j.getRandom().setSeed(seed);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
|
.weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.addLayer("0", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
|
.addLayer("0", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
|
||||||
.padding(0, 0).activation(Activation.TANH).build(), "in")
|
.padding(0, 0).activation(Activation.TANH).build(), "in")
|
||||||
.addLayer("1", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
|
.addLayer("1", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
|
||||||
.padding(0, 0).activation(Activation.TANH).build(), "0")
|
.padding(0, 0).activation(Activation.TANH).build(), "0")
|
||||||
.addLayer("2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
|
.addLayer("2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
|
||||||
.build(), "1")
|
.build(), "1")
|
||||||
.setOutputs("2").setInputTypes(InputType.convolutional(10, 10, 3))
|
.setOutputs("2").setInputTypes(InputType.convolutional(10, 10, 3))
|
||||||
.build();
|
.build();
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,8 +126,8 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
private static TrainingMaster getTrainingMaster(int avgFreq, int miniBatchSize, boolean saveUpdater) {
|
private static TrainingMaster getTrainingMaster(int avgFreq, int miniBatchSize, boolean saveUpdater) {
|
||||||
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
||||||
.averagingFrequency(avgFreq).batchSizePerWorker(miniBatchSize).saveUpdater(saveUpdater)
|
.averagingFrequency(avgFreq).batchSizePerWorker(miniBatchSize).saveUpdater(saveUpdater)
|
||||||
.aggregationDepth(2).workerPrefetchNumBatches(0).build();
|
.aggregationDepth(2).workerPrefetchNumBatches(0).build();
|
||||||
return tm;
|
return tm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,6 +175,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOneExecutor() {
|
public void testOneExecutor() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Idea: single worker/executor on Spark should give identical results to a single machine
|
//Idea: single worker/executor on Spark should give identical results to a single machine
|
||||||
|
|
||||||
int miniBatchSize = 10;
|
int miniBatchSize = 10;
|
||||||
|
@ -224,6 +229,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOneExecutorGraph() {
|
public void testOneExecutorGraph() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Idea: single worker/executor on Spark should give identical results to a single machine
|
//Idea: single worker/executor on Spark should give identical results to a single machine
|
||||||
|
|
||||||
int miniBatchSize = 10;
|
int miniBatchSize = 10;
|
||||||
|
@ -251,7 +260,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
//Do training on Spark with one executor, for 3 separate minibatches
|
//Do training on Spark with one executor, for 3 separate minibatches
|
||||||
TrainingMaster tm = getTrainingMaster(1, miniBatchSize, saveUpdater);
|
TrainingMaster tm = getTrainingMaster(1, miniBatchSize, saveUpdater);
|
||||||
SparkComputationGraph sparkNet =
|
SparkComputationGraph sparkNet =
|
||||||
new SparkComputationGraph(sc, getGraphConf(12345, new RmsProp(0.5)), tm);
|
new SparkComputationGraph(sc, getGraphConf(12345, new RmsProp(0.5)), tm);
|
||||||
sparkNet.setCollectTrainingStats(true);
|
sparkNet.setCollectTrainingStats(true);
|
||||||
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
|
@ -312,10 +321,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
//Do training on Spark with one executor, for 3 separate minibatches
|
//Do training on Spark with one executor, for 3 separate minibatches
|
||||||
// TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater);
|
// TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater);
|
||||||
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
||||||
.averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker)
|
.averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker)
|
||||||
.saveUpdater(saveUpdater).workerPrefetchNumBatches(0)
|
.saveUpdater(saveUpdater).workerPrefetchNumBatches(0)
|
||||||
// .rddTrainingApproach(RDDTrainingApproach.Direct)
|
// .rddTrainingApproach(RDDTrainingApproach.Direct)
|
||||||
.rddTrainingApproach(RDDTrainingApproach.Export).build();
|
.rddTrainingApproach(RDDTrainingApproach.Export).build();
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConf(12345, new Sgd(0.5)), tm);
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConf(12345, new Sgd(0.5)), tm);
|
||||||
sparkNet.setCollectTrainingStats(true);
|
sparkNet.setCollectTrainingStats(true);
|
||||||
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
@ -355,6 +364,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAverageEveryStepCNN() {
|
public void testAverageEveryStepCNN() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
||||||
// on a single machine for synchronous distributed training
|
// on a single machine for synchronous distributed training
|
||||||
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
||||||
|
@ -387,16 +400,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
//Do training on Spark with one executor, for 3 separate minibatches
|
//Do training on Spark with one executor, for 3 separate minibatches
|
||||||
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)
|
||||||
.averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker)
|
.averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker)
|
||||||
.saveUpdater(saveUpdater).workerPrefetchNumBatches(0)
|
.saveUpdater(saveUpdater).workerPrefetchNumBatches(0)
|
||||||
.rddTrainingApproach(RDDTrainingApproach.Export).build();
|
.rddTrainingApproach(RDDTrainingApproach.Export).build();
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConfCNN(12345, new Sgd(0.5)), tm);
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConfCNN(12345, new Sgd(0.5)), tm);
|
||||||
sparkNet.setCollectTrainingStats(true);
|
sparkNet.setCollectTrainingStats(true);
|
||||||
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
for (int i = 0; i < seeds.length; i++) {
|
for (int i = 0; i < seeds.length; i++) {
|
||||||
List<DataSet> list =
|
List<DataSet> list =
|
||||||
getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
|
getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
|
||||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||||
|
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
|
@ -427,6 +440,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAverageEveryStepGraph() {
|
public void testAverageEveryStepGraph() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
||||||
// on a single machine for synchronous distributed training
|
// on a single machine for synchronous distributed training
|
||||||
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
||||||
|
@ -506,6 +523,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAverageEveryStepGraphCNN() {
|
public void testAverageEveryStepGraphCNN() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
//Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
|
||||||
// on a single machine for synchronous distributed training
|
// on a single machine for synchronous distributed training
|
||||||
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
//BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
|
||||||
|
@ -544,7 +565,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
|
|
||||||
for (int i = 0; i < seeds.length; i++) {
|
for (int i = 0; i < seeds.length; i++) {
|
||||||
List<DataSet> list =
|
List<DataSet> list =
|
||||||
getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
|
getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
|
||||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||||
|
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.spark.impl.paramavg;
|
package org.deeplearning4j.spark.impl.paramavg;
|
||||||
|
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.FileSystem;
|
import org.apache.hadoop.fs.FileSystem;
|
||||||
import org.apache.hadoop.fs.LocatedFileStatus;
|
import org.apache.hadoop.fs.LocatedFileStatus;
|
||||||
|
@ -113,6 +114,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFromSvmLightBackprop() throws Exception {
|
public void testFromSvmLightBackprop() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
|
new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
|
||||||
|
@ -145,6 +150,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFromSvmLight() throws Exception {
|
public void testFromSvmLight() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
|
new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
|
||||||
|
@ -175,7 +184,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRunIteration() {
|
public void testRunIteration() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
||||||
List<DataSet> list = dataSet.asList();
|
List<DataSet> list = dataSet.asList();
|
||||||
JavaRDD<DataSet> data = sc.parallelize(list);
|
JavaRDD<DataSet> data = sc.parallelize(list);
|
||||||
|
@ -195,6 +207,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUpdaters() {
|
public void testUpdaters() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
|
||||||
|
@ -217,7 +233,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEvaluation() {
|
public void testEvaluation() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
|
||||||
|
@ -250,7 +269,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
public void testSmallAmountOfData() {
|
public void testSmallAmountOfData() {
|
||||||
//Idea: Test spark training where some executors don't get any data
|
//Idea: Test spark training where some executors don't get any data
|
||||||
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp())
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3)
|
.layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3)
|
||||||
|
@ -353,6 +375,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
List<DataSet> list = new ArrayList<>();
|
List<DataSet> list = new ArrayList<>();
|
||||||
|
@ -402,7 +428,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFitViaStringPaths() throws Exception {
|
public void testFitViaStringPaths() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
||||||
File tempDirF = tempDir.toFile();
|
File tempDirF = tempDir.toFile();
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
|
@ -466,7 +495,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFitViaStringPathsSize1() throws Exception {
|
public void testFitViaStringPathsSize1() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
||||||
File tempDirF = tempDir.toFile();
|
File tempDirF = tempDir.toFile();
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
|
@ -547,7 +579,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFitViaStringPathsCompGraph() throws Exception {
|
public void testFitViaStringPathsCompGraph() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
||||||
Path tempDir2 = testDir.newFolder("DL4J-testFitViaStringPathsCG-MDS").toPath();
|
Path tempDir2 = testDir.newFolder("DL4J-testFitViaStringPathsCG-MDS").toPath();
|
||||||
File tempDirF = tempDir.toFile();
|
File tempDirF = tempDir.toFile();
|
||||||
|
@ -643,7 +678,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
||||||
public void testSeedRepeatability() throws Exception {
|
public void testSeedRepeatability() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp())
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp())
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.weightInit(WeightInit.XAVIER).list()
|
.weightInit(WeightInit.XAVIER).list()
|
||||||
|
@ -715,6 +753,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIterationCounts() throws Exception {
|
public void testIterationCounts() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
List<DataSet> list = new ArrayList<>();
|
List<DataSet> list = new ArrayList<>();
|
||||||
|
@ -761,6 +803,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIterationCountsGraph() throws Exception {
|
public void testIterationCountsGraph() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
List<DataSet> list = new ArrayList<>();
|
List<DataSet> list = new ArrayList<>();
|
||||||
|
@ -806,7 +852,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimple() {
|
public void testVaePretrainSimple() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -842,7 +888,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimpleCG() {
|
public void testVaePretrainSimpleCG() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -992,7 +1038,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test(timeout = 120000L)
|
||||||
public void testEpochCounter() throws Exception {
|
public void testEpochCounter() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(new OutputLayer.Builder().nIn(4).nOut(3).build())
|
.layer(new OutputLayer.Builder().nIn(4).nOut(3).build())
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.impl.stats;
|
package org.deeplearning4j.spark.impl.stats;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -56,6 +57,10 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStatsCollection() throws Exception {
|
public void testStatsCollection() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int nWorkers = numExecutors();
|
int nWorkers = numExecutors();
|
||||||
|
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.ui;
|
package org.deeplearning4j.spark.ui;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.core.storage.Persistable;
|
import org.deeplearning4j.core.storage.Persistable;
|
||||||
|
@ -52,7 +53,10 @@ public class TestListeners extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStatsCollection() {
|
public void testStatsCollection() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
int nExecutors = numExecutors();
|
int nExecutors = numExecutors();
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.util;
|
package org.deeplearning4j.spark.util;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.spark.Partitioner;
|
import org.apache.spark.Partitioner;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -50,6 +51,10 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning() {
|
public void testRepartitioning() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
for (int i = 0; i < 1000; i++) {
|
for (int i = 0; i < 1000; i++) {
|
||||||
list.add(String.valueOf(i));
|
list.add(String.valueOf(i));
|
||||||
|
@ -71,7 +76,10 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning2() throws Exception {
|
public void testRepartitioning2() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
int[] ns;
|
int[] ns;
|
||||||
if(isIntegrationTests()){
|
if(isIntegrationTests()){
|
||||||
ns = new int[]{320, 321, 25600, 25601, 25615};
|
ns = new int[]{320, 321, 25600, 25601, 25615};
|
||||||
|
@ -133,7 +141,10 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning3(){
|
public void testRepartitioning3(){
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
//Initial partitions (idx, count) - [(0,29), (1,29), (2,29), (3,34), (4,34), (5,35), (6,34)]
|
//Initial partitions (idx, count) - [(0,29), (1,29), (2,29), (3,34), (4,34), (5,35), (6,34)]
|
||||||
|
|
||||||
List<Integer> ints = new ArrayList<>();
|
List<Integer> ints = new ArrayList<>();
|
||||||
|
@ -194,9 +205,13 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning4(){
|
public void testRepartitioning4() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<Integer> ints = new ArrayList<>();
|
List<Integer> ints = new ArrayList<>();
|
||||||
for( int i=0; i<7040; i++ ){
|
for( int i = 0; i < 7040; i++) {
|
||||||
ints.add(i);
|
ints.add(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,6 +245,10 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioningApprox() {
|
public void testRepartitioningApprox() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
for (int i = 0; i < 1000; i++) {
|
for (int i = 0; i < 1000; i++) {
|
||||||
list.add(String.valueOf(i));
|
list.add(String.valueOf(i));
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.util;
|
package org.deeplearning4j.spark.util;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.spark.BaseSparkTest;
|
import org.deeplearning4j.spark.BaseSparkTest;
|
||||||
import org.deeplearning4j.spark.util.data.SparkDataValidation;
|
import org.deeplearning4j.spark.util.data.SparkDataValidation;
|
||||||
|
@ -46,10 +47,13 @@ public class TestValidation extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataSetValidation() throws Exception {
|
public void testDataSetValidation() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
File f = folder.newFolder();
|
File f = folder.newFolder();
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i = 0; i < 3; i++ ) {
|
||||||
DataSet ds = new DataSet(Nd4j.create(1,10), Nd4j.create(1,10));
|
DataSet ds = new DataSet(Nd4j.create(1,10), Nd4j.create(1,10));
|
||||||
ds.save(new File(f, i + ".bin"));
|
ds.save(new File(f, i + ".bin"));
|
||||||
}
|
}
|
||||||
|
@ -110,10 +114,13 @@ public class TestValidation extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiDataSetValidation() throws Exception {
|
public void testMultiDataSetValidation() throws Exception {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
//Spark tests don't run on windows
|
||||||
|
return;
|
||||||
|
}
|
||||||
File f = folder.newFolder();
|
File f = folder.newFolder();
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i = 0; i < 3; i++ ) {
|
||||||
MultiDataSet ds = new MultiDataSet(Nd4j.create(1,10), Nd4j.create(1,10));
|
MultiDataSet ds = new MultiDataSet(Nd4j.create(1,10), Nd4j.create(1,10));
|
||||||
ds.save(new File(f, i + ".bin"));
|
ds.save(new File(f, i + ".bin"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.ui;
|
package org.deeplearning4j.ui;
|
||||||
|
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -38,34 +37,6 @@ import java.util.List;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class ApiTest {
|
public class ApiTest {
|
||||||
@Test
|
|
||||||
@Ignore
|
|
||||||
public void testUpdateCoords() throws Exception {
|
|
||||||
Nd4j.factory().setDType(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).theta(0.5).learningRate(500)
|
|
||||||
.useAdaGrad(false).numDimension(2).build();
|
|
||||||
|
|
||||||
File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt");
|
|
||||||
INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), " ").get(NDArrayIndex.interval(0, 100),
|
|
||||||
NDArrayIndex.interval(0, 784));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt");
|
|
||||||
List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100);
|
|
||||||
b.fit(data);
|
|
||||||
b.saveAsFile(labelsList, "coords.csv");
|
|
||||||
// String coords = client.target("http://localhost:8080").path("api").path("update")
|
|
||||||
// .request().accept(MediaType.APPLICATION_JSON)
|
|
||||||
//// .post(Entity.entity(new UrlResource("http://localhost:8080/api/coords.csv"), MediaType.APPLICATION_JSON))
|
|
||||||
// .readEntity(String.class);
|
|
||||||
// ObjectMapper mapper = new ObjectMapper();
|
|
||||||
// List<String> testLines = mapper.readValue(coords,List.class);
|
|
||||||
// List<String> lines = IOUtils.readLines(new FileInputStream("coords.csv"));
|
|
||||||
// assertEquals(testLines,lines);
|
|
||||||
|
|
||||||
throw new RuntimeException("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,6 @@ import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
import org.deeplearning4j.plot.BarnesHutTsne;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||||
|
@ -84,7 +83,6 @@ import static org.junit.Assert.fail;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ManualTests {
|
public class ManualTests {
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(ManualTests.class);
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLaunch() throws Exception {
|
public void testLaunch() throws Exception {
|
||||||
|
@ -100,33 +98,7 @@ public class ManualTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
|
||||||
public void testTsne() throws Exception {
|
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
|
||||||
BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500)
|
|
||||||
.useAdaGrad(true).build();
|
|
||||||
|
|
||||||
File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt");
|
|
||||||
INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), " ").get(NDArrayIndex.interval(0, 100),
|
|
||||||
NDArrayIndex.interval(0, 784));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt");
|
|
||||||
List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100);
|
|
||||||
b.fit(data);
|
|
||||||
File save = new File(System.getProperty("java.io.tmpdir"), "labels-" + UUID.randomUUID().toString());
|
|
||||||
System.out.println("Saved to " + save.getAbsolutePath());
|
|
||||||
save.deleteOnExit();
|
|
||||||
b.saveAsFile(labelsList, save.getAbsolutePath());
|
|
||||||
|
|
||||||
INDArray output = b.getData();
|
|
||||||
System.out.println("Coordinates");
|
|
||||||
|
|
||||||
UIServer server = UIServer.getInstance();
|
|
||||||
Thread.sleep(10000000000L);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test is for manual execution only, since it's here just to get working CNN and visualize it's layers
|
* This test is for manual execution only, since it's here just to get working CNN and visualize it's layers
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
#
|
||||||
|
# /* ******************************************************************************
|
||||||
|
# *
|
||||||
|
# *
|
||||||
|
# * This program and the accompanying materials are made available under the
|
||||||
|
# * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
# *
|
||||||
|
# * See the NOTICE file distributed with this work for additional
|
||||||
|
# * information regarding copyright ownership.
|
||||||
|
# * Unless required by applicable law or agreed to in writing, software
|
||||||
|
# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# * License for the specific language governing permissions and limitations
|
||||||
|
# * under the License.
|
||||||
|
# *
|
||||||
|
# * SPDX-License-Identifier: Apache-2.0
|
||||||
|
# ******************************************************************************/
|
||||||
|
#
|
||||||
|
|
||||||
|
real.class.double = org.nd4j.linalg.cpu.NDArray
|
||||||
|
shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider
|
||||||
|
constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache
|
||||||
|
affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager
|
||||||
|
memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager
|
||||||
|
dtype = float
|
||||||
|
blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper
|
||||||
|
|
||||||
|
native.ops= org.nd4j.nativeblas.Nd4jCpu
|
||||||
|
ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory
|
||||||
|
ndarray.order = c
|
||||||
|
resourcemanager_state = false
|
||||||
|
databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory
|
||||||
|
workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager
|
||||||
|
alloc = javacpp
|
||||||
|
opexec= org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner
|
||||||
|
opexec.mode= native
|
||||||
|
random=org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom
|
|
@ -72,7 +72,7 @@ public abstract class ZooModel<T> implements InstantiableModel {
|
||||||
|
|
||||||
if (!cachedFile.exists()) {
|
if (!cachedFile.exists()) {
|
||||||
log.info("Downloading model to " + cachedFile.toString());
|
log.info("Downloading model to " + cachedFile.toString());
|
||||||
FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile);
|
FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile,Integer.MAX_VALUE,Integer.MAX_VALUE);
|
||||||
} else {
|
} else {
|
||||||
log.info("Using cached model at " + cachedFile.toString());
|
log.info("Using cached model at " + cachedFile.toString());
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ public abstract class ZooModel<T> implements InstantiableModel {
|
||||||
log.error("Checksums do not match. Cleaning up files and failing...");
|
log.error("Checksums do not match. Cleaning up files and failing...");
|
||||||
cachedFile.delete();
|
cachedFile.delete();
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Pretrained model file failed checksum. If this error persists, please open an issue at https://github.com/deeplearning4j/deeplearning4j.");
|
"Pretrained model file failed checksum. If this error persists, please open an issue at https://github.com/eclipse/deeplearning4j.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.zoo.model.VGG16;
|
import org.deeplearning4j.zoo.model.VGG16;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -33,17 +34,16 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@Ignore("Times out too often")
|
||||||
public class MiscTests extends BaseDL4JTest {
|
public class MiscTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 240000L;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTransferVGG() throws Exception {
|
public void testTransferVGG() throws Exception {
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/5167
|
|
||||||
DataSet ds = new DataSet();
|
DataSet ds = new DataSet();
|
||||||
ds.setFeatures(Nd4j.create(1, 3, 224, 224));
|
ds.setFeatures(Nd4j.create(1, 3, 224, 224));
|
||||||
ds.setLabels(Nd4j.create(1, 2));
|
ds.setLabels(Nd4j.create(1, 2));
|
||||||
|
|
|
@ -44,6 +44,7 @@ import java.util.Map;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore("Times out too often")
|
||||||
public class TestDownload extends BaseDL4JTest {
|
public class TestDownload extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -54,6 +54,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore("Times out too often")
|
||||||
public class TestImageNet extends BaseDL4JTest {
|
public class TestImageNet extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -52,6 +52,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assume.assumeTrue;
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Ignore("Times out too often")
|
||||||
public class TestInstantiation extends BaseDL4JTest {
|
public class TestInstantiation extends BaseDL4JTest {
|
||||||
|
|
||||||
protected static void ignoreIfCuda(){
|
protected static void ignoreIfCuda(){
|
||||||
|
|
|
@ -59,7 +59,6 @@
|
||||||
<module>deeplearning4j-modelexport-solr</module>
|
<module>deeplearning4j-modelexport-solr</module>
|
||||||
<module>deeplearning4j-zoo</module>
|
<module>deeplearning4j-zoo</module>
|
||||||
<module>deeplearning4j-data</module>
|
<module>deeplearning4j-data</module>
|
||||||
<module>deeplearning4j-manifold</module>
|
|
||||||
<module>dl4j-integration-tests</module>
|
<module>dl4j-integration-tests</module>
|
||||||
<module>deeplearning4j-common</module>
|
<module>deeplearning4j-common</module>
|
||||||
<module>deeplearning4j-common-tests</module>
|
<module>deeplearning4j-common-tests</module>
|
||||||
|
@ -231,7 +230,7 @@
|
||||||
-->
|
-->
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
<useSystemClassLoader>true</useSystemClassLoader>
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
<argLine> -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
|
@ -292,6 +291,51 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<inherited>true</inherited>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<configuration>
|
||||||
|
<environmentVariables>
|
||||||
|
|
||||||
|
</environmentVariables>
|
||||||
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*.java</include>
|
||||||
|
<include>**/*.java</include>
|
||||||
|
<include>**/Test*.java</include>
|
||||||
|
<include>**/*Test.java</include>
|
||||||
|
<include>**/*TestCase.java</include>
|
||||||
|
</includes>
|
||||||
|
<junitArtifactName>junit:junit</junitArtifactName>
|
||||||
|
<systemPropertyVariables>
|
||||||
|
<org.nd4j.linalg.defaultbackend>
|
||||||
|
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||||
|
</org.nd4j.linalg.defaultbackend>
|
||||||
|
<org.nd4j.linalg.tests.backendstorun>
|
||||||
|
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||||
|
</org.nd4j.linalg.tests.backendstorun>
|
||||||
|
</systemPropertyVariables>
|
||||||
|
<!--
|
||||||
|
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
||||||
|
Depending on a build machine, default value is not always enough.
|
||||||
|
|
||||||
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
|
-->
|
||||||
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||||
<profile>
|
<profile>
|
||||||
|
@ -314,6 +358,47 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<!-- Default to ALL modules here, unlike nd4j-native -->
|
<!-- Default to ALL modules here, unlike nd4j-native -->
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.maven.surefire</groupId>
|
||||||
|
<artifactId>surefire-junit47</artifactId>
|
||||||
|
<version>2.19.1</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<configuration>
|
||||||
|
<environmentVariables>
|
||||||
|
</environmentVariables>
|
||||||
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
<includes>
|
||||||
|
<include>*.java</include>
|
||||||
|
<include>**/*.java</include>
|
||||||
|
<include>**/Test*.java</include>
|
||||||
|
<include>**/*Test.java</include>
|
||||||
|
<include>**/*TestCase.java</include>
|
||||||
|
</includes>
|
||||||
|
<junitArtifactName>junit:junit</junitArtifactName>
|
||||||
|
<systemPropertyVariables>
|
||||||
|
<org.nd4j.linalg.defaultbackend>
|
||||||
|
org.nd4j.linalg.jcublas.JCublasBackend
|
||||||
|
</org.nd4j.linalg.defaultbackend>
|
||||||
|
<org.nd4j.linalg.tests.backendstorun>
|
||||||
|
org.nd4j.linalg.jcublas.JCublasBackend
|
||||||
|
</org.nd4j.linalg.tests.backendstorun>
|
||||||
|
</systemPropertyVariables>
|
||||||
|
<!--
|
||||||
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
|
Depending on a build machine, default value is not always enough.
|
||||||
|
-->
|
||||||
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -59,6 +59,6 @@ fi
|
||||||
unameOut="$(uname)"
|
unameOut="$(uname)"
|
||||||
echo "$OSTYPE"
|
echo "$OSTYPE"
|
||||||
|
|
||||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests.exe
|
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||||
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
||||||
#[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||||
|
|
|
@ -881,7 +881,7 @@ public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp,
|
||||||
for (int i = 0; i < outShape.size(); i++) {
|
for (int i = 0; i < outShape.size(); i++) {
|
||||||
LongShapeDescriptor reqShape = outShape.get(i);
|
LongShapeDescriptor reqShape = outShape.get(i);
|
||||||
|
|
||||||
//Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
//Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/eclipse/deeplearning4j/issues/6872
|
||||||
//As a workaround, we'll use the output variable datatype instead.
|
//As a workaround, we'll use the output variable datatype instead.
|
||||||
DataType dt = sameDiff.getVariable(outNames[i]).dataType();
|
DataType dt = sameDiff.getVariable(outNames[i]).dataType();
|
||||||
DataType currDT = reqShape.dataType();
|
DataType currDT = reqShape.dataType();
|
||||||
|
|
|
@ -189,7 +189,7 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO Temporary workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7102
|
//TODO Temporary workaround for: https://github.com/eclipse/deeplearning4j/issues/7102
|
||||||
if(prob.isView())
|
if(prob.isView())
|
||||||
prob = prob.dup();
|
prob = prob.dup();
|
||||||
if(label.isView())
|
if(label.isView())
|
||||||
|
|
|
@ -221,7 +221,7 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
INDArray prob = predictions2d.getColumn(i, true); //Probability of class i
|
INDArray prob = predictions2d.getColumn(i, true); //Probability of class i
|
||||||
INDArray label = labels2d.getColumn(i, true);
|
INDArray label = labels2d.getColumn(i, true);
|
||||||
//Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7305
|
//Workaround for: https://github.com/eclipse/deeplearning4j/issues/7305
|
||||||
if(prob.rank() == 0)
|
if(prob.rank() == 0)
|
||||||
prob = prob.reshape(1,1);
|
prob = prob.reshape(1,1);
|
||||||
if(label.rank() == 0)
|
if(label.rank() == 0)
|
||||||
|
|
|
@ -73,7 +73,7 @@ public class Min extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//TODO Switch to minimum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp
|
//TODO Switch to minimum_bp op - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp
|
||||||
SDVariable min = outputVariables()[0];
|
SDVariable min = outputVariables()[0];
|
||||||
SDVariable eq1 = sameDiff.eq(larg(), min).castTo(arg(0).dataType());
|
SDVariable eq1 = sameDiff.eq(larg(), min).castTo(arg(0).dataType());
|
||||||
SDVariable eq2 = sameDiff.eq(rarg(), min).castTo(arg(1).dataType());
|
SDVariable eq2 = sameDiff.eq(rarg(), min).castTo(arg(1).dataType());
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class Pow extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//TODO: replace this with discrete op once available: https://github.com/deeplearning4j/deeplearning4j/issues/7461
|
//TODO: replace this with discrete op once available: https://github.com/eclipse/deeplearning4j/issues/7461
|
||||||
//If y=a^b, then:
|
//If y=a^b, then:
|
||||||
//dL/da = b*a^(b-1) * dL/dy
|
//dL/da = b*a^(b-1) * dL/dy
|
||||||
//dL/db = a^b * log(a) * dL/dy
|
//dL/db = a^b * log(a) * dL/dy
|
||||||
|
|
|
@ -84,7 +84,7 @@ public class RandomStandardNormal extends DynamicCustomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class RandomBernoulli extends DynamicCustomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ public class RandomExponential extends DynamicCustomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class RandomNormal extends DynamicCustomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,7 +118,7 @@ public class BernoulliDistribution extends BaseRandomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(dataType);
|
return Collections.singletonList(dataType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,7 +140,7 @@ public class BinomialDistribution extends BaseRandomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,28 +91,28 @@ public class Linspace extends BaseRandomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray x(){
|
public INDArray x(){
|
||||||
//Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723
|
//Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723
|
||||||
//If x or y is present, can't execute this op properly (wrong signature is used)
|
//If x or y is present, can't execute this op properly (wrong signature is used)
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray y(){
|
public INDArray y(){
|
||||||
//Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723
|
//Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723
|
||||||
//If x or y is present, can't execute this op properly (wrong signature is used)
|
//If x or y is present, can't execute this op properly (wrong signature is used)
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setX(INDArray x){
|
public void setX(INDArray x){
|
||||||
//Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723
|
//Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723
|
||||||
//If x or y is present, can't execute this op properly (wrong signature is used)
|
//If x or y is present, can't execute this op properly (wrong signature is used)
|
||||||
this.x = null;
|
this.x = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setY(INDArray y){
|
public void setY(INDArray y){
|
||||||
//Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/6723
|
//Workaround for: https://github.com/eclipse/deeplearning4j/issues/6723
|
||||||
//If x or y is present, can't execute this op properly (wrong signature is used)
|
//If x or y is present, can't execute this op properly (wrong signature is used)
|
||||||
this.y = null;
|
this.y = null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,7 +139,7 @@ public class TruncatedNormalDistribution extends BaseRandomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(DataType.DOUBLE);
|
return Collections.singletonList(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ public class UniformDistribution extends BaseRandomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape; output data type should be any float
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
//TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854
|
||||||
return Collections.singletonList(dataType);
|
return Collections.singletonList(dataType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ public class VersionInfo {
|
||||||
|
|
||||||
public VersionInfo(URI uri) throws IOException {
|
public VersionInfo(URI uri) throws IOException {
|
||||||
//Can't use new File(uri).getPath() for URIs pointing to resources in JARs
|
//Can't use new File(uri).getPath() for URIs pointing to resources in JARs
|
||||||
//But URI.toString() returns "%2520" instead of spaces in path - https://github.com/deeplearning4j/deeplearning4j/issues/6056
|
//But URI.toString() returns "%2520" instead of spaces in path - https://github.com/eclipse/deeplearning4j/issues/6056
|
||||||
String path = uri.toString().replaceAll(HTML_SPACE, " ");
|
String path = uri.toString().replaceAll(HTML_SPACE, " ");
|
||||||
int idxOf = path.lastIndexOf('/');
|
int idxOf = path.lastIndexOf('/');
|
||||||
idxOf = Math.max(idxOf, path.lastIndexOf('\\'));
|
idxOf = Math.max(idxOf, path.lastIndexOf('\\'));
|
||||||
|
|
|
@ -141,7 +141,7 @@
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine>-Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
|
|
|
@ -1,316 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-backends</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>nd4j-tests-tensorflow</artifactId>
|
|
||||||
|
|
||||||
<name>nd4j-tests-tensorflow</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
<maven.compiler.testTarget>1.8</maven.compiler.testTarget>
|
|
||||||
<maven.compiler.testSource>1.8</maven.compiler.testSource>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-tensorflow</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>ch.qos.logback</groupId>
|
|
||||||
<artifactId>logback-classic</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<testSourceDirectory>${test.root}</testSourceDirectory>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-enforcer-plugin</artifactId>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<phase>test</phase>
|
|
||||||
<id>enforce-test-resources</id>
|
|
||||||
<goals>
|
|
||||||
<goal>enforce</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
|
||||||
<rules>
|
|
||||||
<requireActiveProfile>
|
|
||||||
<profiles>nd4j-tf-cpu,nd4j-tf-gpu</profiles>
|
|
||||||
<all>false</all>
|
|
||||||
</requireActiveProfile>
|
|
||||||
</rules>
|
|
||||||
<fail>true</fail>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>testresources</id>
|
|
||||||
<activation>
|
|
||||||
<activeByDefault>true</activeByDefault>
|
|
||||||
</activation>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>tf-cpu</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow-platform</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>tf-gpu</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
<classifier>linux-x86_64-gpu</classifier>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
<classifier>windows-x86_64-gpu</classifier>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>nd4j-tf-gpu</id>
|
|
||||||
<properties>
|
|
||||||
<test.root>src/test/gpujava</test.root>
|
|
||||||
</properties>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-failsafe-plugin</artifactId>
|
|
||||||
<version>2.18</version>
|
|
||||||
<executions>
|
|
||||||
<!--
|
|
||||||
Invokes both the integration-test and the verify goals of the
|
|
||||||
Failsafe Maven plugin
|
|
||||||
-->
|
|
||||||
<execution>
|
|
||||||
<id>integration-tests</id>
|
|
||||||
<phase>test</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>integration-test</goal>
|
|
||||||
<goal>verify</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<!--
|
|
||||||
Skips integration tests if the value of skip.integration.tests
|
|
||||||
property is true
|
|
||||||
-->
|
|
||||||
<skipTests>false</skipTests>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
|
||||||
<artifactId>build-helper-maven-plugin</artifactId>
|
|
||||||
<version>1.9.1</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>add-integration-test-sources</id>
|
|
||||||
<phase>test-compile</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>add-test-source</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<!-- Configures the source directory of our integration tests -->
|
|
||||||
<sources>
|
|
||||||
<source>src/test/gpujava</source>
|
|
||||||
</sources>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<version>${maven-compiler-plugin.version}</version>
|
|
||||||
<configuration>
|
|
||||||
<source>1.8</source>
|
|
||||||
<target>1.8</target>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit47</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<testSourceDirectory>${project.basedir}/src/test/gpujava
|
|
||||||
</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
</includes>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
<skip>false</skip>
|
|
||||||
<argLine>-Xmx6g -Dfile.encoding=UTF-8</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
<classifier>linux-x86_64-gpu</classifier>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
<classifier>windows-x86_64-gpu</classifier>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>nd4j-tf-cpu</id>
|
|
||||||
<properties>
|
|
||||||
<test.root>src/test/cpujava</test.root>
|
|
||||||
</properties>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<version>${maven-compiler-plugin.version}</version>
|
|
||||||
<configuration>
|
|
||||||
<testSource>1.8</testSource>
|
|
||||||
<source>1.8</source>
|
|
||||||
<target>1.8</target>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit47</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<testSourceDirectory>${project.basedir}/src/test/cpujava
|
|
||||||
</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
</includes>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
<argLine>-Xmx6g -Dfile.encoding=UTF-8</argLine>
|
|
||||||
<skipTests>false</skipTests>
|
|
||||||
<skip>false</skip>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.bytedeco</groupId>
|
|
||||||
<artifactId>tensorflow-platform</artifactId>
|
|
||||||
<version>${tensorflow.javacpp.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,193 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* See the NOTICE file distributed with this work for additional
|
|
||||||
* information regarding copyright ownership.
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.bytedeco.tensorflow.TF_Tensor;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.common.resources.Resources;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
|
||||||
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner;
|
|
||||||
import org.nd4j.tensorflow.conversion.graphrunner.SavedModelConfig;
|
|
||||||
import org.tensorflow.framework.ConfigProto;
|
|
||||||
import org.tensorflow.framework.GPUOptions;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
|
||||||
|
|
||||||
public class GraphRunnerTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType getDataType() {
|
|
||||||
return DataType.FLOAT;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType getDefaultFPDataType() {
|
|
||||||
return DataType.FLOAT;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ConfigProto getConfig(){
|
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
|
||||||
if("CUDA".equalsIgnoreCase(backend)) {
|
|
||||||
org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance();
|
|
||||||
ConfigProto.Builder b = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread());
|
|
||||||
return b.setGpuOptions(GPUOptions.newBuilder()
|
|
||||||
.setAllowGrowth(true)
|
|
||||||
.setPerProcessGpuMemoryFraction(0.5)
|
|
||||||
.build()).build();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGraphRunner() throws Exception {
|
|
||||||
List<String> inputs = Arrays.asList("input_0","input_1");
|
|
||||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
|
||||||
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) {
|
|
||||||
runGraphRunnerTest(graphRunner);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGraphRunnerFilePath() throws Exception {
|
|
||||||
List<String> inputs = Arrays.asList("input_0","input_1");
|
|
||||||
byte[] content = FileUtils.readFileToByteArray(Resources.asFile("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb"));
|
|
||||||
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) {
|
|
||||||
runGraphRunnerTest(graphRunner);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testInputOutputResolution() throws Exception {
|
|
||||||
ClassPathResource lenetPb = new ClassPathResource("tf_graphs/lenet_frozen.pb");
|
|
||||||
byte[] content = IOUtils.toByteArray(lenetPb.getInputStream());
|
|
||||||
List<String> inputs = Arrays.asList("Reshape/tensor");
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) {
|
|
||||||
assertEquals(1, graphRunner.getInputOrder().size());
|
|
||||||
assertEquals(1, graphRunner.getOutputOrder().size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test @Ignore //Ignored 2019/02/05: ssd_inception_v2_coco_2019_01_28 does not exist in test resources
|
|
||||||
public void testMultiOutputGraph() throws Exception {
|
|
||||||
List<String> inputs = Arrays.asList("image_tensor");
|
|
||||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream());
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) {
|
|
||||||
String[] outputs = new String[]{"detection_boxes", "detection_scores", "detection_classes", "num_detections"};
|
|
||||||
|
|
||||||
assertEquals(1, graphRunner.getInputOrder().size());
|
|
||||||
System.out.println(graphRunner.getOutputOrder());
|
|
||||||
assertEquals(4, graphRunner.getOutputOrder().size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void runGraphRunnerTest(GraphRunner graphRunner) throws Exception {
|
|
||||||
String json = graphRunner.sessionOptionsToJson();
|
|
||||||
if( json != null ) {
|
|
||||||
org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder();
|
|
||||||
JsonFormat.parser().merge(json, builder);
|
|
||||||
org.tensorflow.framework.ConfigProto build = builder.build();
|
|
||||||
assertEquals(build,graphRunner.getSessionOptionsConfigProto());
|
|
||||||
}
|
|
||||||
assertNotNull(graphRunner.getInputOrder());
|
|
||||||
assertNotNull(graphRunner.getOutputOrder());
|
|
||||||
|
|
||||||
|
|
||||||
org.tensorflow.framework.ConfigProto configProto1 = json == null ? null : GraphRunner.fromJson(json);
|
|
||||||
|
|
||||||
assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1);
|
|
||||||
assertEquals(2,graphRunner.getInputOrder().size());
|
|
||||||
assertEquals(1,graphRunner.getOutputOrder().size());
|
|
||||||
|
|
||||||
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4);
|
|
||||||
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4);
|
|
||||||
|
|
||||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
|
||||||
inputs.put("input_0",input1);
|
|
||||||
inputs.put("input_1",input2);
|
|
||||||
|
|
||||||
for(int i = 0; i < 2; i++) {
|
|
||||||
Map<String,INDArray> outputs = graphRunner.run(inputs);
|
|
||||||
|
|
||||||
INDArray assertion = input1.add(input2);
|
|
||||||
assertEquals(assertion,outputs.get("output"));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGraphRunnerSavedModel() throws Exception {
|
|
||||||
File f = testDir.newFolder("test");
|
|
||||||
new ClassPathResource("/tf_saved_models/saved_model_counter/00000123/").copyDirectory(f);
|
|
||||||
SavedModelConfig savedModelConfig = SavedModelConfig.builder()
|
|
||||||
.savedModelPath(f.getAbsolutePath())
|
|
||||||
.signatureKey("incr_counter_by")
|
|
||||||
.modelTag("serve")
|
|
||||||
.build();
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().savedModelConfig(savedModelConfig).sessionOptionsConfigProto(getConfig()).build()) {
|
|
||||||
INDArray delta = Nd4j.create(new float[] { 42 }, new long[0]);
|
|
||||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
|
||||||
inputs.put("delta:0",delta);
|
|
||||||
Map<String,INDArray> outputs = graphRunner.run(inputs);
|
|
||||||
assertEquals(1, outputs.size());
|
|
||||||
System.out.println(Arrays.toString(outputs.keySet().toArray(new String[0])));
|
|
||||||
INDArray output = outputs.values().toArray(new INDArray[0])[0];
|
|
||||||
assertEquals(42.0, output.getDouble(0), 0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGraphRunnerCast() {
|
|
||||||
INDArray arr = Nd4j.linspace(1,4,4).castTo(DataType.FLOAT);
|
|
||||||
TF_Tensor tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr);
|
|
||||||
TF_Tensor tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.FLOAT,TensorDataType.DOUBLE);
|
|
||||||
INDArray doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor);
|
|
||||||
TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType());
|
|
||||||
|
|
||||||
arr = arr.castTo(DataType.INT);
|
|
||||||
tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr);
|
|
||||||
tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.fromNd4jType(DataType.INT),TensorDataType.DOUBLE);
|
|
||||||
doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor);
|
|
||||||
TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,130 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* See the NOTICE file distributed with this work for additional
|
|
||||||
* information regarding copyright ownership.
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
|
|
||||||
import org.bytedeco.tensorflow.*;
|
|
||||||
import static org.bytedeco.tensorflow.global.tensorflow.*;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
import static org.nd4j.linalg.api.buffer.DataType.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class TensorflowConversionTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testView() {
|
|
||||||
INDArray matrix = Nd4j.linspace(1,8,8).reshape(2,4);
|
|
||||||
INDArray view = matrix.slice(0);
|
|
||||||
TensorflowConversion conversion =TensorflowConversion.getInstance();
|
|
||||||
TF_Tensor tf_tensor = conversion.tensorFromNDArray(view);
|
|
||||||
INDArray converted = conversion.ndArrayFromTensor(tf_tensor);
|
|
||||||
assertEquals(view,converted);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
|
||||||
public void testNullArray() {
|
|
||||||
INDArray array = Nd4j.create(2,2);
|
|
||||||
array.setData(null);
|
|
||||||
TensorflowConversion conversion =TensorflowConversion.getInstance();
|
|
||||||
TF_Tensor tf_tensor = conversion.tensorFromNDArray(array);
|
|
||||||
fail();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testConversionFromNdArray() throws Exception {
|
|
||||||
DataType[] dtypes = new DataType[]{
|
|
||||||
DOUBLE,
|
|
||||||
FLOAT,
|
|
||||||
SHORT,
|
|
||||||
LONG,
|
|
||||||
BYTE,
|
|
||||||
UBYTE,
|
|
||||||
UINT16,
|
|
||||||
UINT32,
|
|
||||||
UINT64,
|
|
||||||
BFLOAT16,
|
|
||||||
BOOL,
|
|
||||||
INT,
|
|
||||||
HALF
|
|
||||||
};
|
|
||||||
for(DataType dtype: dtypes){
|
|
||||||
log.debug("Testing conversion for data type " + dtype);
|
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype);
|
|
||||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
|
||||||
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
|
||||||
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
|
||||||
assertEquals(arr,fromTensor);
|
|
||||||
if (dtype == BOOL){
|
|
||||||
arr.putScalar(3, 0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
arr.addi(1.0);
|
|
||||||
}
|
|
||||||
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
|
||||||
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
|
||||||
assertEquals(arr,fromTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCudaIfAvailable() throws Exception {
|
|
||||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
|
||||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
|
||||||
//byte[] content = Files.readAllBytes(Paths.get(new File("/home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").toURI()));
|
|
||||||
TF_Status status = TF_Status.newStatus();
|
|
||||||
TF_Graph initializedGraphForNd4jDevices = tensorflowConversion.loadGraph(content, status);
|
|
||||||
assertNotNull(initializedGraphForNd4jDevices);
|
|
||||||
|
|
||||||
String deviceName = tensorflowConversion.defaultDeviceForThread();
|
|
||||||
|
|
||||||
byte[] content2 = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
|
||||||
GraphDef graphDef1 = GraphDef.parseFrom(content2);
|
|
||||||
System.out.println(graphDef1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testStringConversion() throws Exception {
|
|
||||||
String[] strings = {"one", "two", "three"};
|
|
||||||
INDArray arr = Nd4j.create(strings);
|
|
||||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
|
||||||
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
|
||||||
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
|
||||||
assertEquals(arr.length(), fromTensor.length());
|
|
||||||
for (int i = 0; i < arr.length(); i++) {
|
|
||||||
assertEquals(strings[i], fromTensor.getString(i));
|
|
||||||
assertEquals(arr.getString(i), fromTensor.getString(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,94 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* See the NOTICE file distributed with this work for additional
|
|
||||||
* information regarding copyright ownership.
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
|
||||||
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner;
|
|
||||||
import org.tensorflow.framework.ConfigProto;
|
|
||||||
import org.tensorflow.framework.GPUOptions;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
|
||||||
|
|
||||||
public class GpuGraphRunnerTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGraphRunner() throws Exception {
|
|
||||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
|
||||||
List<String> inputNames = Arrays.asList("input_0","input_1");
|
|
||||||
|
|
||||||
ConfigProto configProto = ConfigProto.newBuilder()
|
|
||||||
.setGpuOptions(GPUOptions.newBuilder()
|
|
||||||
.setPerProcessGpuMemoryFraction(0.1)
|
|
||||||
.setAllowGrowth(false)
|
|
||||||
.build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputNames).sessionOptionsConfigProto(configProto).build()) {
|
|
||||||
org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder();
|
|
||||||
String json = graphRunner.sessionOptionsToJson();
|
|
||||||
JsonFormat.parser().merge(json,builder);
|
|
||||||
org.tensorflow.framework.ConfigProto build = builder.build();
|
|
||||||
assertEquals(build,graphRunner.getSessionOptionsConfigProto());
|
|
||||||
assertNotNull(graphRunner.getInputOrder());
|
|
||||||
assertNotNull(graphRunner.getOutputOrder());
|
|
||||||
|
|
||||||
|
|
||||||
org.tensorflow.framework.ConfigProto configProto1 = GraphRunner.fromJson(json);
|
|
||||||
|
|
||||||
assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1);
|
|
||||||
assertEquals(2,graphRunner.getInputOrder().size());
|
|
||||||
assertEquals(1,graphRunner.getOutputOrder().size());
|
|
||||||
|
|
||||||
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
|
||||||
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
|
||||||
|
|
||||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
|
||||||
inputs.put("input_0",input1);
|
|
||||||
inputs.put("input_1",input2);
|
|
||||||
|
|
||||||
for(int i = 0; i < 2; i++) {
|
|
||||||
Map<String,INDArray> outputs = graphRunner.run(inputs);
|
|
||||||
|
|
||||||
INDArray assertion = input1.add(input2);
|
|
||||||
assertEquals(assertion,outputs.get("output"));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,2 +1,441 @@
|
||||||
Identity,in_0/read
|
Transpose,transpose
|
||||||
MaxPoolWithArgmax,MaxPoolWithArgmax
|
Identity,conv2d/kernel/read
|
||||||
|
Identity,batch_normalization/gamma/read
|
||||||
|
Identity,batch_normalization/beta/read
|
||||||
|
Identity,batch_normalization/moving_mean/read
|
||||||
|
Identity,batch_normalization/moving_variance/read
|
||||||
|
Identity,conv2d_1/kernel/read
|
||||||
|
Identity,conv2d_2/kernel/read
|
||||||
|
Identity,batch_normalization_1/gamma/read
|
||||||
|
Identity,batch_normalization_1/beta/read
|
||||||
|
Identity,batch_normalization_1/moving_mean/read
|
||||||
|
Identity,batch_normalization_1/moving_variance/read
|
||||||
|
Identity,conv2d_3/kernel/read
|
||||||
|
Identity,batch_normalization_2/gamma/read
|
||||||
|
Identity,batch_normalization_2/beta/read
|
||||||
|
Identity,batch_normalization_2/moving_mean/read
|
||||||
|
Identity,batch_normalization_2/moving_variance/read
|
||||||
|
Identity,conv2d_4/kernel/read
|
||||||
|
Identity,batch_normalization_3/gamma/read
|
||||||
|
Identity,batch_normalization_3/beta/read
|
||||||
|
Identity,batch_normalization_3/moving_mean/read
|
||||||
|
Identity,batch_normalization_3/moving_variance/read
|
||||||
|
Identity,conv2d_5/kernel/read
|
||||||
|
Identity,batch_normalization_4/gamma/read
|
||||||
|
Identity,batch_normalization_4/beta/read
|
||||||
|
Identity,batch_normalization_4/moving_mean/read
|
||||||
|
Identity,batch_normalization_4/moving_variance/read
|
||||||
|
Identity,conv2d_6/kernel/read
|
||||||
|
Identity,batch_normalization_5/gamma/read
|
||||||
|
Identity,batch_normalization_5/beta/read
|
||||||
|
Identity,batch_normalization_5/moving_mean/read
|
||||||
|
Identity,batch_normalization_5/moving_variance/read
|
||||||
|
Identity,conv2d_7/kernel/read
|
||||||
|
Identity,batch_normalization_6/gamma/read
|
||||||
|
Identity,batch_normalization_6/beta/read
|
||||||
|
Identity,batch_normalization_6/moving_mean/read
|
||||||
|
Identity,batch_normalization_6/moving_variance/read
|
||||||
|
Identity,conv2d_8/kernel/read
|
||||||
|
Identity,batch_normalization_7/gamma/read
|
||||||
|
Identity,batch_normalization_7/beta/read
|
||||||
|
Identity,batch_normalization_7/moving_mean/read
|
||||||
|
Identity,batch_normalization_7/moving_variance/read
|
||||||
|
Identity,conv2d_9/kernel/read
|
||||||
|
Identity,batch_normalization_8/gamma/read
|
||||||
|
Identity,batch_normalization_8/beta/read
|
||||||
|
Identity,batch_normalization_8/moving_mean/read
|
||||||
|
Identity,batch_normalization_8/moving_variance/read
|
||||||
|
Identity,conv2d_10/kernel/read
|
||||||
|
Identity,batch_normalization_9/gamma/read
|
||||||
|
Identity,batch_normalization_9/beta/read
|
||||||
|
Identity,batch_normalization_9/moving_mean/read
|
||||||
|
Identity,batch_normalization_9/moving_variance/read
|
||||||
|
Identity,conv2d_11/kernel/read
|
||||||
|
Identity,conv2d_12/kernel/read
|
||||||
|
Identity,batch_normalization_10/gamma/read
|
||||||
|
Identity,batch_normalization_10/beta/read
|
||||||
|
Identity,batch_normalization_10/moving_mean/read
|
||||||
|
Identity,batch_normalization_10/moving_variance/read
|
||||||
|
Identity,conv2d_13/kernel/read
|
||||||
|
Identity,batch_normalization_11/gamma/read
|
||||||
|
Identity,batch_normalization_11/beta/read
|
||||||
|
Identity,batch_normalization_11/moving_mean/read
|
||||||
|
Identity,batch_normalization_11/moving_variance/read
|
||||||
|
Identity,conv2d_14/kernel/read
|
||||||
|
Identity,batch_normalization_12/gamma/read
|
||||||
|
Identity,batch_normalization_12/beta/read
|
||||||
|
Identity,batch_normalization_12/moving_mean/read
|
||||||
|
Identity,batch_normalization_12/moving_variance/read
|
||||||
|
Identity,conv2d_15/kernel/read
|
||||||
|
Identity,batch_normalization_13/gamma/read
|
||||||
|
Identity,batch_normalization_13/beta/read
|
||||||
|
Identity,batch_normalization_13/moving_mean/read
|
||||||
|
Identity,batch_normalization_13/moving_variance/read
|
||||||
|
Identity,conv2d_16/kernel/read
|
||||||
|
Identity,batch_normalization_14/gamma/read
|
||||||
|
Identity,batch_normalization_14/beta/read
|
||||||
|
Identity,batch_normalization_14/moving_mean/read
|
||||||
|
Identity,batch_normalization_14/moving_variance/read
|
||||||
|
Identity,conv2d_17/kernel/read
|
||||||
|
Identity,batch_normalization_15/gamma/read
|
||||||
|
Identity,batch_normalization_15/beta/read
|
||||||
|
Identity,batch_normalization_15/moving_mean/read
|
||||||
|
Identity,batch_normalization_15/moving_variance/read
|
||||||
|
Identity,conv2d_18/kernel/read
|
||||||
|
Identity,batch_normalization_16/gamma/read
|
||||||
|
Identity,batch_normalization_16/beta/read
|
||||||
|
Identity,batch_normalization_16/moving_mean/read
|
||||||
|
Identity,batch_normalization_16/moving_variance/read
|
||||||
|
Identity,conv2d_19/kernel/read
|
||||||
|
Identity,batch_normalization_17/gamma/read
|
||||||
|
Identity,batch_normalization_17/beta/read
|
||||||
|
Identity,batch_normalization_17/moving_mean/read
|
||||||
|
Identity,batch_normalization_17/moving_variance/read
|
||||||
|
Identity,conv2d_20/kernel/read
|
||||||
|
Identity,batch_normalization_18/gamma/read
|
||||||
|
Identity,batch_normalization_18/beta/read
|
||||||
|
Identity,batch_normalization_18/moving_mean/read
|
||||||
|
Identity,batch_normalization_18/moving_variance/read
|
||||||
|
Identity,conv2d_21/kernel/read
|
||||||
|
Identity,batch_normalization_19/gamma/read
|
||||||
|
Identity,batch_normalization_19/beta/read
|
||||||
|
Identity,batch_normalization_19/moving_mean/read
|
||||||
|
Identity,batch_normalization_19/moving_variance/read
|
||||||
|
Identity,conv2d_22/kernel/read
|
||||||
|
Identity,batch_normalization_20/gamma/read
|
||||||
|
Identity,batch_normalization_20/beta/read
|
||||||
|
Identity,batch_normalization_20/moving_mean/read
|
||||||
|
Identity,batch_normalization_20/moving_variance/read
|
||||||
|
Identity,conv2d_23/kernel/read
|
||||||
|
Identity,batch_normalization_21/gamma/read
|
||||||
|
Identity,batch_normalization_21/beta/read
|
||||||
|
Identity,batch_normalization_21/moving_mean/read
|
||||||
|
Identity,batch_normalization_21/moving_variance/read
|
||||||
|
Identity,conv2d_24/kernel/read
|
||||||
|
Identity,conv2d_25/kernel/read
|
||||||
|
Identity,batch_normalization_22/gamma/read
|
||||||
|
Identity,batch_normalization_22/beta/read
|
||||||
|
Identity,batch_normalization_22/moving_mean/read
|
||||||
|
Identity,batch_normalization_22/moving_variance/read
|
||||||
|
Identity,conv2d_26/kernel/read
|
||||||
|
Identity,batch_normalization_23/gamma/read
|
||||||
|
Identity,batch_normalization_23/beta/read
|
||||||
|
Identity,batch_normalization_23/moving_mean/read
|
||||||
|
Identity,batch_normalization_23/moving_variance/read
|
||||||
|
Identity,conv2d_27/kernel/read
|
||||||
|
Identity,batch_normalization_24/gamma/read
|
||||||
|
Identity,batch_normalization_24/beta/read
|
||||||
|
Identity,batch_normalization_24/moving_mean/read
|
||||||
|
Identity,batch_normalization_24/moving_variance/read
|
||||||
|
Identity,conv2d_28/kernel/read
|
||||||
|
Identity,batch_normalization_25/gamma/read
|
||||||
|
Identity,batch_normalization_25/beta/read
|
||||||
|
Identity,batch_normalization_25/moving_mean/read
|
||||||
|
Identity,batch_normalization_25/moving_variance/read
|
||||||
|
Identity,conv2d_29/kernel/read
|
||||||
|
Identity,batch_normalization_26/gamma/read
|
||||||
|
Identity,batch_normalization_26/beta/read
|
||||||
|
Identity,batch_normalization_26/moving_mean/read
|
||||||
|
Identity,batch_normalization_26/moving_variance/read
|
||||||
|
Identity,conv2d_30/kernel/read
|
||||||
|
Identity,batch_normalization_27/gamma/read
|
||||||
|
Identity,batch_normalization_27/beta/read
|
||||||
|
Identity,batch_normalization_27/moving_mean/read
|
||||||
|
Identity,batch_normalization_27/moving_variance/read
|
||||||
|
Identity,conv2d_31/kernel/read
|
||||||
|
Identity,batch_normalization_28/gamma/read
|
||||||
|
Identity,batch_normalization_28/beta/read
|
||||||
|
Identity,batch_normalization_28/moving_mean/read
|
||||||
|
Identity,batch_normalization_28/moving_variance/read
|
||||||
|
Identity,conv2d_32/kernel/read
|
||||||
|
Identity,batch_normalization_29/gamma/read
|
||||||
|
Identity,batch_normalization_29/beta/read
|
||||||
|
Identity,batch_normalization_29/moving_mean/read
|
||||||
|
Identity,batch_normalization_29/moving_variance/read
|
||||||
|
Identity,conv2d_33/kernel/read
|
||||||
|
Identity,batch_normalization_30/gamma/read
|
||||||
|
Identity,batch_normalization_30/beta/read
|
||||||
|
Identity,batch_normalization_30/moving_mean/read
|
||||||
|
Identity,batch_normalization_30/moving_variance/read
|
||||||
|
Identity,conv2d_34/kernel/read
|
||||||
|
Identity,batch_normalization_31/gamma/read
|
||||||
|
Identity,batch_normalization_31/beta/read
|
||||||
|
Identity,batch_normalization_31/moving_mean/read
|
||||||
|
Identity,batch_normalization_31/moving_variance/read
|
||||||
|
Identity,conv2d_35/kernel/read
|
||||||
|
Identity,batch_normalization_32/gamma/read
|
||||||
|
Identity,batch_normalization_32/beta/read
|
||||||
|
Identity,batch_normalization_32/moving_mean/read
|
||||||
|
Identity,batch_normalization_32/moving_variance/read
|
||||||
|
Identity,conv2d_36/kernel/read
|
||||||
|
Identity,batch_normalization_33/gamma/read
|
||||||
|
Identity,batch_normalization_33/beta/read
|
||||||
|
Identity,batch_normalization_33/moving_mean/read
|
||||||
|
Identity,batch_normalization_33/moving_variance/read
|
||||||
|
Identity,conv2d_37/kernel/read
|
||||||
|
Identity,batch_normalization_34/gamma/read
|
||||||
|
Identity,batch_normalization_34/beta/read
|
||||||
|
Identity,batch_normalization_34/moving_mean/read
|
||||||
|
Identity,batch_normalization_34/moving_variance/read
|
||||||
|
Identity,conv2d_38/kernel/read
|
||||||
|
Identity,batch_normalization_35/gamma/read
|
||||||
|
Identity,batch_normalization_35/beta/read
|
||||||
|
Identity,batch_normalization_35/moving_mean/read
|
||||||
|
Identity,batch_normalization_35/moving_variance/read
|
||||||
|
Identity,conv2d_39/kernel/read
|
||||||
|
Identity,batch_normalization_36/gamma/read
|
||||||
|
Identity,batch_normalization_36/beta/read
|
||||||
|
Identity,batch_normalization_36/moving_mean/read
|
||||||
|
Identity,batch_normalization_36/moving_variance/read
|
||||||
|
Identity,conv2d_40/kernel/read
|
||||||
|
Identity,batch_normalization_37/gamma/read
|
||||||
|
Identity,batch_normalization_37/beta/read
|
||||||
|
Identity,batch_normalization_37/moving_mean/read
|
||||||
|
Identity,batch_normalization_37/moving_variance/read
|
||||||
|
Identity,conv2d_41/kernel/read
|
||||||
|
Identity,batch_normalization_38/gamma/read
|
||||||
|
Identity,batch_normalization_38/beta/read
|
||||||
|
Identity,batch_normalization_38/moving_mean/read
|
||||||
|
Identity,batch_normalization_38/moving_variance/read
|
||||||
|
Identity,conv2d_42/kernel/read
|
||||||
|
Identity,batch_normalization_39/gamma/read
|
||||||
|
Identity,batch_normalization_39/beta/read
|
||||||
|
Identity,batch_normalization_39/moving_mean/read
|
||||||
|
Identity,batch_normalization_39/moving_variance/read
|
||||||
|
Identity,conv2d_43/kernel/read
|
||||||
|
Identity,conv2d_44/kernel/read
|
||||||
|
Identity,batch_normalization_40/gamma/read
|
||||||
|
Identity,batch_normalization_40/beta/read
|
||||||
|
Identity,batch_normalization_40/moving_mean/read
|
||||||
|
Identity,batch_normalization_40/moving_variance/read
|
||||||
|
Identity,conv2d_45/kernel/read
|
||||||
|
Identity,batch_normalization_41/gamma/read
|
||||||
|
Identity,batch_normalization_41/beta/read
|
||||||
|
Identity,batch_normalization_41/moving_mean/read
|
||||||
|
Identity,batch_normalization_41/moving_variance/read
|
||||||
|
Identity,conv2d_46/kernel/read
|
||||||
|
Identity,batch_normalization_42/gamma/read
|
||||||
|
Identity,batch_normalization_42/beta/read
|
||||||
|
Identity,batch_normalization_42/moving_mean/read
|
||||||
|
Identity,batch_normalization_42/moving_variance/read
|
||||||
|
Identity,conv2d_47/kernel/read
|
||||||
|
Identity,batch_normalization_43/gamma/read
|
||||||
|
Identity,batch_normalization_43/beta/read
|
||||||
|
Identity,batch_normalization_43/moving_mean/read
|
||||||
|
Identity,batch_normalization_43/moving_variance/read
|
||||||
|
Identity,conv2d_48/kernel/read
|
||||||
|
Identity,batch_normalization_44/gamma/read
|
||||||
|
Identity,batch_normalization_44/beta/read
|
||||||
|
Identity,batch_normalization_44/moving_mean/read
|
||||||
|
Identity,batch_normalization_44/moving_variance/read
|
||||||
|
Identity,conv2d_49/kernel/read
|
||||||
|
Identity,batch_normalization_45/gamma/read
|
||||||
|
Identity,batch_normalization_45/beta/read
|
||||||
|
Identity,batch_normalization_45/moving_mean/read
|
||||||
|
Identity,batch_normalization_45/moving_variance/read
|
||||||
|
Identity,conv2d_50/kernel/read
|
||||||
|
Identity,batch_normalization_46/gamma/read
|
||||||
|
Identity,batch_normalization_46/beta/read
|
||||||
|
Identity,batch_normalization_46/moving_mean/read
|
||||||
|
Identity,batch_normalization_46/moving_variance/read
|
||||||
|
Identity,conv2d_51/kernel/read
|
||||||
|
Identity,batch_normalization_47/gamma/read
|
||||||
|
Identity,batch_normalization_47/beta/read
|
||||||
|
Identity,batch_normalization_47/moving_mean/read
|
||||||
|
Identity,batch_normalization_47/moving_variance/read
|
||||||
|
Identity,conv2d_52/kernel/read
|
||||||
|
Identity,batch_normalization_48/gamma/read
|
||||||
|
Identity,batch_normalization_48/beta/read
|
||||||
|
Identity,batch_normalization_48/moving_mean/read
|
||||||
|
Identity,batch_normalization_48/moving_variance/read
|
||||||
|
Identity,dense/kernel/read
|
||||||
|
Identity,dense/bias/read
|
||||||
|
Pad,Pad
|
||||||
|
Conv2D,conv2d/Conv2D
|
||||||
|
Identity,initial_conv
|
||||||
|
MaxPool,max_pooling2d/MaxPool
|
||||||
|
Identity,initial_max_pool
|
||||||
|
FusedBatchNorm,batch_normalization/FusedBatchNorm
|
||||||
|
Relu,Relu
|
||||||
|
Conv2D,conv2d_1/Conv2D
|
||||||
|
Conv2D,conv2d_2/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_1/FusedBatchNorm
|
||||||
|
Relu,Relu_1
|
||||||
|
Conv2D,conv2d_3/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_2/FusedBatchNorm
|
||||||
|
Relu,Relu_2
|
||||||
|
Conv2D,conv2d_4/Conv2D
|
||||||
|
Add,add
|
||||||
|
FusedBatchNorm,batch_normalization_3/FusedBatchNorm
|
||||||
|
Relu,Relu_3
|
||||||
|
Conv2D,conv2d_5/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_4/FusedBatchNorm
|
||||||
|
Relu,Relu_4
|
||||||
|
Conv2D,conv2d_6/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_5/FusedBatchNorm
|
||||||
|
Relu,Relu_5
|
||||||
|
Conv2D,conv2d_7/Conv2D
|
||||||
|
Add,add_1
|
||||||
|
FusedBatchNorm,batch_normalization_6/FusedBatchNorm
|
||||||
|
Relu,Relu_6
|
||||||
|
Conv2D,conv2d_8/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_7/FusedBatchNorm
|
||||||
|
Relu,Relu_7
|
||||||
|
Conv2D,conv2d_9/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_8/FusedBatchNorm
|
||||||
|
Relu,Relu_8
|
||||||
|
Conv2D,conv2d_10/Conv2D
|
||||||
|
Add,add_2
|
||||||
|
Identity,block_layer1
|
||||||
|
FusedBatchNorm,batch_normalization_9/FusedBatchNorm
|
||||||
|
Relu,Relu_9
|
||||||
|
Pad,Pad_1
|
||||||
|
Conv2D,conv2d_12/Conv2D
|
||||||
|
Conv2D,conv2d_11/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_10/FusedBatchNorm
|
||||||
|
Relu,Relu_10
|
||||||
|
Pad,Pad_2
|
||||||
|
Conv2D,conv2d_13/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_11/FusedBatchNorm
|
||||||
|
Relu,Relu_11
|
||||||
|
Conv2D,conv2d_14/Conv2D
|
||||||
|
Add,add_3
|
||||||
|
FusedBatchNorm,batch_normalization_12/FusedBatchNorm
|
||||||
|
Relu,Relu_12
|
||||||
|
Conv2D,conv2d_15/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_13/FusedBatchNorm
|
||||||
|
Relu,Relu_13
|
||||||
|
Conv2D,conv2d_16/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_14/FusedBatchNorm
|
||||||
|
Relu,Relu_14
|
||||||
|
Conv2D,conv2d_17/Conv2D
|
||||||
|
Add,add_4
|
||||||
|
FusedBatchNorm,batch_normalization_15/FusedBatchNorm
|
||||||
|
Relu,Relu_15
|
||||||
|
Conv2D,conv2d_18/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_16/FusedBatchNorm
|
||||||
|
Relu,Relu_16
|
||||||
|
Conv2D,conv2d_19/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_17/FusedBatchNorm
|
||||||
|
Relu,Relu_17
|
||||||
|
Conv2D,conv2d_20/Conv2D
|
||||||
|
Add,add_5
|
||||||
|
FusedBatchNorm,batch_normalization_18/FusedBatchNorm
|
||||||
|
Relu,Relu_18
|
||||||
|
Conv2D,conv2d_21/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_19/FusedBatchNorm
|
||||||
|
Relu,Relu_19
|
||||||
|
Conv2D,conv2d_22/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_20/FusedBatchNorm
|
||||||
|
Relu,Relu_20
|
||||||
|
Conv2D,conv2d_23/Conv2D
|
||||||
|
Add,add_6
|
||||||
|
Identity,block_layer2
|
||||||
|
FusedBatchNorm,batch_normalization_21/FusedBatchNorm
|
||||||
|
Relu,Relu_21
|
||||||
|
Pad,Pad_3
|
||||||
|
Conv2D,conv2d_25/Conv2D
|
||||||
|
Conv2D,conv2d_24/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_22/FusedBatchNorm
|
||||||
|
Relu,Relu_22
|
||||||
|
Pad,Pad_4
|
||||||
|
Conv2D,conv2d_26/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_23/FusedBatchNorm
|
||||||
|
Relu,Relu_23
|
||||||
|
Conv2D,conv2d_27/Conv2D
|
||||||
|
Add,add_7
|
||||||
|
FusedBatchNorm,batch_normalization_24/FusedBatchNorm
|
||||||
|
Relu,Relu_24
|
||||||
|
Conv2D,conv2d_28/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_25/FusedBatchNorm
|
||||||
|
Relu,Relu_25
|
||||||
|
Conv2D,conv2d_29/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_26/FusedBatchNorm
|
||||||
|
Relu,Relu_26
|
||||||
|
Conv2D,conv2d_30/Conv2D
|
||||||
|
Add,add_8
|
||||||
|
FusedBatchNorm,batch_normalization_27/FusedBatchNorm
|
||||||
|
Relu,Relu_27
|
||||||
|
Conv2D,conv2d_31/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_28/FusedBatchNorm
|
||||||
|
Relu,Relu_28
|
||||||
|
Conv2D,conv2d_32/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_29/FusedBatchNorm
|
||||||
|
Relu,Relu_29
|
||||||
|
Conv2D,conv2d_33/Conv2D
|
||||||
|
Add,add_9
|
||||||
|
FusedBatchNorm,batch_normalization_30/FusedBatchNorm
|
||||||
|
Relu,Relu_30
|
||||||
|
Conv2D,conv2d_34/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_31/FusedBatchNorm
|
||||||
|
Relu,Relu_31
|
||||||
|
Conv2D,conv2d_35/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_32/FusedBatchNorm
|
||||||
|
Relu,Relu_32
|
||||||
|
Conv2D,conv2d_36/Conv2D
|
||||||
|
Add,add_10
|
||||||
|
FusedBatchNorm,batch_normalization_33/FusedBatchNorm
|
||||||
|
Relu,Relu_33
|
||||||
|
Conv2D,conv2d_37/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_34/FusedBatchNorm
|
||||||
|
Relu,Relu_34
|
||||||
|
Conv2D,conv2d_38/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_35/FusedBatchNorm
|
||||||
|
Relu,Relu_35
|
||||||
|
Conv2D,conv2d_39/Conv2D
|
||||||
|
Add,add_11
|
||||||
|
FusedBatchNorm,batch_normalization_36/FusedBatchNorm
|
||||||
|
Relu,Relu_36
|
||||||
|
Conv2D,conv2d_40/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_37/FusedBatchNorm
|
||||||
|
Relu,Relu_37
|
||||||
|
Conv2D,conv2d_41/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_38/FusedBatchNorm
|
||||||
|
Relu,Relu_38
|
||||||
|
Conv2D,conv2d_42/Conv2D
|
||||||
|
Add,add_12
|
||||||
|
Identity,block_layer3
|
||||||
|
FusedBatchNorm,batch_normalization_39/FusedBatchNorm
|
||||||
|
Relu,Relu_39
|
||||||
|
Pad,Pad_5
|
||||||
|
Conv2D,conv2d_44/Conv2D
|
||||||
|
Conv2D,conv2d_43/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_40/FusedBatchNorm
|
||||||
|
Relu,Relu_40
|
||||||
|
Pad,Pad_6
|
||||||
|
Conv2D,conv2d_45/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_41/FusedBatchNorm
|
||||||
|
Relu,Relu_41
|
||||||
|
Conv2D,conv2d_46/Conv2D
|
||||||
|
Add,add_13
|
||||||
|
FusedBatchNorm,batch_normalization_42/FusedBatchNorm
|
||||||
|
Relu,Relu_42
|
||||||
|
Conv2D,conv2d_47/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_43/FusedBatchNorm
|
||||||
|
Relu,Relu_43
|
||||||
|
Conv2D,conv2d_48/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_44/FusedBatchNorm
|
||||||
|
Relu,Relu_44
|
||||||
|
Conv2D,conv2d_49/Conv2D
|
||||||
|
Add,add_14
|
||||||
|
FusedBatchNorm,batch_normalization_45/FusedBatchNorm
|
||||||
|
Relu,Relu_45
|
||||||
|
Conv2D,conv2d_50/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_46/FusedBatchNorm
|
||||||
|
Relu,Relu_46
|
||||||
|
Conv2D,conv2d_51/Conv2D
|
||||||
|
FusedBatchNorm,batch_normalization_47/FusedBatchNorm
|
||||||
|
Relu,Relu_47
|
||||||
|
Conv2D,conv2d_52/Conv2D
|
||||||
|
Add,add_15
|
||||||
|
Identity,block_layer4
|
||||||
|
FusedBatchNorm,batch_normalization_48/FusedBatchNorm
|
||||||
|
Relu,Relu_48
|
||||||
|
Mean,Mean
|
||||||
|
Identity,final_reduce_mean
|
||||||
|
Reshape,Reshape
|
||||||
|
MatMul,dense/MatMul
|
||||||
|
BiasAdd,dense/BiasAdd
|
||||||
|
Identity,final_dense
|
||||||
|
ArgMax,ArgMax
|
||||||
|
Softmax,softmax_tensor
|
||||||
|
|
|
@ -471,7 +471,7 @@
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine> -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
<argLine> -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -343,7 +343,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIm2Col() {
|
public void testIm2Col() {
|
||||||
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
|
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
|
||||||
|
|
|
@ -480,7 +480,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
dLdInExpected_1.putColumn(i, prod_1);
|
dLdInExpected_1.putColumn(i, prod_1);
|
||||||
}
|
}
|
||||||
dLdInExpected_1.divi(preReduceInput);
|
dLdInExpected_1.divi(preReduceInput);
|
||||||
dLdInExpected_1.muliColumnVector(dLdOut_1.reshape(3, 1)); //Reshape is a hack around https://github.com/deeplearning4j/deeplearning4j/issues/5530
|
dLdInExpected_1.muliColumnVector(dLdOut_1.reshape(3, 1)); //Reshape is a hack around https://github.com/eclipse/deeplearning4j/issues/5530
|
||||||
//System.out.println(dLdInExpected_1);
|
//System.out.println(dLdInExpected_1);
|
||||||
/*
|
/*
|
||||||
[[ 24.0000, 12.0000, 8.0000, 6.0000],
|
[[ 24.0000, 12.0000, 8.0000, 6.0000],
|
||||||
|
|
|
@ -2004,7 +2004,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
@Test
|
@Test
|
||||||
public void testCastEmpty(){
|
public void testCastEmpty(){
|
||||||
INDArray emptyLong = Nd4j.empty(DataType.LONG);
|
INDArray emptyLong = Nd4j.empty(DataType.LONG);
|
||||||
int dtype = 9; //INT = 9 - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
|
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("cast")
|
DynamicCustomOp op = DynamicCustomOp.builder("cast")
|
||||||
.addInputs(emptyLong)
|
.addInputs(emptyLong)
|
||||||
.addIntegerArguments(dtype)
|
.addIntegerArguments(dtype)
|
||||||
|
|
|
@ -326,7 +326,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchToSpace() {
|
public void testBatchToSpace() {
|
||||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
int miniBatch = 4;
|
int miniBatch = 4;
|
||||||
|
@ -363,7 +363,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSpaceToBatch() {
|
public void testSpaceToBatch() {
|
||||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(7331);
|
Nd4j.getRandom().setSeed(7331);
|
||||||
|
|
||||||
|
@ -1281,7 +1281,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
out = sd.math().isInfinite(in);
|
out = sd.math().isInfinite(in);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
//TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
//TODO: IsMax supports both bool and float out: https://github.com/eclipse/deeplearning4j/issues/6872
|
||||||
inArr = Nd4j.create(new double[]{-3, 5, 0, 2});
|
inArr = Nd4j.create(new double[]{-3, 5, 0, 2});
|
||||||
exp = Nd4j.create(new boolean[]{false, true, false, false});
|
exp = Nd4j.create(new boolean[]{false, true, false, false});
|
||||||
out = sd.math().isMax(in);
|
out = sd.math().isMax(in);
|
||||||
|
|
|
@ -61,10 +61,10 @@ public class ExecutionTests extends BaseNd4jTest {
|
||||||
if(TFGraphTestZooModels.isPPC()){
|
if(TFGraphTestZooModels.isPPC()){
|
||||||
/*
|
/*
|
||||||
Ugly hack to temporarily disable tests on PPC only on CI
|
Ugly hack to temporarily disable tests on PPC only on CI
|
||||||
Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657
|
Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657
|
||||||
These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions
|
These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions
|
||||||
*/
|
*/
|
||||||
log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657");
|
log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657");
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue