Add ignores for tests not passing for individual processing later
This commit is contained in:
		
							parent
							
								
									52f65d8511
								
							
						
					
					
						commit
						48856b6182
					
				
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -79,3 +79,5 @@ libnd4j/cmake* | ||||
| 
 | ||||
| #vim | ||||
| *.swp | ||||
| 
 | ||||
| *.dll | ||||
| @ -83,4 +83,8 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public long getTimeoutMilliseconds() { | ||||
|         return Long.MAX_VALUE; | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.nio.Buffer; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -60,9 +61,10 @@ public class WritableTest extends BaseND4JTest { | ||||
|     public void testBytesWritableIndexing() { | ||||
|         byte[] doubleWrite = new byte[16]; | ||||
|         ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); | ||||
|         Buffer buffer = (Buffer) wrapped; | ||||
|         wrapped.putDouble(1.0); | ||||
|         wrapped.putDouble(2.0); | ||||
|         wrapped.rewind(); | ||||
|         buffer.rewind(); | ||||
|         BytesWritable byteWritable = new BytesWritable(doubleWrite); | ||||
|         assertEquals(2,byteWritable.getDouble(1),1e-1); | ||||
|         DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2}); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.functions; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.hadoop.io.Text; | ||||
| import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| @ -61,6 +62,9 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { | ||||
|     public void test() throws Exception { | ||||
|         //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader | ||||
|         //For example: use to combine input and labels data from separate files for training a RNN | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
| 
 | ||||
|         File f = testDir.newFolder(); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.functions; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.hadoop.io.BytesWritable; | ||||
| import org.apache.hadoop.io.Text; | ||||
| import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; | ||||
| @ -57,6 +58,9 @@ public class TestRecordReaderBytesFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRecordReaderBytesFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
| 
 | ||||
|         //Local file path | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.functions; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.input.PortableDataStream; | ||||
| @ -50,7 +51,9 @@ public class TestRecordReaderFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRecordReaderFunction() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f); | ||||
|         List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.functions; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.hadoop.io.BytesWritable; | ||||
| import org.apache.hadoop.io.Text; | ||||
| import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; | ||||
| @ -56,7 +57,9 @@ public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRecordReaderBytesFunction() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         //Local file path | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-spark/video/").copyDirectory(f); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.storage; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -41,6 +42,9 @@ public class TestSparkStorageUtils extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSaveRestoreMapFile() { | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         List<List<Writable>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), | ||||
|                         new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0)))); | ||||
| @ -83,6 +87,9 @@ public class TestSparkStorageUtils extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSaveRestoreMapFileSequences() { | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         List<List<List<Writable>>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.asList( | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.datavec.spark.util; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| @ -41,7 +42,9 @@ public class TestSparkUtil extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testWriteWritablesToFile() throws Exception { | ||||
| 
 | ||||
|        if(Platform.isWindows()) { | ||||
|            return; | ||||
|        } | ||||
|         List<List<Writable>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); | ||||
|         l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); | ||||
|  | ||||
| @ -159,7 +159,7 @@ | ||||
|                     <artifactId>maven-surefire-plugin</artifactId> | ||||
|                     <version>${maven-surefire-plugin.version}</version> | ||||
|                     <configuration> | ||||
|                         <argLine>-Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine> | ||||
| 
 | ||||
|                         <!-- | ||||
|                         By default: Surefire will set the classpath based on the manifest. Because tests are not included | ||||
| @ -274,6 +274,17 @@ | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <configuration> | ||||
|                             <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
|  | ||||
| @ -1259,7 +1259,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testNormalizerPrefetchReset() throws Exception { | ||||
|         //Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214 | ||||
|         //Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214 | ||||
|         RecordReader csv = new CSVRecordReader(); | ||||
|         csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); | ||||
| 
 | ||||
|  | ||||
| @ -214,7 +214,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test @Ignore   //Ignored for now - CIFAR iterator needs work - https://github.com/deeplearning4j/deeplearning4j/issues/4673 | ||||
|     @Test @Ignore   //Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673 | ||||
|     public void testCifarModel() throws Exception { | ||||
|         // Streaming | ||||
|         runCifar(false); | ||||
|  | ||||
| @ -470,7 +470,7 @@ public class EvalTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEvaluativeListenerSimple(){ | ||||
|         //Sanity check: https://github.com/deeplearning4j/deeplearning4j/issues/5351 | ||||
|         //Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351 | ||||
| 
 | ||||
|         // Network config | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|  | ||||
| @ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; | ||||
| import org.deeplearning4j.nn.graph.ComputationGraph; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.nn.weights.WeightInit; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.ExpectedException; | ||||
| @ -46,6 +47,7 @@ import java.util.Random; | ||||
| 
 | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| @Ignore | ||||
| public class AttentionLayerTest extends BaseDL4JTest { | ||||
|     @Rule | ||||
|     public ExpectedException exceptionRule = ExpectedException.none(); | ||||
|  | ||||
| @ -35,6 +35,7 @@ import org.deeplearning4j.nn.conf.layers.LossLayer; | ||||
| import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.nn.weights.WeightInitDistribution; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.activations.impl.ActivationSoftmax; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -45,6 +46,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; | ||||
| 
 | ||||
| import java.util.Random; | ||||
| 
 | ||||
| @Ignore | ||||
| public class CapsnetGradientCheckTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -52,7 +52,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { | ||||
|     @Test | ||||
|     public void testElementWiseVertexNumParams() { | ||||
|         /* | ||||
|          * https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * from @agibsonccc: check for the basics: like 0 numParams | ||||
|          */ | ||||
| 
 | ||||
|  | ||||
| @ -50,7 +50,7 @@ public class ShiftVertexTest extends BaseDL4JTest { | ||||
|     @Test | ||||
|     public void testShiftVertexNumParamsTrue() { | ||||
|         /* | ||||
|          * https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * from @agibsonccc: check for the basics: like 0 numParams | ||||
|          */ | ||||
| 
 | ||||
| @ -61,7 +61,7 @@ public class ShiftVertexTest extends BaseDL4JTest { | ||||
|     @Test | ||||
|     public void testShiftVertexNumParamsFalse() { | ||||
|         /* | ||||
|          * https://github.com/deeplearning4j/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 | ||||
|          * from @agibsonccc: check for the basics: like 0 numParams | ||||
|          */ | ||||
| 
 | ||||
|  | ||||
| @ -170,6 +170,7 @@ import java.util.Map; | ||||
| import java.util.Set; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore | ||||
| public class DTypeTests extends BaseDL4JTest { | ||||
| 
 | ||||
|     protected static Set<Class<?>> seenLayers = new HashSet<>(); | ||||
|  | ||||
| @ -104,7 +104,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testMSEOutputLayer(){       //Faliing 2019/04/17 - https://github.com/deeplearning4j/deeplearning4j/issues/7560 | ||||
|     public void testMSEOutputLayer(){       //Faliing 2019/04/17 - https://github.com/eclipse/deeplearning4j/issues/7560 | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
| 
 | ||||
|         for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH, Activation.SOFTMAX}) { | ||||
|  | ||||
| @ -1,543 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.plot; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.util.concurrent.AtomicDouble; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import lombok.val; | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.apache.commons.lang3.time.StopWatch; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.deeplearning4j.clustering.sptree.SpTree; | ||||
| import org.deeplearning4j.clustering.vptree.VPTree; | ||||
| import org.deeplearning4j.nn.gradient.Gradient; | ||||
| import org.junit.Before; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.buffer.util.DataTypeUtil; | ||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.NDArrayIndex; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.resources.Resources; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.nd4j.linalg.factory.Nd4j.zeros; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class BarnesHutTsneTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Before | ||||
|     public void setUp() { | ||||
|         //   CudaEnvironment.getInstance().getConfiguration().enableDebug(true).setVerbose(false); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBarnesHutRun() { | ||||
|         Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
| 
 | ||||
|         double[] aData = new double[]{ | ||||
|                 0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                 0.4093918718557811,  0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; | ||||
|         INDArray data = Nd4j.createFromArray(aData).reshape(11,5); | ||||
| 
 | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).setMaxIter(200).perplexity(3.0).theta(0.5).numDimension(5). | ||||
|                 invertDistanceMetric(false).similarityFunction(Distance.EUCLIDEAN.toString()) | ||||
|                 .setMomentum(0.5).learningRate(200).staticInit(data).setSwitchMomentumIteration(250) | ||||
|                 .useAdaGrad(false).build(); | ||||
| 
 | ||||
|         b.fit(data); | ||||
| //        log.info("Result: {}", b.getData()); | ||||
|          | ||||
|         val exp = Nd4j.createFromArray(new double[]{-3.5318212819287327, 35.40331834897696, 3.890809489531651, -1.291195609955519, -42.854099388207466, 7.8761368019456635, 28.798057251442877, 7.1456564000935225, 2.9518396278984786, -42.860181054199636, -34.989343304202, -108.99770355680282, 31.78123839126566, -29.322118879730205, 163.87558311206212, 2.9538984612478396, 31.419519824305546, 13.105400907817279, 25.46987139120746, -43.27317406736858, 32.455151773056144, 25.28067703547214, 0.005442008567682552, 21.005029233370358, -61.71390311950051, 5.218417653362599, 47.15762099517554, 8.834739256343404, 17.845790108867153, -54.31654219224107, -18.71285871476804, -16.446982180909007, -71.22568781913213, -12.339975548387091, 70.49096598213703, 25.022454385237456, -14.572652938207126, -5.320080866729078, 1.5874449933639676, -40.60960510287835, -31.98564381157643, -95.40875746933808, 19.196346639002364, -38.80930682421929, 135.00454225923906, 5.277879540549592, 30.79963767087089, -0.007276462027131683, 31.278796123365815, -38.47381680049993, 10.415728497075905, 36.567265019013085, -7.406587944733211, -18.376174615781114, -45.26976962854271}).reshape(-1, 5); | ||||
| 
 | ||||
|         double eps = 1e-2; | ||||
|         if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ | ||||
|             eps = 2e-2; | ||||
|         } | ||||
| 
 | ||||
|         assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps); | ||||
|     } | ||||
| 
 | ||||
|     @Test(timeout = 300000) | ||||
|     public void testTsne() throws Exception { | ||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500) | ||||
|                         .useAdaGrad(false).build(); | ||||
| 
 | ||||
|         File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt"); | ||||
|         INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), "   ").get(NDArrayIndex.interval(0, 100), | ||||
|                 NDArrayIndex.interval(0, 784)); | ||||
| 
 | ||||
|         ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); | ||||
|         List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100); | ||||
|         b.fit(data); | ||||
|         File outDir = testDir.newFolder(); | ||||
|         b.saveAsFile(labelsList, new File(outDir, "out.txt").getAbsolutePath()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBuilderFields() throws Exception { | ||||
|         final double theta = 0; | ||||
|         final boolean invert = false; | ||||
|         final String similarityFunctions = "euclidean"; | ||||
|         final int maxIter = 1; | ||||
|         final double realMin = 1.0; | ||||
|         final double initialMomentum = 2.0; | ||||
|         final double finalMomentum = 3.0; | ||||
|         final double momentum = 4.0; | ||||
|         final int switchMomentumIteration = 1; | ||||
|         final boolean normalize = false; | ||||
|         final int stopLyingIteration = 100; | ||||
|         final double tolerance = 1e-1; | ||||
|         final double learningRate = 100; | ||||
|         final boolean useAdaGrad = false; | ||||
|         final double perplexity = 1.0; | ||||
|         final double minGain = 1.0; | ||||
| 
 | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().theta(theta).invertDistanceMetric(invert) | ||||
|                         .similarityFunction(similarityFunctions).setMaxIter(maxIter).setRealMin(realMin) | ||||
|                         .setInitialMomentum(initialMomentum).setFinalMomentum(finalMomentum).setMomentum(momentum) | ||||
|                         .setSwitchMomentumIteration(switchMomentumIteration).normalize(normalize) | ||||
|                         .stopLyingIteration(stopLyingIteration).tolerance(tolerance).learningRate(learningRate) | ||||
|                         .perplexity(perplexity).minGain(minGain).build(); | ||||
| 
 | ||||
|         final double DELTA = 1e-15; | ||||
| 
 | ||||
|         assertEquals(theta, b.getTheta(), DELTA); | ||||
|         assertEquals("invert", invert, b.isInvert()); | ||||
|         assertEquals("similarityFunctions", similarityFunctions, b.getSimiarlityFunction()); | ||||
|         assertEquals("maxIter", maxIter, b.maxIter); | ||||
|         assertEquals(realMin, b.realMin, DELTA); | ||||
|         assertEquals(initialMomentum, b.initialMomentum, DELTA); | ||||
|         assertEquals(finalMomentum, b.finalMomentum, DELTA); | ||||
|         assertEquals(momentum, b.momentum, DELTA); | ||||
|         assertEquals("switchMomentumnIteration", switchMomentumIteration, b.switchMomentumIteration); | ||||
|         assertEquals("normalize", normalize, b.normalize); | ||||
|         assertEquals("stopLyingInMemoryLookupTable.javaIteration", stopLyingIteration, b.stopLyingIteration); | ||||
|         assertEquals(tolerance, b.tolerance, DELTA); | ||||
|         assertEquals(learningRate, b.learningRate, DELTA); | ||||
|         assertEquals("useAdaGrad", useAdaGrad, b.useAdaGrad); | ||||
|         assertEquals(perplexity, b.getPerplexity(), DELTA); | ||||
|         assertEquals(minGain, b.minGain, DELTA); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testPerplexity() throws Exception { | ||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500) | ||||
|                 .useAdaGrad(false).build(); | ||||
| 
 | ||||
|         DataSetIterator iter = new MnistDataSetIterator(100, true, 12345); | ||||
|         INDArray data = iter.next().getFeatures(); | ||||
| 
 | ||||
|         INDArray perplexityOutput = b.computeGaussianPerplexity(data, 30.0); | ||||
| //        System.out.println(perplexityOutput); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testReproducibility() { | ||||
|         Nd4j.getRandom().setSeed(10); | ||||
|         INDArray input = Nd4j.createFromArray(new double[]{ 0.4681,    0.2971, | ||||
|                 0.2938,    0.3655, | ||||
|                 0.3968,    0.0990, | ||||
|                 0.0796,    0.9245}).reshape(4,2); | ||||
| 
 | ||||
|         BarnesHutTsne b1 = new BarnesHutTsne.Builder().perplexity(1.0).build(), | ||||
|                 b2 = new BarnesHutTsne.Builder().perplexity(1.0).build(); | ||||
|         b1.setSimiarlityFunction(Distance.EUCLIDEAN.toString()); | ||||
|         b2.setSimiarlityFunction(Distance.EUCLIDEAN.toString()); | ||||
| 
 | ||||
|         b1.fit(input); | ||||
|         INDArray ret1 = b1.getData(); | ||||
| 
 | ||||
|         Nd4j.getRandom().setSeed(10); | ||||
|         b2.fit(input); | ||||
|         INDArray ret2 = b2.getData(); | ||||
|         assertEquals(ret1, ret2); | ||||
|     } | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testCorrectness() throws IOException { | ||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().perplexity(20.0).numDimension(2).learningRate(200).setMaxIter(50) | ||||
|                 .useAdaGrad(false).build(); | ||||
| 
 | ||||
|         ClassPathResource resource = new ClassPathResource("/mnist2500_X.txt"); | ||||
|         File f = resource.getTempFileFromArchive(); | ||||
|         INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), "   "); | ||||
|         StopWatch watch = new StopWatch(); | ||||
|         watch.start(); | ||||
|         b.fit(data); | ||||
| //        System.out.println(b.getData()); | ||||
|         watch.stop(); | ||||
|         File outDir = testDir.newFolder(); | ||||
|         ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); | ||||
|         List<String> labelsList = IOUtils.readLines(labels.getInputStream()); | ||||
|         b.saveAsFile(/*labelsList,*/ new File(outDir, "raw.txt").getAbsolutePath()); | ||||
| //        System.out.println(b.getData()); | ||||
| 
 | ||||
|         System.out.println("Fit done in " + watch); | ||||
|         assertEquals(2500, b.getData().size(0)); | ||||
| //        System.out.println(b.getData()); | ||||
| 
 | ||||
|         INDArray a1 = b.getData().getRow(0); | ||||
|         INDArray a2 = b.getData().getRow(1); | ||||
|         INDArray a3 = b.getData().getRow(1000); | ||||
|         INDArray a4 = b.getData().getRow(2498); | ||||
|         INDArray a5 = b.getData().getRow(2499); | ||||
| 
 | ||||
|         INDArray expectedRow0 = Nd4j.createFromArray(new double[]{   167.8292,   32.5092,   75.6999,  -27.1170,   17.6490,  107.4103,   46.2925,    0.4640,  -30.7644,   -5.6178,   18.9462,    0.0773,   16.9440,   82.9042,   82.0447,   57.1004,  -65.7106,   21.9009,   31.2762,  -46.9130,  -79.2331,  -47.1991,  -84.3263,   53.6706,   90.2068,  -35.2406,  -39.4955,  -34.6930,  -27.5715,   -4.8603, -126.0396,  -58.8744, -101.5482,   -0.2450,  -12.1293,   74.7684,   69.9875,  -42.2529,  -23.4274,   24.8436,    1.4931,    3.3617,  -85.8046,   31.6360,   29.9752, -118.0233,   65.4318,  -16.9101,   65.3177,  -37.1838,   21.2493,   32.0591,    2.8582,  -62.2490,  -61.2909}); | ||||
|         INDArray expectedRow1 = Nd4j.createFromArray(new double[]{   32.3478,  118.7499,   -5.2345,   18.1522,   -5.7661,   55.0841,   19.1792,    0.6082,   18.7637,  145.1893,   56.9232,   95.6905,    0.6450,   54.9728,  -47.6037,   18.9907,   44.9000,   62.0607,   11.3163,   12.5538,   71.6602,   62.7464,   26.8367,    9.9804,   21.2930,   26.7346,  -25.4178,    0.8815,  127.8388,   95.7059,   61.8721,  198.7351,    3.7012,   38.8855,   56.8623,   -1.9203,  -21.2366,   26.3412,  -15.0002,   -5.5686,  -70.1437,  -75.2662,    5.2471,   32.7884,    9.0304,   25.5222,   52.0305,  -25.6134,   48.3513,   24.0128,  -15.4485, -139.3574,    7.2340,   82.3224,   12.1519}); | ||||
|         INDArray expectedRow1000 = Nd4j.createFromArray(new double[]{  30.8645,  -15.0904,   -8.3493,    3.7487,  -24.4678,    8.1096,   42.3257,   15.6477,  -45.1260,   31.5830,   40.2178,  -28.7947,  -83.6021,   -4.2135,   -9.8731,    0.3819,   -5.6642,  -34.0559,  -67.8494,  -33.4919,   -0.6254,    6.2422,  -56.9254,  -16.5402,   52.7575,  -72.3746,   18.7587,  -47.5842,   12.8834,  -20.3063,   21.7613,  -59.9718,    9.4924,   49.3242,  -36.5622,  -83.7369,   24.9921,   20.6678,    0.0452,  -69.3666,   13.2417,  -63.0318,    8.8107,  -34.4605,   -7.9497,  -12.0326,   27.4876,   -5.1647,    0.4363,  -24.6792,   -7.2241,   47.9472,   16.9052,   -8.1184,  -35.9715}); | ||||
|         INDArray expectedRow2498 = Nd4j.createFromArray(new double[]{  -0.0919, -153.8959,  -51.5028,  -73.8650,   -0.1183,  -14.4633,  -13.5049,   43.3787,   80.7100,    3.4296,   16.9782,  -75.3470,  103.3307,   13.8846,   -6.9218,   96.0892,    6.9730,   -2.1582,  -24.3647,   39.9077,  -10.5426, -135.5623,   -3.5470,   27.1481,  -24.0933,  -47.3872,    4.5534, -118.1384, -100.2693,  -64.9634,  -85.7244,   64.6426,  -48.8833,  -31.1378,  -93.3141,   37.8991,    8.5912,  -58.7564,   93.5057,   43.7609,  -34.8800,  -26.4699,  -37.5039,   10.8743,   22.7238,  -46.8137,   22.4390,  -12.9343,   32.6593,  -11.9136, -123.9708,   -5.3310,  -65.2792,  -72.1379,   36.7171}); | ||||
|         INDArray expectedRow2499 = Nd4j.createFromArray(new double[]{  -48.1854,   54.6014,   61.4287,    7.2306,   67.0068,   97.8297,   79.4408,   40.5714,  -18.2712,   -0.4891,   36.9610,   70.8634,  109.1919,  -28.6810,   13.5949,   -4.6143,   11.4054,  -95.5810,   20.6512,   77.8442,   33.2472,   53.7065,    4.3208,  -85.9796,   38.1717,   -9.6965,   44.0203,    1.0427,  -17.6281,  -54.7104,  -88.1742,  -24.6297,   33.5158,  -10.4808,   16.7051,   21.7057,   42.1260,   61.4450,   -9.4028,  -68.3737,   18.8957,   45.0714,   14.3170,   84.0521,   80.0860,  -15.4343,  -73.6115,  -15.5358,  -41.5067,  -55.7111,    0.1811,  -75.5584,   16.4112, -128.0799,  119.3907}); | ||||
| 
 | ||||
|         assertArrayEquals(expectedRow0.toDoubleVector(), b.getData().getRow(0).toDoubleVector(), 1e-4); | ||||
|         assertArrayEquals(expectedRow1.toDoubleVector(), b.getData().getRow(1).toDoubleVector(), 1e-4); | ||||
|         assertArrayEquals(expectedRow1000.toDoubleVector(), b.getData().getRow(1000).toDoubleVector(), 1e-4); | ||||
|         assertArrayEquals(expectedRow2498.toDoubleVector(), b.getData().getRow(2498).toDoubleVector(), 1e-4); | ||||
|         assertArrayEquals(expectedRow2499.toDoubleVector(), b.getData().getRow(2499).toDoubleVector(), 1e-4); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCorrectness1() { | ||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
| 
 | ||||
|         double[] aData = new double[]{ | ||||
|                 0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                 0.4093918718557811,  0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; | ||||
|         INDArray data = Nd4j.createFromArray(aData).reshape(11,5); | ||||
| 
 | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).setMaxIter(20).perplexity(3.0).theta(0.5).numDimension(5). | ||||
|                 invertDistanceMetric(false).similarityFunction(Distance.EUCLIDEAN.toString()) | ||||
|                 .setMomentum(0.5).learningRate(200).staticInit(data).setSwitchMomentumIteration(250) | ||||
|                 .useAdaGrad(false).build(); | ||||
| 
 | ||||
|         b.fit(data); | ||||
| 
 | ||||
|         double[] expectedData = new double[]{  63.8206,   80.4013,  -19.4424, -140.4326,  198.7239, | ||||
|                                               106.1148,  -96.6273, -124.3634,   78.4174,  -83.6621, | ||||
|                                              -121.8706,    3.0888, -172.8560,  255.1262,   20.7021, | ||||
|                                              -120.7942,  -78.1829,   56.6021, -112.3294,  185.4084, | ||||
|                                                88.5330,   78.0497,  -18.8673,  -11.0155, -175.1564, | ||||
|                                             -297.8463,  174.2511, -103.8793,   72.5455,  -15.8498, | ||||
|                                             -134.5235,   42.3300,  154.0391, -280.1010, -167.9765, | ||||
|                                                306.9938, -150.9666,   83.4419,  -36.0877,   83.9992, | ||||
|                                                245.1813,  -81.5018,  -14.8430,   16.1557,  166.8651, | ||||
|                                                -65.9247, -138.1783,   72.5444,  176.3088,  -25.6732, | ||||
|                                                -69.6843,  167.3360,   87.6238,  -18.5874, -187.3806}; | ||||
| 
 | ||||
|         INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5); | ||||
|         for (int i = 0; i < expectedArray.rows(); ++i) | ||||
|             assertArrayEquals(expectedArray.getRow(i).toDoubleVector(), b.getData().getRow(i).toDoubleVector(), 1e-2); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testComputePerplexity() { | ||||
|         double[] input = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                 0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041}; | ||||
|         INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5) | ||||
|                 .useAdaGrad(false).build(); | ||||
|         b.computeGaussianPerplexity(ndinput, 3.0); | ||||
|         INDArray expectedRows = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); | ||||
|         INDArray expectedCols = Nd4j.createFromArray(new int[] {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); | ||||
|         INDArray expectedValues = Nd4j.createFromArray(new double[]{0.6199394088807811, 0.1964597878478939, 0.13826096288374987, 0.019500202354103796, 0.00892011933324624, 0.008390894278481041, 0.00333353509170543, 0.0026231979968002537, 0.0025718913332382506, 0.5877813741023542, 0.2824053513290301, 0.08100641562340703, 0.014863269403258283, 0.01219532549481422, 0.011522812905961816, 0.004243949243254114, 0.0034625890823446427, 0.002518912815575669, 0.6776991917357972, 0.18322100043035286, 0.040180871517768765, 0.02941481903928284, 0.021638322103495665, 0.019899251613183868, 0.011684443899339756, 0.008438621670147969, 0.007823477990631192, 0.6771051692354304, 0.16616561426152007, 0.06038657043891834, 0.04649900136463559, 0.01688479525099354, 0.014596215509122025, 0.006410339053808227, 0.006075759373243866, 0.005876535512328113, 0.6277958923349469, 0.23516301304728018, 0.07022275517450298, 0.030895020584550934, 0.012294459258033335, 0.009236709512467177, 0.00821667460222265, 0.0043013613064171955, 0.0018741141795786528, 0.7122763773574693, 0.07860063708191449, 0.07060648172121314, 0.06721282603559373, 0.028960026354739106, 0.017791245039439314, 0.01482510169996304, 0.005496178688168659, 0.004231126021499254, 0.5266697563046261, 0.33044733058681547, 0.10927281903651001, 0.018510201893239094, 0.006973656012751928, 0.006381768970069082, 0.0010596892780182746, 6.535010081417198E-4, 3.127690982824874E-5, 0.7176189632561156, 0.08740746743997298, 0.059268842313360166, 0.04664131589557433, 0.03288791302822797, 0.029929724912968133, 0.013368915822982491, 0.010616377319500762, 0.0022604800112974647, 0.689185362462809, 0.13977758696450715, 0.05439663822300743, 0.05434167873889952, 0.028687383013327405, 0.02099540802182275, 0.0072154477293594615, 0.0032822412915506907, 0.0021182535547164334, 0.6823844384306867, 0.13452128016104092, 0.08713547969428868, 0.04287399325857787, 0.025452813990877978, 0.016881841237860937, 0.0072200814416566415, 0.0019232561582331975, 0.0016068156267770154, 0.6425943207872832, 0.18472852256294967, 0.1089653923564887, 0.03467849453890959, 0.013282484305873534, 0.005149863792637524, 0.0037974408302766656, 0.003787710699822367, 0.003015770125758626}); | ||||
|         assertArrayEquals(expectedCols.toIntVector(), b.getCols().toIntVector()); | ||||
|         assertArrayEquals(expectedRows.toIntVector(), b.getRows().toIntVector()); | ||||
|         assertArrayEquals(expectedValues.toDoubleVector(), b.getVals().toDoubleVector(), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testComputeGradient() { | ||||
|         double[] input = new double[]{0.3000,    0.2625,    0.2674,    0.8604,    0.4803, | ||||
|                                     0.1096,    0.7950,    0.5918,    0.2738,    0.9520, | ||||
|                                     0.9690,    0.8586,    0.8088,    0.5338,    0.5961, | ||||
|                                     0.7187,    0.4630,    0.0867,    0.7748,    0.4802, | ||||
|                                     0.2493,    0.3227,    0.3064,    0.6980,    0.7977, | ||||
|                                     0.7674,    0.1680,    0.3107,    0.0217,    0.1380, | ||||
|                                     0.8619,    0.8413,    0.5285,    0.9703,    0.6774, | ||||
|                                     0.2624,    0.4374,    0.1569,    0.1107,    0.0601, | ||||
|                                     0.4094,    0.9564,    0.5994,    0.8279,    0.3859, | ||||
|                                     0.6202,    0.7604,    0.0788,    0.0865,    0.7445, | ||||
|                                     0.6548,    0.3385,    0.0582,    0.6249,    0.7432}; | ||||
|         INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5) | ||||
|                 .useAdaGrad(false).staticInit(ndinput).build(); | ||||
|         b.setY(ndinput); | ||||
|         b.setN(11); | ||||
| 
 | ||||
|         INDArray rowsP = Nd4j.createFromArray(new int[]{0,         9,        18,        27,        36,        45,        54,        63,        72,        81,        90,        99}); | ||||
|         INDArray colsP = Nd4j.createFromArray(new int[]{4,         3,        10,         8,         6,         7,         1,         5,         9,         4,         9,         8,        10,         2,         0,         6,         7,         3,         6,         8,         3,         9,        10,         1,         4,         0,         5,        10,         0,         4,         6,         8,         9,         2,         5,         7,         0,        10,         3,         1,         8,         9,         6,         7,         2,         7,         9,         3,        10,         0,         4,         2,         8,         1,         2,         8,         3,        10,         0,         4,         9,         1,         5,         5,         9,         0,         3,        10,         4,         8,         1,         2,         6,         2,         0,         3,         4,         1,        10,         9,         7,        10,         1,         3,         7,         4,         5,         2,         8,         6,         3,         4,         0,         9,         6,         5,         8,         7,         1}); | ||||
|         INDArray valsP = Nd4j.createFromArray(new double[]{0.6200,    0.1964,    0.1382,    0.0195,    0.0089,    0.0084,    0.0033,    0.0026,    0.0026,    0.5877,    0.2825,    0.0810,    0.0149,    0.0122,    0.0115,    0.0042,    0.0035,    0.0025,    0.6777,    0.1832,    0.0402,    0.0294,    0.0216,    0.0199,    0.0117,    0.0084,    0.0078,    0.6771,    0.1662,    0.0604,    0.0465,    0.0169,    0.0146,    0.0064,    0.0061,    0.0059,    0.6278,    0.2351,    0.0702,    0.0309,    0.0123,    0.0092,    0.0082,    0.0043,    0.0019,    0.7123,    0.0786,    0.0706,    0.0672,    0.0290,    0.0178,    0.0148,    0.0055,    0.0042,    0.5267,    0.3304,    0.1093,    0.0185,    0.0070,    0.0064,    0.0011,    0.0007, 3.1246e-5,    0.7176,    0.0874,    0.0593,    0.0466,    0.0329,    0.0299,    0.0134,    0.0106,    0.0023,    0.6892,    0.1398,    0.0544,    0.0544,    0.0287,    0.0210,    0.0072,    0.0033,    0.0021,    0.6824,    0.1345,    0.0871,    0.0429,    0.0254,    0.0169,    0.0072,    0.0019,    0.0016,    0.6426,    0.1847,    0.1090,    0.0347,    0.0133,    0.0051,    0.0038,    0.0038,    0.0030}); | ||||
| 
 | ||||
|         b.setRows(rowsP); | ||||
|         b.setCols(colsP); | ||||
|         b.setVals(valsP); | ||||
|         Gradient gradient = b.gradient(); | ||||
| 
 | ||||
|         double[] dC = {-0.0618386320333619, -0.06266654959379839, 0.029998268806149204, 0.10780566335888186, -0.19449543068355346, -0.14763764361792697, 0.17493572758118422, 0.1926109839221966, -0.15176648259935419, 0.10974665709698186, 0.13102419155322598, 0.004941641352409449, 0.19159764518354974, -0.26332838053474944, -0.023631441261541583, 0.09838669432305949, 0.09709129638394683, -0.01605053000727605, 0.06566171635025217, -0.17325078066035252, -0.1090854255505605, 0.023350644966904276, 0.075192354899586, -0.08278373866517603, 0.18431338134579323, 0.2766031655578053, -0.17557907233268688, 0.10616148241800637, -0.09999024423215641, -0.017181932145255287, 0.06711331400576945, -0.01388231800826619, -0.10248189290485302, 0.20786521034824304, 0.11254913977572988, -0.289564646781519, 0.13491805919337516, -0.07504249344962562, 0.004154656287570634, -0.10516715438388784, -0.27984655075804576, 0.09811828071286613, 0.03684521473995052, -0.054645216532387256, -0.18147132772800725, 0.027588750493223044, 0.214734364419479, -0.026729138234415008, -0.28410504978879136, 0.007015481601883835, 0.04427981739424874, -0.059253265830134655, -0.05325479031206952, -0.11319889109674944, 0.1530133971867549}; | ||||
|         INDArray actual = gradient.getGradientFor("yIncs"); | ||||
| //        System.out.println(actual); | ||||
|         assertArrayEquals(dC, actual.reshape(1,55).toDoubleVector(), 1e-05); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testApplyGradient() { | ||||
|         double[] Y = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                 0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041}; | ||||
|         INDArray ndinput = Nd4j.createFromArray(Y).reshape(11,5); | ||||
| 
 | ||||
|         double[] gradient = {   -0.0635,   -0.0791,    0.0228,    0.1360,   -0.2016, | ||||
|                    -0.1034,    0.0976,    0.1266,   -0.0781,    0.0707, | ||||
|                     0.1184,   -0.0018,    0.1719,   -0.2529,   -0.0209, | ||||
|                     0.1204,    0.0855,   -0.0530,    0.1069,   -0.1860, | ||||
|                    -0.0890,   -0.0763,    0.0181,    0.0048,    0.1798, | ||||
|                     0.2917,   -0.1699,    0.1038,   -0.0736,    0.0159, | ||||
|                     0.1324,   -0.0409,   -0.1502,    0.2738,    0.1668, | ||||
|                    -0.3012,    0.1489,   -0.0801,    0.0329,   -0.0817, | ||||
|                    -0.2405,    0.0810,    0.0171,   -0.0201,   -0.1638, | ||||
|                     0.0656,    0.1383,   -0.0707,   -0.1757,    0.0144, | ||||
|                     0.0708,   -0.1725,   -0.0870,    0.0160,    0.1921}; | ||||
|         INDArray ndgrad = Nd4j.createFromArray(gradient).reshape(11, 5); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()) | ||||
|                 .invertDistanceMetric(false).theta(0.5).learningRate(200) | ||||
|                 .useAdaGrad(false).staticInit(ndinput).build(); | ||||
|         b.setY(ndinput); | ||||
|         b.setN(11); | ||||
|         INDArray yIncs = Nd4j.zeros(DataType.DOUBLE, ndinput.shape()); | ||||
|         b.setYIncs(yIncs); | ||||
|         INDArray gains = Nd4j.zeros(DataType.DOUBLE, ndinput.shape()); | ||||
|         b.setGains(gains); | ||||
|         b.update(ndgrad, "yIncs"); | ||||
| 
 | ||||
|         double[] expected = {2.54, 3.164, -0.912, -5.44, 8.064, 4.136, -3.9040000000000004, -5.064, 3.124, -2.828, -4.736000000000001, 0.072, -6.8759999999999994, 10.116, 0.836, -4.816, -3.4200000000000004, 2.12, -4.276, 7.4399999999999995, 3.5599999999999996, 3.0520000000000005, -0.7240000000000001, -0.19199999999999998, -7.191999999999999, -11.668000000000001, 6.795999999999999, -4.152, 2.944, -0.636, -5.295999999999999, 1.636, 6.008, -10.952, -6.672000000000001, 12.048000000000002, -5.956, 3.204, -1.3159999999999998, 3.268, 9.62, -3.24, -0.684, 0.804, 6.552, -2.624, -5.532, 2.828, 7.028, -0.576, -2.832, 6.8999999999999995, 3.4799999999999995, -0.64, -7.683999999999999}; | ||||
|         assertArrayEquals(expected, b.getYIncs().reshape(55).toDoubleVector(), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testComputeEdgeForces() { | ||||
|         double[] input = new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803, | ||||
|                 0.1096, 0.7950, 0.5918, 0.2738, 0.9520, | ||||
|                 0.9690, 0.8586, 0.8088, 0.5338, 0.5961, | ||||
|                 0.7187, 0.4630, 0.0867, 0.7748, 0.4802, | ||||
|                 0.2493, 0.3227, 0.3064, 0.6980, 0.7977, | ||||
|                 0.7674, 0.1680, 0.3107, 0.0217, 0.1380, | ||||
|                 0.8619, 0.8413, 0.5285, 0.9703, 0.6774, | ||||
|                 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, | ||||
|                 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, | ||||
|                 0.6202, 0.7604, 0.0788, 0.0865, 0.7445, | ||||
|                 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}; | ||||
|         INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5); | ||||
|         SpTree tree = new SpTree(ndinput); | ||||
|         INDArray rows = Nd4j.createFromArray(new int[]{0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); | ||||
|         INDArray cols = Nd4j.createFromArray(new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); | ||||
|         INDArray vals = Nd4j.createFromArray(new double[]{0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030}); | ||||
|         int N = 11; | ||||
|         INDArray posF = Nd4j.create(ndinput.shape()); | ||||
|         tree.computeEdgeForces(rows, cols, vals, N, posF); | ||||
|         double[] expectedPosF = {-0.08017022778816381, -0.08584612446002386, 0.024041740837932417, 0.13353853518214748, -0.19989209255196486, -0.17059164865362167, 0.18730152809351328, 0.20582835656173232, -0.1652505189678666, 0.13123839113710167, 0.15511476126066306, 0.021425546153174206, 0.21755440369356663, -0.2628756936897519, -0.021079609911707077, 0.11455959658671841, 0.08803186126822704, -0.039212116057989604, 0.08800854045636688, -0.1795568260613919, -0.13265313037184673, 0.0036829788349159154, 0.07205631770917967, -0.06873974602987808, 0.20446419876515043, 0.28724205607738795, -0.19397780156808536, 0.10457369548573531, -0.12340830629973816, -0.03634773269456816, 0.0867775929922852, 0.0029761730963277894, -0.09131897988004745, 0.2348924028566898, 0.12026408931908775, -0.30400848137321873, 0.1282943410872978, -0.08487864823843354, -0.017561758195375168, -0.13082811573092396, -0.2885857462722986, 0.12469730654026252, 0.05408469871148934, -0.03417740859260864, -0.19261929748672968, 0.03318694717819495, 0.22818123908045765, -0.044944593551341956, -0.3141734963080852, 0.020297428845239652, 0.05442118949793863, -0.07890301602838638, -0.07823705950336371, -0.10455483898962027, 0.16980714813230746}; | ||||
|         INDArray indExpectedPositive = Nd4j.createFromArray(expectedPosF).reshape(11, 5); | ||||
|         assertEquals(indExpectedPositive, posF); | ||||
| 
 | ||||
|         AtomicDouble sumQ = new AtomicDouble(0.0); | ||||
|         double theta = 0.5; | ||||
|         INDArray negF = Nd4j.create(ndinput.shape()); | ||||
| 
 | ||||
|         double[][] neg = {{-1.6243229118532043, -2.0538918185758117, -0.5277950148630416, 2.280133920112387, -0.4781864949257863}, | ||||
|         {-2.033904565482581, 1.0957067439325718, 1.1711627018218371, -1.1947911960637323, 1.904335906364157}, | ||||
|         {2.134613094178481, 1.4606030267537151, 2.299972033488509, 0.040111598796927175, 0.22611223726312565}, | ||||
|         {1.4330457669590706, -0.8027368824700638, -2.052297868677289, 1.9801035811739054, -0.5587649959721402}, | ||||
|         {-2.088283171473531, -1.7427092080895168, -0.27787744880128185, 1.2444077055013942, 1.7855201950031347}, | ||||
|         {0.9426889976629138, -1.6302714638583877, -0.14069035384185855, -2.075023651861262, -1.698239988087389}, | ||||
|         {1.7424090804808496, 1.493794306111751, 0.989121494481274, 2.394820866756112, 0.6836049340540907}, | ||||
|         {-1.279836833417519, -0.5869132848699253, -0.871560326864079, -1.9242443527432451, -2.273762088892443}, | ||||
|         {-0.7743611464510498, 2.3551097898757134, 1.527553257122278, 1.813608037002701, -0.9877974041073948}, | ||||
|         {0.49604405759812625, 1.1914983778171337, -1.6140319597311803, -2.6642997837396654, 1.1768845173097966}, | ||||
|         {0.8986049706740562, -1.7411217160869163, -2.213624650045752, 0.7659306956507013, 1.4880578211349607}}; | ||||
| 
 | ||||
|         double expectedSumQ = 88.60782954084712; | ||||
| 
 | ||||
|         for (int n = 0; n < N; n++) { | ||||
|             tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); | ||||
|             assertArrayEquals(neg[n], negF.slice(n).toDoubleVector(), 1e-05); | ||||
|         } | ||||
|         assertEquals(expectedSumQ, sumQ.get(), 1e-05); | ||||
|     } | ||||
| 
 | ||||
|     /* | ||||
|     @Test | ||||
|     public void testSymmetrized() { | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).perplexity(3.0).similarityFunction(Distance.EUCLIDEAN.toString()).invertDistanceMetric(false).theta(0.5) | ||||
|                 .useAdaGrad(false).build(); | ||||
|         INDArray expectedSymmetrized = Nd4j.createFromArray(new double[]{0.6239, 0.1813, 0.12359999999999999, 0.03695, 0.00795, 0.03385, 0.0074, 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051000000000000004, 0.00895, 0.016050000000000002, 0.00245, 0.00705, 0.00125, 0.0021, 0.016050000000000002, 0.6022, 0.1615, 0.0233, 0.0183, 0.0108, 0.0068000000000000005, 0.0042, 0.011300000000000001, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, 0.0779, 0.03565, 0.05085, 0.038349999999999995, 0.026250000000000002, 0.6239, 0.3093, 0.0068000000000000005, 0.0653, 0.2099, 0.0205, 0.0173, 0.007300000000000001, 0.0171, 0.0089, 0.0158, 0.011300000000000001, 0.038349999999999995, 0.71495, 0.04775, 0.03615, 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, 0.0779, 0.007300000000000001, 0.5098, 0.015899999999999997, 0.00135, 1.5623E-5, 0.03385, 0.00705, 0.026250000000000002, 0.0171, 0.71495, 0.06515, 0.018349999999999998, 0.00775, 0.00115, 0.03695, 0.051000000000000004, 0.1615, 0.03565, 0.0205, 0.00275, 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, 0.05085, 0.0173, 0.04775, 0.00135, 0.06515, 0.0026, 0.35855, 0.12359999999999999, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.015899999999999997, 0.018349999999999998, 0.0055, 0.35855}); | ||||
|         INDArray rowsP = Nd4j.createFromArray(new int[]{0,         9,        18,        27,        36,        45,        54,        63,        72,        81,        90,        99}); | ||||
|         INDArray colsP = Nd4j.createFromArray(new int[]{4,         3,        10,         8,         6,         7,         1,         5,         9,         4,         9,         8,        10,         2,         0,         6,         7,         3,         6,         8,         3,         9,        10,         1,         4,         0,         5,        10,         0,         4,         6,         8,         9,         2,         5,         7,         0,        10,         3,         1,         8,         9,         6,         7,         2,         7,         9,         3,        10,         0,         4,         2,         8,         1,         2,         8,         3,        10,         0,         4,         9,         1,         5,         5,         9,         0,         3,        10,         4,         8,         1,         2,         6,         2,         0,         3,         4,         1,        10,         9,         7,        10,         1,         3,         7,         4,         5,         2,         8,         6,         3,         4,         0,         9,         6,         5,         8,         7,         1}); | ||||
|         INDArray valsP = Nd4j.createFromArray(new double[]{0.6200,    0.1964,    0.1382,    0.0195,    0.0089,    0.0084,    0.0033,    0.0026,    0.0026,    0.5877,    0.2825,    0.0810,    0.0149,    0.0122,    0.0115,    0.0042,    0.0035,    0.0025,    0.6777,    0.1832,    0.0402,    0.0294,    0.0216,    0.0199,    0.0117,    0.0084,    0.0078,    0.6771,    0.1662,    0.0604,    0.0465,    0.0169,    0.0146,    0.0064,    0.0061,    0.0059,    0.6278,    0.2351,    0.0702,    0.0309,    0.0123,    0.0092,    0.0082,    0.0043,    0.0019,    0.7123,    0.0786,    0.0706,    0.0672,    0.0290,    0.0178,    0.0148,    0.0055,    0.0042,    0.5267,    0.3304,    0.1093,    0.0185,    0.0070,    0.0064,    0.0011,    0.0007, 3.1246e-5,    0.7176,    0.0874,    0.0593,    0.0466,    0.0329,    0.0299,    0.0134,    0.0106,    0.0023,    0.6892,    0.1398,    0.0544,    0.0544,    0.0287,    0.0210,    0.0072,    0.0033,    0.0021,    0.6824,    0.1345,    0.0871,    0.0429,    0.0254,    0.0169,    0.0072,    0.0019,    0.0016,    0.6426,    0.1847,    0.1090,    0.0347,    0.0133,    0.0051,    0.0038,    0.0038,    0.0030}); | ||||
|         b.setN(11); | ||||
|         BarnesHutTsne.SymResult actualSymmetrized = b.symmetrized(rowsP, colsP, valsP); | ||||
|         System.out.println("Symmetrized from Java:" + actualSymmetrized); | ||||
|         System.out.println(actualSymmetrized.rows); | ||||
|         System.out.println(actualSymmetrized.cols); | ||||
|         assertArrayEquals(expectedSymmetrized.toDoubleVector(), actualSymmetrized.vals.toDoubleVector(), 1e-5); | ||||
| 
 | ||||
| 
 | ||||
|         INDArray rowsFromCpp = Nd4j.create(new int[]{rowsP.rows(),rowsP.columns()}, DataType.INT); | ||||
|         BarnesHutSymmetrize op = new BarnesHutSymmetrize(rowsP, colsP, valsP, 11, rowsFromCpp); | ||||
|         Nd4j.getExecutioner().exec(op); | ||||
|         INDArray valsFromCpp = op.getSymmetrizedValues(); | ||||
|         INDArray colsFromCpp = op.getSymmetrizedCols(); | ||||
|         System.out.println("Symmetrized from C++: " + valsP); | ||||
|         assertArrayEquals(expectedSymmetrized.toDoubleVector(), valsFromCpp.toDoubleVector(), 1e-5); | ||||
| 
 | ||||
|         int[] expectedRows = new int[]{0, 10, 20, 30, 40, 50, 60, 69, 78, 88, 98, 108}; | ||||
|         int[] expectedCols = new int[]{4, 3, 10, 8, 6, 7, 1, 5, 9, 2, 0, 4, 9, 8, 10, 2, 6, 7, 3, 5, 1, 6, 8, 3, 9, 10, 4, 0, 5, 7, 0, 1, 2, 10, 4, 6, 8, 9, 5, 7, 0, 1, 2, 3, 10, 8, 9, 6, 7, 5, 0, 2, 3, 7, 9, 10, 4, 8, 1, 6, 0, 1, 2, 3, 4, 8, 10, 9, 5, 0, 1, 3, 4, 5, 9, 10, 8, 2, 0, 1, 2, 3, 4, 5, 6, 7, 10, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; | ||||
| 
 | ||||
|         assertArrayEquals(expectedRows, rowsFromCpp.toIntVector()); | ||||
|         assertArrayEquals(expectedCols, colsFromCpp.toIntVector()); | ||||
|     } | ||||
|      */ | ||||
| 
 | ||||
|     @Test | ||||
|     public void testVPTree() { | ||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { | ||||
|             double[] d = new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803, | ||||
|                     0.1096, 0.7950, 0.5918, 0.2738, 0.9520, | ||||
|                     0.9690, 0.8586, 0.8088, 0.5338, 0.5961, | ||||
|                     0.7187, 0.4630, 0.0867, 0.7748, 0.4802, | ||||
|                     0.2493, 0.3227, 0.3064, 0.6980, 0.7977, | ||||
|                     0.7674, 0.1680, 0.3107, 0.0217, 0.1380, | ||||
|                     0.8619, 0.8413, 0.5285, 0.9703, 0.6774, | ||||
|                     0.2624, 0.4374, 0.1569, 0.1107, 0.0601, | ||||
|                     0.4094, 0.9564, 0.5994, 0.8279, 0.3859, | ||||
|                     0.6202, 0.7604, 0.0788, 0.0865, 0.7445, | ||||
|                     0.6548, 0.3385, 0.0582, 0.6249, 0.7432}; | ||||
|             VPTree tree = new VPTree(Nd4j.createFromArray(d).reshape(11, 5), "euclidean", 1, false); | ||||
|             INDArray target = Nd4j.createFromArray(new double[]{0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); | ||||
|             List<DataPoint> results = new ArrayList<>(); | ||||
|             List<Double> distances = new ArrayList<>(); | ||||
|             tree.search(target, 11, results, distances); | ||||
| //            System.out.println("Results:" + results); | ||||
| //            System.out.println("Distances:" + distances); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSpTree() { | ||||
|             double[] input = new double[]{0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                     0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                     0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                     0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                     0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                     0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                     0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041}; | ||||
|             INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5); | ||||
| 
 | ||||
|             int[] rows = {0, 10, 20, 30, 40, 50, 60, 69, 78, 88, 98, 108}; | ||||
|             INDArray indRows = Nd4j.createFromArray(rows); | ||||
|             int[] cols = {4, 3, 10, 8, 6, 7, 1, 5, 9, 2, 0, 4, 9, 8, 10, 2, 6, 7, 3, 5, 1, 6, 8, 3, 9, 10, 4, 0, 5, 7, 0, 1, 2, 10, 4, 6, 8, 9, | ||||
|                     5, 7, 0, 1, 2, 3, 10, 8, 9, 6, 7, 5, 0, 2, 3, 7, 9, 10, 4, 8, 1, 6, 0, 1, 2, 3, 4, 8, 10, 9, 5, 0, 1, 3, 4, 5, 9, 10, 8, 2, 0, 1, 2, 3, 4, 5, 6, 7, 10, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; | ||||
|             INDArray indCols = Nd4j.createFromArray(cols); | ||||
|             double[] vals = {0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911}; | ||||
|             INDArray indVals = Nd4j.createFromArray(vals); | ||||
| 
 | ||||
|             final int N = 11; | ||||
|             INDArray posF = Nd4j.create(DataType.DOUBLE, ndinput.shape()); | ||||
|             SpTree tree = new SpTree(ndinput); | ||||
|             tree.computeEdgeForces(indRows, indCols, indVals, N, posF); | ||||
|             double[]expectedPosF = {-0.0818453583761987, -0.10231102631753211, 0.016809473355579547, 0.16176252194290375, -0.20703464777007444, -0.1263832139293613, 0.10996898963389254, 0.13983782727968627, -0.09164547825742625, 0.09219041827159041, 0.14252277104691244, 0.014676985587529433, 0.19786703075718223, -0.25244374832212546, -0.018387062879777892, 0.13652061663449183, 0.07639155593531936, -0.07616591260449279, 0.12919565310762643, -0.19229222179037395, -0.11250575155166542, -0.09598877143033444, 0.014899570740339653, 0.018867923701997365, 0.19996253097190828, 0.30233811684856743, -0.18830455752593392, 0.10223346521208224, -0.09703007177169608, -0.003280966942428477, 0.15213078827243462, -0.02397414389327494, -0.1390550777479942, 0.30088735606726813, 0.17456236098186903, -0.31560012032960044, 0.142309945794784, -0.08988089476622348, 0.011236280978163357, -0.10732740266565795, -0.24928551644245478, 0.10762735102220329, 0.03434270193250408, 2.831838829882295E-4, -0.17494982967210068, 0.07114328804840916, 0.15171552834583996, -0.08888924450773618, -0.20576831397087963, 0.027662749212463134, 0.08096437977846523, -0.19211185715249313, -0.11199893965092741, 0.024654692641180212, 0.20889407228258244}; | ||||
|             assertArrayEquals(expectedPosF, posF.reshape(1,55).toDoubleVector(), 1e-5); | ||||
| 
 | ||||
|             final double theta = 0.5; | ||||
|             AtomicDouble sumQ = new AtomicDouble(0.0); | ||||
|             INDArray negF = Nd4j.create(DataType.DOUBLE, ndinput.shape()); | ||||
|             for (int n = 0; n < N; n++) { | ||||
|                 INDArray prev = ((n == 0) ? negF.slice(n ): negF.slice(n-1)); | ||||
|                 tree.computeNonEdgeForces(n, theta, negF.slice(0), sumQ); | ||||
|             } | ||||
| 
 | ||||
|             double[] expectedNegF = {-0.15349944039348173, -0.9608688924710804, -1.7099994806905086, 2.6604989787415203, 1.2677709150619332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | ||||
|             double expectedSum = 88.60715062760883; | ||||
| 
 | ||||
|             assertArrayEquals(expectedNegF, negF.reshape(1,55).toDoubleVector(), 1e-5); | ||||
|             assertEquals(expectedSum, sumQ.get(), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testZeroMean() { | ||||
|         double[] aData = new double[]{ | ||||
|                 0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, | ||||
|                 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, | ||||
|                 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, | ||||
|                 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, | ||||
|                 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, | ||||
|                 0.4093918718557811,  0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, | ||||
|                 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; | ||||
|         INDArray ndinput = Nd4j.createFromArray(aData).reshape(11,5); | ||||
|         BarnesHutTsne.zeroMean(ndinput); | ||||
|         double[] expected = {-0.2384362257971937, -0.3014583649756485, -0.07747340086583643, 0.3347228669042438, -0.07021239883787267, -0.4288269552188002, 0.23104543246717713, 0.24692615118463546, -0.2518949460768749, 0.40149075100042775, 0.43058455530728645, 0.2945826924287568, 0.46391735081548713, 0.008071612942910145, 0.04560992908478334, 0.18029509736889826, -0.10100112958911733, -0.25819965185986493, 0.249076993761699, -0.07027581217344359, -0.28914440219989934, -0.2412528624510093, -0.03844377463128612, 0.17229766891014098, 0.24724071459311825, 0.22893338884928305, -0.39601068985406596, -0.034122795135254735, -0.5040218199596326, -0.4125030539615038, 0.3234774312676665, 0.2773549760319213, 0.18363699390132904, 0.44461322249255764, 0.12691041508560408, -0.275970422630463, -0.12656919880264839, -0.18800328403712419, -0.41499425466692597, -0.4904037222152954, -0.12902604875790624, 0.3924120572383435, 0.2545557508323111, 0.30216923841015575, -0.16460937225707228, 0.0817665510120591, 0.1964040455733127, -0.26610182764728363, -0.4392121790122696, 0.19404338217447925, 0.11634703079906861, -0.22550695806702292, -0.2866915125571131, 0.09917159629399586, 0.19270916750677514}; | ||||
|         assertArrayEquals(expected, ndinput.reshape(55).toDoubleVector(), 1e-5); | ||||
|     } | ||||
| } | ||||
| @ -190,7 +190,7 @@ public class ValidateCuDNN extends BaseDL4JTest { | ||||
|         validateLayers(net, classesToTest, false, fShape, lShape, CuDNNValidationUtil.MAX_REL_ERROR, CuDNNValidationUtil.MIN_ABS_ERROR); | ||||
|     } | ||||
| 
 | ||||
|     @Test @Ignore //AB 2019/05/20 - https://github.com/deeplearning4j/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later | ||||
|     @Test @Ignore //AB 2019/05/20 - https://github.com/eclipse/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later | ||||
|     public void validateConvLayersLRN() { | ||||
|         //Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else) | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|  | ||||
| @ -80,7 +80,7 @@ public abstract class CacheableExtractableDataSetFetcher implements CacheableDat | ||||
|                 log.error("Checksums do not match. Cleaning up files and failing..."); | ||||
|                 tmpFile.delete(); | ||||
|                 throw new IllegalStateException( "Dataset file failed checksum: " + tmpFile + " - expected checksum " + expectedChecksum(set) | ||||
|                 + " vs. actual checksum " + localChecksum + ". If this error persists, please open an issue at https://github.com/deeplearning4j/deeplearning4j."); | ||||
|                 + " vs. actual checksum " + localChecksum + ". If this error persists, please open an issue at https://github.com/eclipse/deeplearning4j."); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|  | ||||
| @ -1,77 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-manifold</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>deeplearning4j-tsne</artifactId> | ||||
|     <packaging>jar</packaging> | ||||
| 
 | ||||
|     <name>deeplearning4j-tsne</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>nearestneighbor-core</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-nn</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.projectlombok</groupId> | ||||
|             <artifactId>lombok</artifactId> | ||||
|             <version>${lombok.version}</version> | ||||
|             <scope>provided</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-api</artifactId> | ||||
|             <version>${nd4j.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-common-tests</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,433 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.plot; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.primitives.Ints; | ||||
| import org.apache.commons.math3.util.FastMath; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dimensionalityreduction.PCA; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.BooleanIndexing; | ||||
| import org.nd4j.linalg.indexing.INDArrayIndex; | ||||
| import org.nd4j.linalg.indexing.SpecifiedIndex; | ||||
| import org.nd4j.linalg.indexing.conditions.Conditions; | ||||
| import org.nd4j.linalg.learning.legacy.AdaGrad; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| import org.nd4j.common.util.ArrayUtil; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.io.BufferedWriter; | ||||
| import java.io.File; | ||||
| import java.io.FileWriter; | ||||
| import java.io.IOException; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.nd4j.linalg.factory.Nd4j.*; | ||||
| import static org.nd4j.linalg.ops.transforms.Transforms.*; | ||||
| 
 | ||||
| public class Tsne { | ||||
|     protected int maxIter = 1000; | ||||
|     protected double realMin = Nd4j.EPS_THRESHOLD; | ||||
|     protected double initialMomentum = 0.5; | ||||
|     protected double finalMomentum = 0.8; | ||||
|     protected double minGain = 1e-2; | ||||
|     protected double momentum = initialMomentum; | ||||
|     protected int switchMomentumIteration = 100; | ||||
|     protected boolean normalize = true; | ||||
|     protected boolean usePca = false; | ||||
|     protected int stopLyingIteration = 250; | ||||
|     protected double tolerance = 1e-5; | ||||
|     protected double learningRate = 500; | ||||
|     protected AdaGrad adaGrad; | ||||
|     protected boolean useAdaGrad = true; | ||||
|     protected double perplexity = 30; | ||||
|     //protected INDArray gains,yIncs; | ||||
|     protected INDArray Y; | ||||
| 
 | ||||
|     protected static final Logger logger = LoggerFactory.getLogger(Tsne.class); | ||||
| 
 | ||||
| 
 | ||||
|     public Tsne(final int maxIter, final double realMin, final double initialMomentum, final double finalMomentum, | ||||
|                     final double minGain, final double momentum, final int switchMomentumIteration, | ||||
|                     final boolean normalize, final boolean usePca, final int stopLyingIteration, final double tolerance, | ||||
|                     final double learningRate, final boolean useAdaGrad, final double perplexity) { | ||||
|         this.maxIter = maxIter; | ||||
|         this.realMin = realMin; | ||||
|         this.initialMomentum = initialMomentum; | ||||
|         this.finalMomentum = finalMomentum; | ||||
|         this.minGain = minGain; | ||||
|         this.momentum = momentum; | ||||
|         this.switchMomentumIteration = switchMomentumIteration; | ||||
|         this.normalize = normalize; | ||||
|         this.usePca = usePca; | ||||
|         this.stopLyingIteration = stopLyingIteration; | ||||
|         this.tolerance = tolerance; | ||||
|         this.learningRate = learningRate; | ||||
|         this.useAdaGrad = useAdaGrad; | ||||
|         this.perplexity = perplexity; | ||||
|         this.init(); | ||||
|     } | ||||
| 
 | ||||
|     protected void init() { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public INDArray calculate(INDArray X, int targetDimensions, double perplexity) { | ||||
|         // pca hook | ||||
|         if (usePca) { | ||||
|             X = PCA.pca(X, Math.min(50, X.columns()), normalize); | ||||
|         } else if (normalize) { | ||||
|             X.subi(X.min(Integer.MAX_VALUE)); | ||||
|             X = X.divi(X.max(Integer.MAX_VALUE)); | ||||
|             X = X.subiRowVector(X.mean(0)); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         int n = X.rows(); | ||||
|         // FIXME: this is wrong, another distribution required here | ||||
|         Y = Nd4j.randn(X.dataType(), X.rows(), targetDimensions); | ||||
|         INDArray dY = Nd4j.zeros(n, targetDimensions); | ||||
|         INDArray iY = Nd4j.zeros(n, targetDimensions); | ||||
|         INDArray gains = Nd4j.ones(n, targetDimensions); | ||||
| 
 | ||||
|         boolean stopLying = false; | ||||
|         logger.debug("Y:Shape is = " + Arrays.toString(Y.shape())); | ||||
| 
 | ||||
|         // compute P-values | ||||
|         INDArray P = x2p(X, tolerance, perplexity); | ||||
| 
 | ||||
|         // do training | ||||
|         for (int i = 0; i < maxIter; i++) { | ||||
|             INDArray sumY = pow(Y, 2).sum(1).transpose(); | ||||
| 
 | ||||
|             //Student-t distribution | ||||
|             //also un normalized q | ||||
|             // also known as num in original implementation | ||||
|             INDArray qu = Y.mmul(Y.transpose()).muli(-2).addiRowVector(sumY).transpose().addiRowVector(sumY).addi(1) | ||||
|                             .rdivi(1); | ||||
| 
 | ||||
|             //          doAlongDiagonal(qu,new Zero()); | ||||
| 
 | ||||
|             INDArray Q = qu.div(qu.sumNumber().doubleValue()); | ||||
|             BooleanIndexing.replaceWhere(Q, 1e-12, Conditions.lessThan(1e-12)); | ||||
| 
 | ||||
|             INDArray PQ = P.sub(Q).muli(qu); | ||||
| 
 | ||||
|             logger.debug("PQ shape is: " + Arrays.toString(PQ.shape())); | ||||
|             logger.debug("PQ.sum(1) shape is: " + Arrays.toString(PQ.sum(1).shape())); | ||||
| 
 | ||||
|             dY = diag(PQ.sum(1)).subi(PQ).mmul(Y).muli(4); | ||||
| 
 | ||||
| 
 | ||||
|             if (i < switchMomentumIteration) { | ||||
|                 momentum = initialMomentum; | ||||
|             } else { | ||||
|                 momentum = finalMomentum; | ||||
|             } | ||||
| 
 | ||||
|             gains = gains.add(.2).muli(dY.cond(Conditions.greaterThan(0)).neq(iY.cond(Conditions.greaterThan(0)))) | ||||
|                             .addi(gains.mul(0.8).muli(dY.cond(Conditions.greaterThan(0)) | ||||
|                                             .eq(iY.cond(Conditions.greaterThan(0))))); | ||||
| 
 | ||||
|             BooleanIndexing.replaceWhere(gains, minGain, Conditions.lessThan(minGain)); | ||||
| 
 | ||||
|             INDArray gradChange = gains.mul(dY); | ||||
| 
 | ||||
|             gradChange.muli(learningRate); | ||||
| 
 | ||||
|             iY.muli(momentum).subi(gradChange); | ||||
| 
 | ||||
|             double cost = P.mul(log(P.div(Q), false)).sumNumber().doubleValue(); | ||||
|             logger.info("Iteration [" + i + "] error is: [" + cost + "]"); | ||||
| 
 | ||||
|             Y.addi(iY); | ||||
|             //  Y.addi(iY).subiRowVector(Y.mean(0)); | ||||
|             INDArray tiled = Nd4j.tile(Y.mean(0), new int[] {Y.rows(), 1}); | ||||
|             Y.subi(tiled); | ||||
| 
 | ||||
|             if (!stopLying && (i > maxIter / 2 || i >= stopLyingIteration)) { | ||||
|                 P.divi(4); | ||||
|                 stopLying = true; | ||||
|             } | ||||
|         } | ||||
|         return Y; | ||||
|     } | ||||
| 
 | ||||
|     public INDArray diag(INDArray ds) { | ||||
|         boolean isLong = ds.rows() > ds.columns(); | ||||
|         INDArray sliceZero = ds.slice(0); | ||||
|         int dim = Math.max(ds.columns(), ds.rows()); | ||||
|         INDArray result = Nd4j.create(dim, dim); | ||||
|         for (int i = 0; i < dim; i++) { | ||||
|             INDArray sliceSrc = ds.slice(i); | ||||
|             INDArray sliceDst = result.slice(i); | ||||
|             for (int j = 0; j < dim; j++) { | ||||
|                 if (i == j) { | ||||
|                     if (isLong) | ||||
|                         sliceDst.putScalar(j, sliceSrc.getDouble(0)); | ||||
|                     else | ||||
|                         sliceDst.putScalar(j, sliceZero.getDouble(i)); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException { | ||||
| 
 | ||||
|         calculate(matrix, nDims, perplexity); | ||||
| 
 | ||||
|         BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), true)); | ||||
| 
 | ||||
|         for (int i = 0; i < Y.rows(); i++) { | ||||
|             if (i >= labels.size()) | ||||
|                 break; | ||||
|             String word = labels.get(i); | ||||
|             if (word == null) | ||||
|                 continue; | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             INDArray wordVector = Y.getRow(i); | ||||
|             for (int j = 0; j < wordVector.length(); j++) { | ||||
|                 sb.append(wordVector.getDouble(j)); | ||||
|                 if (j < wordVector.length() - 1) | ||||
|                     sb.append(","); | ||||
|             } | ||||
| 
 | ||||
|             sb.append(","); | ||||
|             sb.append(word); | ||||
|             sb.append(" "); | ||||
| 
 | ||||
|             sb.append("\n"); | ||||
|             write.write(sb.toString()); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         write.flush(); | ||||
|         write.close(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Computes a gaussian kernel | ||||
|      * given a vector of squared distance distances | ||||
|      * | ||||
|      * @param d the data | ||||
|      * @param beta | ||||
|      * @return | ||||
|      */ | ||||
|     public Pair<Double, INDArray> hBeta(INDArray d, double beta) { | ||||
|         INDArray P = exp(d.neg().muli(beta)); | ||||
|         double sumP = P.sumNumber().doubleValue(); | ||||
|         double logSumP = FastMath.log(sumP); | ||||
|         Double H = logSumP + ((beta * (d.mul(P).sumNumber().doubleValue())) / sumP); | ||||
|         P.divi(sumP); | ||||
|         return new Pair<>(H, P); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * This method build probabilities for given source data | ||||
|      * | ||||
|      * @param X | ||||
|      * @param tolerance | ||||
|      * @param perplexity | ||||
|      * @return | ||||
|      */ | ||||
|     private INDArray x2p(final INDArray X, double tolerance, double perplexity) { | ||||
|         int n = X.rows(); | ||||
|         final INDArray p = zeros(n, n); | ||||
|         final INDArray beta = ones(n, 1); | ||||
|         final double logU = Math.log(perplexity); | ||||
| 
 | ||||
|         INDArray sumX = pow(X, 2).sum(1); | ||||
| 
 | ||||
|         logger.debug("sumX shape: " + Arrays.toString(sumX.shape())); | ||||
| 
 | ||||
|         INDArray times = X.mmul(X.transpose()).muli(-2); | ||||
| 
 | ||||
|         logger.debug("times shape: " + Arrays.toString(times.shape())); | ||||
| 
 | ||||
|         INDArray prodSum = times.transpose().addiColumnVector(sumX); | ||||
| 
 | ||||
|         logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape())); | ||||
| 
 | ||||
|         INDArray D = X.mmul(X.transpose()).mul(-2) // thats times | ||||
|                         .transpose().addColumnVector(sumX) // thats prodSum | ||||
|                         .addRowVector(sumX.transpose()); // thats D | ||||
| 
 | ||||
|         logger.info("Calculating probabilities of data similarities..."); | ||||
|         logger.debug("Tolerance: " + tolerance); | ||||
|         for (int i = 0; i < n; i++) { | ||||
|             if (i % 500 == 0 && i > 0) | ||||
|                 logger.info("Handled [" + i + "] records out of [" + n + "]"); | ||||
| 
 | ||||
|             double betaMin = Double.NEGATIVE_INFINITY; | ||||
|             double betaMax = Double.POSITIVE_INFINITY; | ||||
|             int[] vals = Ints.concat(ArrayUtil.range(0, i), ArrayUtil.range(i + 1, n)); | ||||
|             INDArrayIndex[] range = new INDArrayIndex[] {new SpecifiedIndex(vals)}; | ||||
| 
 | ||||
|             INDArray row = D.slice(i).get(range); | ||||
|             Pair<Double, INDArray> pair = hBeta(row, beta.getDouble(i)); | ||||
|             //INDArray hDiff = pair.getFirst().sub(logU); | ||||
|             double hDiff = pair.getFirst() - logU; | ||||
|             int tries = 0; | ||||
| 
 | ||||
|             //while hdiff > tolerance | ||||
|             while (Math.abs(hDiff) > tolerance && tries < 50) { | ||||
|                 //if hdiff > 0 | ||||
|                 if (hDiff > 0) { | ||||
|                     betaMin = beta.getDouble(i); | ||||
|                     if (Double.isInfinite(betaMax)) | ||||
|                         beta.putScalar(i, beta.getDouble(i) * 2.0); | ||||
|                     else | ||||
|                         beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0); | ||||
|                 } else { | ||||
|                     betaMax = beta.getDouble(i); | ||||
|                     if (Double.isInfinite(betaMin)) | ||||
|                         beta.putScalar(i, beta.getDouble(i) / 2.0); | ||||
|                     else | ||||
|                         beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0); | ||||
|                 } | ||||
| 
 | ||||
|                 pair = hBeta(row, beta.getDouble(i)); | ||||
|                 hDiff = pair.getFirst() - logU; | ||||
|                 tries++; | ||||
|             } | ||||
|             p.slice(i).put(range, pair.getSecond()); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         //dont need data in memory after | ||||
|         logger.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE)); | ||||
|         BooleanIndexing.replaceWhere(p, 1e-12, Conditions.isNan()); | ||||
| 
 | ||||
|         //set 0 along the diagonal | ||||
|         INDArray permute = p.transpose(); | ||||
| 
 | ||||
|         INDArray pOut = p.add(permute); | ||||
| 
 | ||||
|         pOut.divi(pOut.sumNumber().doubleValue() + 1e-6); | ||||
| 
 | ||||
|         pOut.muli(4); | ||||
| 
 | ||||
|         BooleanIndexing.replaceWhere(pOut, 1e-12, Conditions.lessThan(1e-12)); | ||||
|         //ensure no nans | ||||
| 
 | ||||
|         return pOut; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static class Builder { | ||||
|         protected int maxIter = 1000; | ||||
|         protected double realMin = 1e-12f; | ||||
|         protected double initialMomentum = 5e-1f; | ||||
|         protected double finalMomentum = 8e-1f; | ||||
|         protected double momentum = 5e-1f; | ||||
|         protected int switchMomentumIteration = 100; | ||||
|         protected boolean normalize = true; | ||||
|         protected boolean usePca = false; | ||||
|         protected int stopLyingIteration = 100; | ||||
|         protected double tolerance = 1e-5f; | ||||
|         protected double learningRate = 1e-1f; | ||||
|         protected boolean useAdaGrad = false; | ||||
|         protected double perplexity = 30; | ||||
|         protected double minGain = 1e-1f; | ||||
| 
 | ||||
| 
 | ||||
|         public Builder minGain(double minGain) { | ||||
|             this.minGain = minGain; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder perplexity(double perplexity) { | ||||
|             this.perplexity = perplexity; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder useAdaGrad(boolean useAdaGrad) { | ||||
|             this.useAdaGrad = useAdaGrad; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder learningRate(double learningRate) { | ||||
|             this.learningRate = learningRate; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         public Builder tolerance(double tolerance) { | ||||
|             this.tolerance = tolerance; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder stopLyingIteration(int stopLyingIteration) { | ||||
|             this.stopLyingIteration = stopLyingIteration; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder usePca(boolean usePca) { | ||||
|             this.usePca = usePca; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder normalize(boolean normalize) { | ||||
|             this.normalize = normalize; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setMaxIter(int maxIter) { | ||||
|             this.maxIter = maxIter; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setRealMin(double realMin) { | ||||
|             this.realMin = realMin; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setInitialMomentum(double initialMomentum) { | ||||
|             this.initialMomentum = initialMomentum; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setFinalMomentum(double finalMomentum) { | ||||
|             this.finalMomentum = finalMomentum; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setMomentum(double momentum) { | ||||
|             this.momentum = momentum; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Builder setSwitchMomentumIteration(int switchMomentumIteration) { | ||||
|             this.switchMomentumIteration = switchMomentumIteration; | ||||
|             return this; | ||||
|         } | ||||
| 
 | ||||
|         public Tsne build() { | ||||
|             return new Tsne(maxIter, realMin, initialMomentum, finalMomentum, minGain, momentum, | ||||
|                             switchMomentumIteration, normalize, usePca, stopLyingIteration, tolerance, learningRate, | ||||
|                             useAdaGrad, perplexity); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,68 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.plot; | ||||
| 
 | ||||
| import lombok.val; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| 
 | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| public class Test6058 extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void test() throws Exception { | ||||
|         //All zero input -> cosine similarity isn't defined | ||||
|         //https://github.com/deeplearning4j/deeplearning4j/issues/6058 | ||||
|         val iterations = 10; | ||||
|         val cacheList = new ArrayList<String>(); | ||||
| 
 | ||||
|         int nWords  = 100; | ||||
|         for(int i=0; i<nWords; i++ ) { | ||||
|             cacheList.add("word_" + i); | ||||
|         } | ||||
| 
 | ||||
|         //STEP 3: build a dual-tree tsne to use later | ||||
|         System.out.println("Build model...."); | ||||
|         val tsne = new BarnesHutTsne.Builder() | ||||
|                 .setMaxIter(iterations) | ||||
|                 .theta(0.5) | ||||
|                 .normalize(false) | ||||
|                 .learningRate(1000) | ||||
|                 .useAdaGrad(false) | ||||
|                 //.usePca(false) | ||||
|                 .build(); | ||||
| 
 | ||||
|         System.out.println("fit"); | ||||
|         INDArray weights = Nd4j.rand(new int[]{nWords, 100}); | ||||
|         weights.getRow(1).assign(0); | ||||
|         try { | ||||
|             tsne.fit(weights); | ||||
|         } catch (IllegalStateException e){ | ||||
|             assertTrue(e.getMessage().contains("may not be defined")); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,87 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| //package org.deeplearning4j.plot; | ||||
| // | ||||
| //import lombok.extern.slf4j.Slf4j; | ||||
| //import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; | ||||
| //import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | ||||
| //import org.deeplearning4j.models.word2vec.wordstore.VocabCache; | ||||
| //import org.deeplearning4j.nn.conf.WorkspaceMode; | ||||
| //import org.junit.Test; | ||||
| //import org.nd4j.linalg.api.buffer.DataBuffer; | ||||
| //import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| //import org.nd4j.linalg.factory.Nd4j; | ||||
| //import org.nd4j.linalg.io.ClassPathResource; | ||||
| //import org.nd4j.linalg.primitives.Pair; | ||||
| // | ||||
| //import java.io.File; | ||||
| //import java.util.ArrayList; | ||||
| //import java.util.List; | ||||
| // | ||||
| //@Slf4j | ||||
| //public class TsneTest { | ||||
| // | ||||
| //    @Test | ||||
| //    public void testSimple() throws Exception { | ||||
| //        //Simple sanity check | ||||
| // | ||||
| //        for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}){ | ||||
| // | ||||
| //            //STEP 1: Initialization | ||||
| //            int iterations = 100; | ||||
| //            //create an n-dimensional array of doubles | ||||
| //            Nd4j.setDataType(DataType.DOUBLE); | ||||
| //            List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words | ||||
| // | ||||
| //            //STEP 2: Turn text input into a list of words | ||||
| //            log.info("Load & Vectorize data...."); | ||||
| //            File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile();   //Open the file | ||||
| //            //Get the data of all unique word vectors | ||||
| //            Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile); | ||||
| //            VocabCache cache = vectors.getSecond(); | ||||
| //            INDArray weights = vectors.getFirst().getSyn0();    //seperate weights of unique words into their own list | ||||
| // | ||||
| //            for(int i = 0; i < cache.numWords(); i++)   //seperate strings of words into their own list | ||||
| //                cacheList.add(cache.wordAtIndex(i)); | ||||
| // | ||||
| //            //STEP 3: build a dual-tree tsne to use later | ||||
| //            log.info("Build model...."); | ||||
| //            BarnesHutTsne tsne = new BarnesHutTsne.Builder() | ||||
| //                    .setMaxIter(iterations).theta(0.5) | ||||
| //                    .normalize(false) | ||||
| //                    .learningRate(500) | ||||
| //                    .useAdaGrad(false) | ||||
| //                    .workspaceMode(wsm) | ||||
| //                    .build(); | ||||
| // | ||||
| //            //STEP 4: establish the tsne values and save them to a file | ||||
| //            log.info("Store TSNE Coordinates for Plotting...."); | ||||
| //            String outputFile = "target/archive-tmp/tsne-standard-coords.csv"; | ||||
| //            (new File(outputFile)).getParentFile().mkdirs(); | ||||
| // | ||||
| //            tsne.fit(weights); | ||||
| //            tsne.saveAsFile(cacheList, outputFile); | ||||
| // | ||||
| // | ||||
| //        } | ||||
| //    } | ||||
| // | ||||
| //} | ||||
| @ -1,51 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>deeplearning4j-manifold</artifactId> | ||||
|     <packaging>pom</packaging> | ||||
| 
 | ||||
|     <name>deeplearning4j-manifold</name> | ||||
| 
 | ||||
|     <modules> | ||||
|         <module>deeplearning4j-tsne</module> | ||||
|     </modules> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -127,6 +127,51 @@ | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <inherited>true</inherited> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.nd4j</groupId> | ||||
|                                 <artifactId>nd4j-native</artifactId> | ||||
|                                 <version>${project.version}</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <environmentVariables> | ||||
| 
 | ||||
|                             </environmentVariables> | ||||
|                             <testSourceDirectory>src/test/java</testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>*.java</include> | ||||
|                                 <include>**/*.java</include> | ||||
|                                 <include>**/Test*.java</include> | ||||
|                                 <include>**/*Test.java</include> | ||||
|                                 <include>**/*TestCase.java</include> | ||||
|                             </includes> | ||||
|                             <junitArtifactName>junit:junit</junitArtifactName> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 8g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
| 
 | ||||
|                                 For testing large zoo models, this may not be enough (so comment it out). | ||||
|                             --> | ||||
|                             <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
| @ -138,6 +183,47 @@ | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.apache.maven.surefire</groupId> | ||||
|                                 <artifactId>surefire-junit47</artifactId> | ||||
|                                 <version>2.19.1</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <environmentVariables> | ||||
|                             </environmentVariables> | ||||
|                             <testSourceDirectory>src/test/java</testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>*.java</include> | ||||
|                                 <include>**/*.java</include> | ||||
|                                 <include>**/Test*.java</include> | ||||
|                                 <include>**/*Test.java</include> | ||||
|                                 <include>**/*TestCase.java</include> | ||||
|                             </includes> | ||||
|                             <junitArtifactName>junit:junit</junitArtifactName> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
|                             --> | ||||
|                             <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
|  | ||||
| @ -1001,7 +1001,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { | ||||
| 
 | ||||
|         for (Layer l : netToTest.getLayers()) { | ||||
|             // Remove any dropout manually - until this is fixed: | ||||
|             // https://github.com/deeplearning4j/deeplearning4j/issues/4368 | ||||
|             // https://github.com/eclipse/deeplearning4j/issues/4368 | ||||
|              l.conf().getLayer().setIDropout(null); | ||||
| 
 | ||||
|             //Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable... | ||||
|  | ||||
| @ -22,7 +22,6 @@ package org.deeplearning4j.models.embeddings; | ||||
| 
 | ||||
| import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; | ||||
| import org.deeplearning4j.models.word2vec.wordstore.VocabCache; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.deeplearning4j.core.ui.UiConnectionInfo; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| @ -74,27 +73,7 @@ public interface WeightLookupTable<T extends SequenceElement> extends Serializab | ||||
|      */ | ||||
|     void resetWeights(boolean reset); | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via TSNE | ||||
|      * @param tsne the tsne to use | ||||
|      */ | ||||
|     void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo); | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via TSNE | ||||
|      * @param tsne the tsne to use | ||||
|      */ | ||||
|     void plotVocab(BarnesHutTsne tsne, int numWords, File file); | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via tsne | ||||
|      */ | ||||
|     void plotVocab(int numWords, UiConnectionInfo connectionInfo); | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via tsne | ||||
|      */ | ||||
|     void plotVocab(int numWords, File file); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|  | ||||
| @ -29,7 +29,6 @@ import org.deeplearning4j.models.embeddings.WeightLookupTable; | ||||
| import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; | ||||
| import org.deeplearning4j.models.word2vec.Word2Vec; | ||||
| import org.deeplearning4j.models.word2vec.wordstore.VocabCache; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.deeplearning4j.core.ui.UiConnectionInfo; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -154,123 +153,8 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo | ||||
|         initNegative(); | ||||
|     } | ||||
| 
 | ||||
|     private List<String> fitTnseAndGetLabels(final BarnesHutTsne tsne, final int numWords) { | ||||
|         INDArray array = Nd4j.create(numWords, vectorLength); | ||||
|         List<String> labels = new ArrayList<>(); | ||||
|         for (int i = 0; i < numWords && i < vocab.numWords(); i++) { | ||||
|             labels.add(vocab.wordAtIndex(i)); | ||||
|             array.putRow(i, syn0.slice(i)); | ||||
|         } | ||||
|         tsne.fit(array); | ||||
|         return labels; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public void plotVocab(BarnesHutTsne tsne, int numWords, File file) { | ||||
|         final List<String> labels = fitTnseAndGetLabels(tsne, numWords); | ||||
|         try { | ||||
|             tsne.saveAsFile(labels, file.getAbsolutePath()); | ||||
|         } catch (IOException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via tsne | ||||
|      */ | ||||
|     @Override | ||||
|     public void plotVocab(int numWords, File file) { | ||||
|         BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2) | ||||
|                         .setMaxIter(1000).build(); | ||||
|         plotVocab(tsne, numWords, file); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via tsne | ||||
|      */ | ||||
|     @Override | ||||
|     public void plotVocab(int numWords, UiConnectionInfo connectionInfo) { | ||||
|         BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2) | ||||
|                         .setMaxIter(1000).build(); | ||||
|         plotVocab(tsne, numWords, connectionInfo); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Render the words via TSNE | ||||
|      * | ||||
|      * @param tsne           the tsne to use | ||||
|      * @param numWords | ||||
|      * @param connectionInfo | ||||
|      */ | ||||
|     @Override | ||||
|     public void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo) { | ||||
|         try { | ||||
|             final List<String> labels = fitTnseAndGetLabels(tsne, numWords); | ||||
|             final INDArray reducedData = tsne.getData(); | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             for (int i = 0; i < reducedData.rows() && i < numWords; i++) { | ||||
|                 String word = labels.get(i); | ||||
|                 INDArray wordVector = reducedData.getRow(i); | ||||
|                 for (int j = 0; j < wordVector.length(); j++) { | ||||
|                     sb.append(String.valueOf(wordVector.getDouble(j))).append(","); | ||||
|                 } | ||||
|                 sb.append(word); | ||||
|             } | ||||
| 
 | ||||
|             String address = connectionInfo.getFirstPart() + "/tsne/post/" + connectionInfo.getSessionId(); | ||||
|             //            System.out.println("ADDRESS: " + address); | ||||
|             URI uri = new URI(address); | ||||
| 
 | ||||
|             HttpURLConnection connection = (HttpURLConnection) uri.toURL().openConnection(); | ||||
|             connection.setRequestMethod("POST"); | ||||
|             connection.setRequestProperty("User-Agent", "Mozilla/5.0"); | ||||
|             //            connection.setRequestProperty("Content-Type", "application/json"); | ||||
|             connection.setRequestProperty("Content-Type", "multipart/form-data; boundary=-----TSNE-POST-DATA-----"); | ||||
|             connection.setDoOutput(true); | ||||
| 
 | ||||
|             final OutputStream outputStream = connection.getOutputStream(); | ||||
|             final PrintWriter writer = new PrintWriter(outputStream); | ||||
|             writer.println("-------TSNE-POST-DATA-----"); | ||||
|             writer.println("Content-Disposition: form-data; name=\"fileupload\"; filename=\"tsne.csv\""); | ||||
|             writer.println("Content-Type: text/plain; charset=UTF-16"); | ||||
|             writer.println("Content-Transfer-Encoding: binary"); | ||||
|             writer.println(); | ||||
|             writer.flush(); | ||||
| 
 | ||||
|             DataOutputStream dos = new DataOutputStream(outputStream); | ||||
|             dos.writeBytes(sb.toString()); | ||||
|             dos.flush(); | ||||
|             writer.println(); | ||||
|             writer.flush(); | ||||
|             dos.close(); | ||||
|             outputStream.close(); | ||||
| 
 | ||||
|             try { | ||||
|                 int responseCode = connection.getResponseCode(); | ||||
|                 System.out.println("RESPONSE CODE: " + responseCode); | ||||
| 
 | ||||
|                 if (responseCode != 200) { | ||||
|                     BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream())); | ||||
|                     String inputLine; | ||||
|                     StringBuilder response = new StringBuilder(); | ||||
| 
 | ||||
|                     while ((inputLine = in.readLine()) != null) { | ||||
|                         response.append(inputLine); | ||||
|                     } | ||||
|                     in.close(); | ||||
| 
 | ||||
|                     log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, | ||||
|                                     response.toString()); | ||||
|                 } | ||||
|             } catch (IOException e) { | ||||
|                 log.warn("Error posting to remote UI at {}", uri, e); | ||||
|             } | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param codeIndex | ||||
|      * @param code | ||||
|  | ||||
| @ -26,7 +26,6 @@ import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; | ||||
| import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | ||||
| import org.deeplearning4j.models.word2vec.wordstore.VocabCache; | ||||
| import org.deeplearning4j.nn.conf.WorkspaceMode; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| @ -62,152 +61,4 @@ public class TsneTest extends BaseDL4JTest { | ||||
|         return DataType.FLOAT; | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSimple() throws Exception { | ||||
|         //Simple sanity check | ||||
| 
 | ||||
|         for( int test=0; test <=1; test++){ | ||||
|             boolean syntheticData = test == 1; | ||||
|             WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED; | ||||
|             log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); | ||||
| 
 | ||||
|             //STEP 1: Initialization | ||||
|             int iterations = 50; | ||||
|             //create an n-dimensional array of doubles | ||||
|             Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); | ||||
|             List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words | ||||
| 
 | ||||
|             //STEP 2: Turn text input into a list of words | ||||
|             INDArray weights; | ||||
|             if(syntheticData){ | ||||
|                 weights = Nd4j.rand(250, 200); | ||||
|             } else { | ||||
|                 log.info("Load & Vectorize data...."); | ||||
|                 File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile();   //Open the file | ||||
|                 //Get the data of all unique word vectors | ||||
|                 Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile); | ||||
|                 VocabCache cache = vectors.getSecond(); | ||||
|                 weights = vectors.getFirst().getSyn0();    //seperate weights of unique words into their own list | ||||
| 
 | ||||
|                 for (int i = 0; i < cache.numWords(); i++)   //seperate strings of words into their own list | ||||
|                     cacheList.add(cache.wordAtIndex(i)); | ||||
|             } | ||||
| 
 | ||||
|             //STEP 3: build a dual-tree tsne to use later | ||||
|             log.info("Build model...."); | ||||
|             BarnesHutTsne tsne = new BarnesHutTsne.Builder() | ||||
|                     .setMaxIter(iterations) | ||||
|                     .theta(0.5) | ||||
|                     .normalize(false) | ||||
|                     .learningRate(500) | ||||
|                     .useAdaGrad(false) | ||||
|                     .workspaceMode(wsm) | ||||
|                     .build(); | ||||
| 
 | ||||
| 
 | ||||
|             //STEP 4: establish the tsne values and save them to a file | ||||
|             log.info("Store TSNE Coordinates for Plotting...."); | ||||
|             File outDir = testDir.newFolder(); | ||||
|             tsne.fit(weights); | ||||
|             tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testPerformance() throws Exception { | ||||
| 
 | ||||
|         StopWatch watch = new StopWatch(); | ||||
|         watch.start(); | ||||
|         for( int test=0; test <=1; test++){ | ||||
|             boolean syntheticData = test == 1; | ||||
|             WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED; | ||||
|             log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); | ||||
| 
 | ||||
|             //STEP 1: Initialization | ||||
|             int iterations = 50; | ||||
|             //create an n-dimensional array of doubles | ||||
|             Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); | ||||
|             List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words | ||||
| 
 | ||||
|             //STEP 2: Turn text input into a list of words | ||||
|             INDArray weights; | ||||
|             if(syntheticData){ | ||||
|                 weights = Nd4j.rand(DataType.FLOAT, 250, 20); | ||||
|             } else { | ||||
|                 log.info("Load & Vectorize data...."); | ||||
|                 File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile();   //Open the file | ||||
|                 //Get the data of all unique word vectors | ||||
|                 Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile); | ||||
|                 VocabCache cache = vectors.getSecond(); | ||||
|                 weights = vectors.getFirst().getSyn0();    //seperate weights of unique words into their own list | ||||
| 
 | ||||
|                 for (int i = 0; i < cache.numWords(); i++)   //seperate strings of words into their own list | ||||
|                     cacheList.add(cache.wordAtIndex(i)); | ||||
|             } | ||||
| 
 | ||||
|             //STEP 3: build a dual-tree tsne to use later | ||||
|             log.info("Build model...."); | ||||
|             BarnesHutTsne tsne = new BarnesHutTsne.Builder() | ||||
|                     .setMaxIter(iterations) | ||||
|                     .theta(0.5) | ||||
|                     .normalize(false) | ||||
|                     .learningRate(500) | ||||
|                     .useAdaGrad(false) | ||||
|                     .workspaceMode(wsm) | ||||
|                     .build(); | ||||
| 
 | ||||
| 
 | ||||
|             //STEP 4: establish the tsne values and save them to a file | ||||
|             log.info("Store TSNE Coordinates for Plotting...."); | ||||
|             File outDir = testDir.newFolder(); | ||||
|             tsne.fit(weights); | ||||
|             tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); | ||||
|         } | ||||
|         watch.stop(); | ||||
|         System.out.println("Elapsed time : " + watch); | ||||
|     } | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void testTSNEPerformance() throws Exception { | ||||
| 
 | ||||
|             for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { | ||||
| 
 | ||||
|                 //STEP 1: Initialization | ||||
|                 int iterations = 50; | ||||
|                 //create an n-dimensional array of doubles | ||||
|                 Nd4j.setDataType(DataType.DOUBLE); | ||||
|                 List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words | ||||
| 
 | ||||
|                 //STEP 2: Turn text input into a list of words | ||||
|                 INDArray weights = Nd4j.rand(10000,300); | ||||
| 
 | ||||
|                 StopWatch watch = new StopWatch(); | ||||
|                 watch.start(); | ||||
|                 //STEP 3: build a dual-tree tsne to use later | ||||
|                 log.info("Build model...."); | ||||
|                 BarnesHutTsne tsne = new BarnesHutTsne.Builder() | ||||
|                         .setMaxIter(iterations) | ||||
|                         .theta(0.5) | ||||
|                         .normalize(false) | ||||
|                         .learningRate(500) | ||||
|                         .useAdaGrad(false) | ||||
|                         .workspaceMode(wsm) | ||||
|                         .build(); | ||||
| 
 | ||||
|                 watch.stop(); | ||||
|                 System.out.println("Elapsed time for construction: " + watch); | ||||
| 
 | ||||
|                 //STEP 4: establish the tsne values and save them to a file | ||||
|                 log.info("Store TSNE Coordinates for Plotting...."); | ||||
|                 File outDir = testDir.newFolder(); | ||||
| 
 | ||||
|                 watch.reset(); | ||||
|                 watch.start(); | ||||
|                 tsne.fit(weights); | ||||
|                 watch.stop(); | ||||
|                 System.out.println("Elapsed time for fit: " + watch); | ||||
|                 tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); | ||||
|             } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.iterator; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import lombok.Getter; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; | ||||
| @ -57,9 +58,11 @@ public class TestBertIterator extends BaseDL4JTest { | ||||
|     public TestBertIterator() throws IOException { | ||||
|     } | ||||
| 
 | ||||
|     @Test(timeout = 20000L) | ||||
|     @Test() | ||||
|     public void testBertSequenceClassification() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         int minibatchSize = 2; | ||||
|         TestSentenceHelper testHelper = new TestSentenceHelper(); | ||||
|         BertIterator b = BertIterator.builder() | ||||
| @ -308,6 +311,9 @@ public class TestBertIterator extends BaseDL4JTest { | ||||
|      */ | ||||
|     @Test | ||||
|     public void testSentencePairsSingle() throws IOException { | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         boolean prependAppend; | ||||
|         int numOfSentences; | ||||
| 
 | ||||
| @ -367,7 +373,9 @@ public class TestBertIterator extends BaseDL4JTest { | ||||
|     */ | ||||
|     @Test | ||||
|     public void testSentencePairsUnequalLengths() throws IOException { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         int minibatchSize = 4; | ||||
|         int numOfSentencesinIter = 3; | ||||
| 
 | ||||
| @ -456,6 +464,9 @@ public class TestBertIterator extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSentencePairFeaturizer() throws IOException { | ||||
|         if(Platform.isWindows()) { | ||||
|             return; | ||||
|         } | ||||
|         int minibatchSize = 2; | ||||
|         TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize); | ||||
|         BertIterator b = BertIterator.builder() | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | ||||
| import org.deeplearning4j.models.word2vec.Word2Vec; | ||||
| import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; | ||||
| import org.deeplearning4j.text.sentenceiterator.SentenceIterator; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| @ -43,6 +44,7 @@ import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore | ||||
| public class FastTextTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Rule | ||||
|  | ||||
| @ -23,7 +23,6 @@ package org.deeplearning4j.models.word2vec; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | ||||
| import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.junit.Before; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| @ -40,11 +39,5 @@ public class Word2VecVisualizationTests extends BaseDL4JTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBarnesHutTsneVisualization() throws Exception { | ||||
|         BarnesHutTsne tsne = new BarnesHutTsne.Builder().setMaxIter(4).stopLyingIteration(250).learningRate(500) | ||||
|                         .useAdaGrad(false).theta(0.5).setMomentum(0.5).normalize(true).build(); | ||||
| 
 | ||||
|         //vectors.lookupTable().plotVocab(tsne); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -32,6 +32,7 @@ import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIte | ||||
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | ||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| @ -56,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { | ||||
|      * Basically all we want from this test - being able to finish without exceptions. | ||||
|      */ | ||||
|     @Test | ||||
|     @Ignore | ||||
|     public void testIterator1() throws Exception { | ||||
| 
 | ||||
|         File inputFile = Resources.asFile("big/raw_sentences.txt"); | ||||
|  | ||||
| @ -42,6 +42,7 @@ import java.util.List; | ||||
| import static org.junit.Assert.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore | ||||
| public class BertWordPieceTokenizerTests extends BaseDL4JTest { | ||||
| 
 | ||||
|     private File pathToVocab =  Resources.asFile("other/vocab.txt"); | ||||
|  | ||||
| @ -71,7 +71,7 @@ public class LocalResponseNormalization | ||||
|                     dataType); | ||||
|             log.debug("CudnnLocalResponseNormalizationHelper successfully initialized"); | ||||
|         } | ||||
|         //2019-03-09 AB - MKL-DNN helper disabled: https://github.com/deeplearning4j/deeplearning4j/issues/7272 | ||||
|         //2019-03-09 AB - MKL-DNN helper disabled: https://github.com/eclipse/deeplearning4j/issues/7272 | ||||
| //        else if("CPU".equalsIgnoreCase(backend)){ | ||||
| //            helper = new MKLDNNLocalResponseNormalizationHelper(); | ||||
| //            log.debug("Created MKLDNNLocalResponseNormalizationHelper"); | ||||
|  | ||||
| @ -953,7 +953,7 @@ public class ModelSerializer { | ||||
| 
 | ||||
| 
 | ||||
|     private static void checkInputStream(InputStream inputStream) throws IOException { | ||||
|         //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887 | ||||
|         //available method can return 0 in some cases: https://github.com/eclipse/deeplearning4j/issues/4887 | ||||
|         int available; | ||||
|         try{ | ||||
|             //InputStream.available(): A subclass' implementation of this method may choose to throw an IOException | ||||
|  | ||||
| @ -370,7 +370,7 @@ public class NetworkUtils { | ||||
|         final String message; | ||||
|         if (model.getClass().getName().startsWith("org.deeplearning4j")) { | ||||
|             message = model.getClass().getName() + " models are not yet supported and " + | ||||
|                     "pull requests are welcome: https://github.com/deeplearning4j/deeplearning4j"; | ||||
|                     "pull requests are welcome: https://github.com/eclipse/deeplearning4j"; | ||||
|         } else { | ||||
|             message = model.getClass().getName() + " models are unsupported."; | ||||
|         } | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.models.sequencevectors; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.SparkConf; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| @ -87,6 +88,11 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFrequenciesCount() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic); | ||||
| 
 | ||||
|         SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>(); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.models.embeddings.word2vec; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.SparkConf; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| @ -54,6 +55,10 @@ public class Word2VecTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testConcepts() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         // These are all default values for word2vec | ||||
|         SparkConf sparkConf = new SparkConf().setMaster("local[8]") | ||||
|                 .set("spark.driver.host", "localhost") | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.text; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.SparkConf; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -94,6 +95,10 @@ public class TextPipelineTest extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTokenizer() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
|         JavaRDD<String> corpusRDD = getCorpusRDD(sc); | ||||
|         Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.parameterserver.accumulation; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -33,6 +34,10 @@ public class SharedTrainingAccumulationFunctionTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testAccumulation1() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         INDArray updates1 = Nd4j.create(1000).assign(1.0); | ||||
|         INDArray updates2 = Nd4j.create(1000).assign(2.0); | ||||
|         INDArray expUpdates = Nd4j.create(1000).assign(3.0); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.parameterserver.accumulation; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult; | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| @ -36,6 +37,10 @@ public class SharedTrainingAggregateFunctionTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testAggregate1() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         INDArray updates1 = Nd4j.create(1000).assign(1.0); | ||||
|         INDArray updates2 = Nd4j.create(1000).assign(2.0); | ||||
|         INDArray expUpdates = Nd4j.create(1000).assign(3.0); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.parameterserver.iterators; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -39,6 +40,10 @@ public class VirtualDataSetIteratorTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSimple1() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<Iterator<DataSet>> iterators = new ArrayList<>(); | ||||
| 
 | ||||
|         List<DataSet> first = new ArrayList<>(); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.parameterserver.iterators; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| @ -36,6 +37,10 @@ public class VirtualIteratorTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIteration1() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<Integer> integers = new ArrayList<>(); | ||||
|         for (int i = 0; i < 100; i++) { | ||||
|             integers.add(i); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.parameterserver.modelimport.elephas; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; | ||||
| import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | ||||
| @ -40,6 +41,10 @@ public class TestElephasImport extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testElephasSequentialImport() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         String modelPath = "modelimport/elephas/elephas_sequential.h5"; | ||||
|         SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath); | ||||
|         // System.out.println(model.getNetwork().summary()); | ||||
| @ -48,7 +53,11 @@ public class TestElephasImport extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testElephasSequentialImportAsync() throws Exception { | ||||
|         String modelPath = "modelimport/elephas/elephas_sequential_async.h5"; | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|        String modelPath = "modelimport/elephas/elephas_sequential_async.h5"; | ||||
|         SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath); | ||||
|         // System.out.println(model.getNetwork().summary()); | ||||
|         assertTrue(model.getTrainingMaster() instanceof SharedTrainingMaster); | ||||
|  | ||||
| @ -0,0 +1,38 @@ | ||||
| # | ||||
| # /* ****************************************************************************** | ||||
| #  * | ||||
| #  * | ||||
| #  * This program and the accompanying materials are made available under the | ||||
| #  * terms of the Apache License, Version 2.0 which is available at | ||||
| #  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
| #  * | ||||
| #  *  See the NOTICE file distributed with this work for additional | ||||
| #  *  information regarding copyright ownership. | ||||
| #  * Unless required by applicable law or agreed to in writing, software | ||||
| #  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
| #  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
| #  * License for the specific language governing permissions and limitations | ||||
| #  * under the License. | ||||
| #  * | ||||
| #  * SPDX-License-Identifier: Apache-2.0 | ||||
| #  ******************************************************************************/ | ||||
| # | ||||
| 
 | ||||
| real.class.double = org.nd4j.linalg.cpu.NDArray | ||||
| shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider | ||||
| constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache | ||||
| affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager | ||||
| memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager | ||||
| dtype = float | ||||
| blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper | ||||
| 
 | ||||
| native.ops= org.nd4j.nativeblas.Nd4jCpu | ||||
| ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory | ||||
| ndarray.order = c | ||||
| resourcemanager_state = false | ||||
| databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory | ||||
| workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager | ||||
| alloc = javacpp | ||||
| opexec= org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner | ||||
| opexec.mode= native | ||||
| random=org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; | ||||
| @ -63,6 +64,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEarlyStoppingIris() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() | ||||
| @ -113,7 +118,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testBadTuning() { | ||||
|         //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -150,7 +158,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testTimeTermination() { | ||||
|         //test termination after max time | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -193,7 +204,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | ||||
|     public void testNoImprovementNEpochsTermination() { | ||||
|         //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs | ||||
|         //Simulate this by setting LR = 0.0 | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -228,6 +242,10 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testListeners() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; | ||||
| @ -66,6 +67,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEarlyStoppingIris() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") | ||||
| @ -114,7 +119,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testBadTuning() { | ||||
|         //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -152,7 +160,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testTimeTermination() { | ||||
|         //test termination after max time | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -197,7 +208,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | ||||
|     public void testNoImprovementNEpochsTermination() { | ||||
|         //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs | ||||
|         //Simulate this by setting LR = 0.0 | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
| @ -235,6 +249,10 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testListeners() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.datavec; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import lombok.val; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| import org.apache.hadoop.io.Text; | ||||
| @ -68,6 +69,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDataVecDataSetFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
| 
 | ||||
|         File f = testDir.newFolder(); | ||||
| @ -178,6 +183,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDataVecSequenceDataSetFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
|         //Test Spark record reader functionality vs. local | ||||
|         File dir = testDir.newFolder(); | ||||
| @ -236,6 +245,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDataVecSequencePairDataSetFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
| 
 | ||||
|         File f = testDir.newFolder(); | ||||
| @ -332,7 +345,10 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testDataVecSequencePairDataSetFunctionVariableLength() throws Exception { | ||||
|         //Same sort of test as testDataVecSequencePairDataSetFunction() but with variable length time series (labels shorter, align end) | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         File dirFeatures = testDir.newFolder(); | ||||
|         ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); | ||||
|         cpr.copyDirectory(dirFeatures); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.datavec; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -44,6 +45,10 @@ public class TestExport extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchAndExportDataSetsFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         String baseDir = System.getProperty("java.io.tmpdir"); | ||||
|         baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/"); | ||||
|         baseDir = baseDir.replaceAll("\\\\", "/"); | ||||
| @ -102,6 +107,10 @@ public class TestExport extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchAndExportMultiDataSetsFunction() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         String baseDir = System.getProperty("java.io.tmpdir"); | ||||
|         baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/"); | ||||
|         baseDir = baseDir.replaceAll("\\\\", "/"); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.datavec; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| @ -63,6 +64,10 @@ public class TestPreProcessedData extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testPreprocessedData() { | ||||
|         //Test _loading_ of preprocessed data | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 10; | ||||
| 
 | ||||
| @ -109,6 +114,10 @@ public class TestPreProcessedData extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testPreprocessedDataCompGraphDataSet() { | ||||
|         //Test _loading_ of preprocessed DataSet data | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 10; | ||||
| 
 | ||||
| @ -157,6 +166,10 @@ public class TestPreProcessedData extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testPreprocessedDataCompGraphMultiDataSet() throws IOException { | ||||
|         //Test _loading_ of preprocessed MultiDataSet data | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 10; | ||||
| 
 | ||||
| @ -206,6 +219,10 @@ public class TestPreProcessedData extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCsvPreprocessedDataGeneration() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         DataSetIterator iter = new IrisDataSetIterator(1, 150); | ||||
|         while (iter.hasNext()) { | ||||
| @ -292,6 +309,10 @@ public class TestPreProcessedData extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testCsvPreprocessedDataGenerationNoLabel() throws Exception { | ||||
|         //Same as above test, but without any labels (in which case: input and output arrays are the same) | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         DataSetIterator iter = new IrisDataSetIterator(1, 150); | ||||
|         while (iter.hasNext()) { | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.impl.customlayer; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| @ -44,6 +45,10 @@ public class TestCustomLayer extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSparkWithCustomLayer() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Basic test - checks whether exceptions etc are thrown with custom layers + spark | ||||
|         //Custom layers are tested more extensively in dl4j core | ||||
|         MultiLayerConfiguration conf = | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.impl.multilayer; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||||
| @ -69,6 +70,10 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEvaluationSimple() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
| 
 | ||||
|         for( int evalWorkers : new int[]{1, 4, 8}) { | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.impl.paramavg; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.SparkConf; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| @ -65,57 +66,57 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
|     private static MultiLayerConfiguration getConf(int seed, IUpdater updater) { | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() | ||||
|                         .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder() | ||||
|                                         .lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) | ||||
|                         .build(); | ||||
|                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                 .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() | ||||
|                 .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder() | ||||
|                         .lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) | ||||
|                 .build(); | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
|     private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) { | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() | ||||
|                         .layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) | ||||
|                                         .activation(Activation.TANH).build()) | ||||
|                         .layer(1, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) | ||||
|                                         .activation(Activation.TANH).build()) | ||||
|                         .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) | ||||
|                                         .build()) | ||||
|                         .setInputType(InputType.convolutional(10, 10, 3)).build(); | ||||
|                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                 .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() | ||||
|                 .layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) | ||||
|                         .activation(Activation.TANH).build()) | ||||
|                 .layer(1, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) | ||||
|                         .activation(Activation.TANH).build()) | ||||
|                 .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) | ||||
|                         .build()) | ||||
|                 .setInputType(InputType.convolutional(10, 10, 3)).build(); | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
|     private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) { | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() | ||||
|                         .addInputs("in") | ||||
|                         .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", | ||||
|                                         new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10) | ||||
|                                                         .nOut(10).build(), | ||||
|                                         "0") | ||||
|                         .setOutputs("1").build(); | ||||
|                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                 .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() | ||||
|                 .addInputs("in") | ||||
|                 .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", | ||||
|                         new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10) | ||||
|                                 .nOut(10).build(), | ||||
|                         "0") | ||||
|                 .setOutputs("1").build(); | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
|     private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) { | ||||
|         Nd4j.getRandom().setSeed(seed); | ||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() | ||||
|                         .addInputs("in") | ||||
|                         .addLayer("0", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1) | ||||
|                                         .padding(0, 0).activation(Activation.TANH).build(), "in") | ||||
|                         .addLayer("1", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1) | ||||
|                                         .padding(0, 0).activation(Activation.TANH).build(), "0") | ||||
|                         .addLayer("2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) | ||||
|                                         .build(), "1") | ||||
|                         .setOutputs("2").setInputTypes(InputType.convolutional(10, 10, 3)) | ||||
|                         .build(); | ||||
|                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                 .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() | ||||
|                 .addInputs("in") | ||||
|                 .addLayer("0", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1) | ||||
|                         .padding(0, 0).activation(Activation.TANH).build(), "in") | ||||
|                 .addLayer("1", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1) | ||||
|                         .padding(0, 0).activation(Activation.TANH).build(), "0") | ||||
|                 .addLayer("2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) | ||||
|                         .build(), "1") | ||||
|                 .setOutputs("2").setInputTypes(InputType.convolutional(10, 10, 3)) | ||||
|                 .build(); | ||||
|         return conf; | ||||
|     } | ||||
| 
 | ||||
| @ -125,8 +126,8 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     private static TrainingMaster getTrainingMaster(int avgFreq, int miniBatchSize, boolean saveUpdater) { | ||||
|         ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1) | ||||
|                         .averagingFrequency(avgFreq).batchSizePerWorker(miniBatchSize).saveUpdater(saveUpdater) | ||||
|                         .aggregationDepth(2).workerPrefetchNumBatches(0).build(); | ||||
|                 .averagingFrequency(avgFreq).batchSizePerWorker(miniBatchSize).saveUpdater(saveUpdater) | ||||
|                 .aggregationDepth(2).workerPrefetchNumBatches(0).build(); | ||||
|         return tm; | ||||
|     } | ||||
| 
 | ||||
| @ -174,6 +175,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testOneExecutor() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Idea: single worker/executor on Spark should give identical results to a single machine | ||||
| 
 | ||||
|         int miniBatchSize = 10; | ||||
| @ -224,6 +229,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testOneExecutorGraph() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Idea: single worker/executor on Spark should give identical results to a single machine | ||||
| 
 | ||||
|         int miniBatchSize = 10; | ||||
| @ -251,7 +260,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
|                 //Do training on Spark with one executor, for 3 separate minibatches | ||||
|                 TrainingMaster tm = getTrainingMaster(1, miniBatchSize, saveUpdater); | ||||
|                 SparkComputationGraph sparkNet = | ||||
|                                 new SparkComputationGraph(sc, getGraphConf(12345, new RmsProp(0.5)), tm); | ||||
|                         new SparkComputationGraph(sc, getGraphConf(12345, new RmsProp(0.5)), tm); | ||||
|                 sparkNet.setCollectTrainingStats(true); | ||||
|                 INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); | ||||
| 
 | ||||
| @ -312,10 +321,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
|                 //Do training on Spark with one executor, for 3 separate minibatches | ||||
|                 //                TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater); | ||||
|                 ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1) | ||||
|                                 .averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker) | ||||
|                                 .saveUpdater(saveUpdater).workerPrefetchNumBatches(0) | ||||
|                                 //                        .rddTrainingApproach(RDDTrainingApproach.Direct) | ||||
|                                 .rddTrainingApproach(RDDTrainingApproach.Export).build(); | ||||
|                         .averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker) | ||||
|                         .saveUpdater(saveUpdater).workerPrefetchNumBatches(0) | ||||
|                         //                        .rddTrainingApproach(RDDTrainingApproach.Direct) | ||||
|                         .rddTrainingApproach(RDDTrainingApproach.Export).build(); | ||||
|                 SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConf(12345, new Sgd(0.5)), tm); | ||||
|                 sparkNet.setCollectTrainingStats(true); | ||||
|                 INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); | ||||
| @ -355,6 +364,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testAverageEveryStepCNN() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning | ||||
|         // on a single machine for synchronous distributed training | ||||
|         //BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if | ||||
| @ -387,16 +400,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|                 //Do training on Spark with one executor, for 3 separate minibatches | ||||
|                 ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1) | ||||
|                                 .averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker) | ||||
|                                 .saveUpdater(saveUpdater).workerPrefetchNumBatches(0) | ||||
|                                 .rddTrainingApproach(RDDTrainingApproach.Export).build(); | ||||
|                         .averagingFrequency(1).batchSizePerWorker(miniBatchSizePerWorker) | ||||
|                         .saveUpdater(saveUpdater).workerPrefetchNumBatches(0) | ||||
|                         .rddTrainingApproach(RDDTrainingApproach.Export).build(); | ||||
|                 SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConfCNN(12345, new Sgd(0.5)), tm); | ||||
|                 sparkNet.setCollectTrainingStats(true); | ||||
|                 INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); | ||||
| 
 | ||||
|                 for (int i = 0; i < seeds.length; i++) { | ||||
|                     List<DataSet> list = | ||||
|                                     getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); | ||||
|                             getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); | ||||
|                     JavaRDD<DataSet> rdd = sc.parallelize(list); | ||||
| 
 | ||||
|                     sparkNet.fit(rdd); | ||||
| @ -427,6 +440,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testAverageEveryStepGraph() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning | ||||
|         // on a single machine for synchronous distributed training | ||||
|         //BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if | ||||
| @ -506,6 +523,10 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testAverageEveryStepGraphCNN() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning | ||||
|         // on a single machine for synchronous distributed training | ||||
|         //BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if | ||||
| @ -544,7 +565,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { | ||||
| 
 | ||||
|                 for (int i = 0; i < seeds.length; i++) { | ||||
|                     List<DataSet> list = | ||||
|                                     getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); | ||||
|                             getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); | ||||
|                     JavaRDD<DataSet> rdd = sc.parallelize(list); | ||||
| 
 | ||||
|                     sparkNet.fit(rdd); | ||||
|  | ||||
| @ -21,6 +21,7 @@ | ||||
| package org.deeplearning4j.spark.impl.paramavg; | ||||
| 
 | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.hadoop.conf.Configuration; | ||||
| import org.apache.hadoop.fs.FileSystem; | ||||
| import org.apache.hadoop.fs.LocatedFileStatus; | ||||
| @ -113,6 +114,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFromSvmLightBackprop() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaRDD<LabeledPoint> data = MLUtils | ||||
|                         .loadLibSVMFile(sc.sc(), | ||||
|                                         new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive() | ||||
| @ -145,6 +150,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFromSvmLight() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaRDD<LabeledPoint> data = MLUtils | ||||
|                         .loadLibSVMFile(sc.sc(), | ||||
|                                         new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive() | ||||
| @ -175,7 +184,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRunIteration() { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         DataSet dataSet = new IrisDataSetIterator(5, 5).next(); | ||||
|         List<DataSet> list = dataSet.asList(); | ||||
|         JavaRDD<DataSet> data = sc.parallelize(list); | ||||
| @ -195,6 +207,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testUpdaters() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         SparkDl4jMultiLayer sparkNet = getBasicNetwork(); | ||||
|         MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); | ||||
| 
 | ||||
| @ -217,7 +233,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEvaluation() { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         SparkDl4jMultiLayer sparkNet = getBasicNetwork(); | ||||
|         MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); | ||||
| 
 | ||||
| @ -250,7 +269,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
|     public void testSmallAmountOfData() { | ||||
|         //Idea: Test spark training where some executors don't get any data | ||||
|         //in this case: by having fewer examples (2 DataSets) than executors (local[*]) | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() | ||||
|                         .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) | ||||
| @ -353,6 +375,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 25; | ||||
|         List<DataSet> list = new ArrayList<>(); | ||||
| @ -402,7 +428,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFitViaStringPaths() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath(); | ||||
|         File tempDirF = tempDir.toFile(); | ||||
|         tempDirF.deleteOnExit(); | ||||
| @ -466,7 +495,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFitViaStringPathsSize1() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath(); | ||||
|         File tempDirF = tempDir.toFile(); | ||||
|         tempDirF.deleteOnExit(); | ||||
| @ -547,7 +579,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFitViaStringPathsCompGraph() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath(); | ||||
|         Path tempDir2 = testDir.newFolder("DL4J-testFitViaStringPathsCG-MDS").toPath(); | ||||
|         File tempDirF = tempDir.toFile(); | ||||
| @ -643,7 +678,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
|     @Test | ||||
|     @Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") | ||||
|     public void testSeedRepeatability() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) | ||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||
|                         .weightInit(WeightInit.XAVIER).list() | ||||
| @ -715,6 +753,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIterationCounts() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 25; | ||||
|         List<DataSet> list = new ArrayList<>(); | ||||
| @ -761,6 +803,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIterationCountsGraph() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int dataSetObjSize = 5; | ||||
|         int batchSizePerExecutor = 25; | ||||
|         List<DataSet> list = new ArrayList<>(); | ||||
| @ -806,7 +852,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     @Ignore   //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 | ||||
|     @Ignore   //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 | ||||
|     public void testVaePretrainSimple() { | ||||
|         //Simple sanity check on pretraining | ||||
|         int nIn = 8; | ||||
| @ -842,7 +888,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     @Ignore    //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 | ||||
|     @Ignore    //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 | ||||
|     public void testVaePretrainSimpleCG() { | ||||
|         //Simple sanity check on pretraining | ||||
|         int nIn = 8; | ||||
| @ -992,7 +1038,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test(timeout = 120000L) | ||||
|     public void testEpochCounter() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||
|                 .list() | ||||
|                 .layer(new OutputLayer.Builder().nIn(4).nOut(3).build()) | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.impl.stats; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| import org.apache.spark.SparkConf; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -56,6 +57,10 @@ public class TestTrainingStatsCollection extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testStatsCollection() throws Exception { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int nWorkers = numExecutors(); | ||||
| 
 | ||||
|         JavaSparkContext sc = getContext(); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.ui; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| import org.apache.spark.api.java.JavaSparkContext; | ||||
| import org.deeplearning4j.core.storage.Persistable; | ||||
| @ -52,7 +53,10 @@ public class TestListeners extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testStatsCollection() { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         JavaSparkContext sc = getContext(); | ||||
|         int nExecutors = numExecutors(); | ||||
| 
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.util; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.spark.Partitioner; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -50,6 +51,10 @@ public class TestRepartitioning extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRepartitioning() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         for (int i = 0; i < 1000; i++) { | ||||
|             list.add(String.valueOf(i)); | ||||
| @ -71,7 +76,10 @@ public class TestRepartitioning extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRepartitioning2() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         int[] ns; | ||||
|         if(isIntegrationTests()){ | ||||
|             ns = new int[]{320, 321, 25600, 25601, 25615}; | ||||
| @ -133,7 +141,10 @@ public class TestRepartitioning extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRepartitioning3(){ | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         //Initial partitions (idx, count) - [(0,29), (1,29), (2,29), (3,34), (4,34), (5,35), (6,34)] | ||||
| 
 | ||||
|         List<Integer> ints = new ArrayList<>(); | ||||
| @ -194,9 +205,13 @@ public class TestRepartitioning extends BaseSparkTest { | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRepartitioning4(){ | ||||
|     public void testRepartitioning4() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<Integer> ints = new ArrayList<>(); | ||||
|         for( int i=0; i<7040; i++ ){ | ||||
|         for( int i = 0; i < 7040; i++) { | ||||
|             ints.add(i); | ||||
|         } | ||||
| 
 | ||||
| @ -230,6 +245,10 @@ public class TestRepartitioning extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRepartitioningApprox() { | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         List<String> list = new ArrayList<>(); | ||||
|         for (int i = 0; i < 1000; i++) { | ||||
|             list.add(String.valueOf(i)); | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.spark.util; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.deeplearning4j.spark.BaseSparkTest; | ||||
| import org.deeplearning4j.spark.util.data.SparkDataValidation; | ||||
| @ -46,10 +47,13 @@ public class TestValidation extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDataSetValidation() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         File f = folder.newFolder(); | ||||
| 
 | ||||
|         for( int i=0; i<3; i++ ) { | ||||
|         for( int i = 0; i < 3; i++ ) { | ||||
|             DataSet ds = new DataSet(Nd4j.create(1,10), Nd4j.create(1,10)); | ||||
|             ds.save(new File(f, i + ".bin")); | ||||
|         } | ||||
| @ -110,10 +114,13 @@ public class TestValidation extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testMultiDataSetValidation() throws Exception { | ||||
| 
 | ||||
|         if(Platform.isWindows()) { | ||||
|             //Spark tests don't run on windows | ||||
|             return; | ||||
|         } | ||||
|         File f = folder.newFolder(); | ||||
| 
 | ||||
|         for( int i=0; i<3; i++ ) { | ||||
|         for( int i = 0; i < 3; i++ ) { | ||||
|             MultiDataSet ds = new MultiDataSet(Nd4j.create(1,10), Nd4j.create(1,10)); | ||||
|             ds.save(new File(f, i + ".bin")); | ||||
|         } | ||||
|  | ||||
| @ -21,7 +21,6 @@ | ||||
| package org.deeplearning4j.ui; | ||||
| 
 | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -38,34 +37,6 @@ import java.util.List; | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class ApiTest { | ||||
|     @Test | ||||
|     @Ignore | ||||
|     public void testUpdateCoords() throws Exception { | ||||
|         Nd4j.factory().setDType(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(250).theta(0.5).learningRate(500) | ||||
|                         .useAdaGrad(false).numDimension(2).build(); | ||||
| 
 | ||||
|         File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt"); | ||||
|         INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), "   ").get(NDArrayIndex.interval(0, 100), | ||||
|                         NDArrayIndex.interval(0, 784)); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); | ||||
|         List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100); | ||||
|         b.fit(data); | ||||
|         b.saveAsFile(labelsList, "coords.csv"); | ||||
|         //        String coords =  client.target("http://localhost:8080").path("api").path("update") | ||||
|         //                .request().accept(MediaType.APPLICATION_JSON) | ||||
|         ////                .post(Entity.entity(new UrlResource("http://localhost:8080/api/coords.csv"), MediaType.APPLICATION_JSON)) | ||||
|         //                .readEntity(String.class); | ||||
|         //        ObjectMapper mapper = new ObjectMapper(); | ||||
|         //        List<String> testLines = mapper.readValue(coords,List.class); | ||||
|         //        List<String> lines = IOUtils.readLines(new FileInputStream("coords.csv")); | ||||
|         //        assertEquals(testLines,lines); | ||||
| 
 | ||||
|         throw new RuntimeException("Not implemented"); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -42,7 +42,6 @@ import org.deeplearning4j.nn.conf.weightnoise.DropConnect; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.deeplearning4j.nn.weights.WeightInit; | ||||
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||||
| import org.deeplearning4j.plot.BarnesHutTsne; | ||||
| import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; | ||||
| import org.deeplearning4j.text.sentenceiterator.SentenceIterator; | ||||
| import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | ||||
| @ -84,7 +83,6 @@ import static org.junit.Assert.fail; | ||||
| @Slf4j | ||||
| public class ManualTests { | ||||
| 
 | ||||
|     private static Logger log = LoggerFactory.getLogger(ManualTests.class); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testLaunch() throws Exception { | ||||
| @ -100,33 +98,7 @@ public class ManualTests { | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test(timeout = 300000) | ||||
|     public void testTsne() throws Exception { | ||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); | ||||
|         Nd4j.getRandom().setSeed(123); | ||||
|         BarnesHutTsne b = new BarnesHutTsne.Builder().stopLyingIteration(10).setMaxIter(10).theta(0.5).learningRate(500) | ||||
|                         .useAdaGrad(true).build(); | ||||
| 
 | ||||
|         File f = Resources.asFile("/deeplearning4j-core/mnist2500_X.txt"); | ||||
|         INDArray data = Nd4j.readNumpy(f.getAbsolutePath(), "   ").get(NDArrayIndex.interval(0, 100), | ||||
|                         NDArrayIndex.interval(0, 784)); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); | ||||
|         List<String> labelsList = IOUtils.readLines(labels.getInputStream()).subList(0, 100); | ||||
|         b.fit(data); | ||||
|         File save = new File(System.getProperty("java.io.tmpdir"), "labels-" + UUID.randomUUID().toString()); | ||||
|         System.out.println("Saved to " + save.getAbsolutePath()); | ||||
|         save.deleteOnExit(); | ||||
|         b.saveAsFile(labelsList, save.getAbsolutePath()); | ||||
| 
 | ||||
|         INDArray output = b.getData(); | ||||
|         System.out.println("Coordinates"); | ||||
| 
 | ||||
|         UIServer server = UIServer.getInstance(); | ||||
|         Thread.sleep(10000000000L); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * This test is for manual execution only, since it's here just to get working CNN and visualize it's layers | ||||
|  | ||||
							
								
								
									
										38
									
								
								deeplearning4j/deeplearning4j-zoo/nd4j-native.properties
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								deeplearning4j/deeplearning4j-zoo/nd4j-native.properties
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | ||||
| # | ||||
| # /* ****************************************************************************** | ||||
| #  * | ||||
| #  * | ||||
| #  * This program and the accompanying materials are made available under the | ||||
| #  * terms of the Apache License, Version 2.0 which is available at | ||||
| #  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
| #  * | ||||
| #  *  See the NOTICE file distributed with this work for additional | ||||
| #  *  information regarding copyright ownership. | ||||
| #  * Unless required by applicable law or agreed to in writing, software | ||||
| #  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
| #  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
| #  * License for the specific language governing permissions and limitations | ||||
| #  * under the License. | ||||
| #  * | ||||
| #  * SPDX-License-Identifier: Apache-2.0 | ||||
| #  ******************************************************************************/ | ||||
| # | ||||
| 
 | ||||
| real.class.double = org.nd4j.linalg.cpu.NDArray | ||||
| shapeinfoprovider = org.nd4j.linalg.cpu.nativecpu.DirectShapeInfoProvider | ||||
| constantsprovider = org.nd4j.linalg.cpu.nativecpu.cache.ConstantBuffersCache | ||||
| affinitymanager = org.nd4j.linalg.cpu.nativecpu.CpuAffinityManager | ||||
| memorymanager = org.nd4j.linalg.cpu.nativecpu.CpuMemoryManager | ||||
| dtype = float | ||||
| blas.ops = org.nd4j.linalg.cpu.nativecpu.BlasWrapper | ||||
| 
 | ||||
| native.ops= org.nd4j.nativeblas.Nd4jCpu | ||||
| ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory | ||||
| ndarray.order = c | ||||
| resourcemanager_state = false | ||||
| databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory | ||||
| workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager | ||||
| alloc = javacpp | ||||
| opexec= org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner | ||||
| opexec.mode= native | ||||
| random=org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom | ||||
| @ -72,7 +72,7 @@ public abstract class ZooModel<T> implements InstantiableModel { | ||||
| 
 | ||||
|         if (!cachedFile.exists()) { | ||||
|             log.info("Downloading model to " + cachedFile.toString()); | ||||
|             FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile); | ||||
|             FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile,Integer.MAX_VALUE,Integer.MAX_VALUE); | ||||
|         } else { | ||||
|             log.info("Using cached model at " + cachedFile.toString()); | ||||
|         } | ||||
| @ -89,7 +89,7 @@ public abstract class ZooModel<T> implements InstantiableModel { | ||||
|                 log.error("Checksums do not match. Cleaning up files and failing..."); | ||||
|                 cachedFile.delete(); | ||||
|                 throw new IllegalStateException( | ||||
|                                 "Pretrained model file failed checksum. If this error persists, please open an issue at https://github.com/deeplearning4j/deeplearning4j."); | ||||
|                                 "Pretrained model file failed checksum. If this error persists, please open an issue at https://github.com/eclipse/deeplearning4j."); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; | ||||
| import org.deeplearning4j.nn.transferlearning.TransferLearning; | ||||
| import org.deeplearning4j.nn.weights.WeightInit; | ||||
| import org.deeplearning4j.zoo.model.VGG16; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| @ -33,17 +34,16 @@ import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||
| 
 | ||||
| import java.io.File; | ||||
| 
 | ||||
| @Ignore("Times out too often") | ||||
| public class MiscTests extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|     public long getTimeoutMilliseconds() { | ||||
|         return 240000L; | ||||
|         return Long.MAX_VALUE; | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTransferVGG() throws Exception { | ||||
|         //https://github.com/deeplearning4j/deeplearning4j/issues/5167 | ||||
|         DataSet ds = new DataSet(); | ||||
|         ds.setFeatures(Nd4j.create(1, 3, 224, 224)); | ||||
|         ds.setLabels(Nd4j.create(1, 2)); | ||||
|  | ||||
| @ -44,6 +44,7 @@ import java.util.Map; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore("Times out too often") | ||||
| public class TestDownload extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -54,6 +54,7 @@ import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore("Times out too often") | ||||
| public class TestImageNet extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -52,6 +52,7 @@ import static org.junit.Assert.assertArrayEquals; | ||||
| import static org.junit.Assume.assumeTrue; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Ignore("Times out too often") | ||||
| public class TestInstantiation extends BaseDL4JTest { | ||||
| 
 | ||||
|     protected static void ignoreIfCuda(){ | ||||
|  | ||||
| @ -59,7 +59,6 @@ | ||||
|         <module>deeplearning4j-modelexport-solr</module> | ||||
|         <module>deeplearning4j-zoo</module> | ||||
|         <module>deeplearning4j-data</module> | ||||
|         <module>deeplearning4j-manifold</module> | ||||
|         <module>dl4j-integration-tests</module> | ||||
|         <module>deeplearning4j-common</module> | ||||
|         <module>deeplearning4j-common-tests</module> | ||||
| @ -231,7 +230,7 @@ | ||||
|                          --> | ||||
|                         <useSystemClassLoader>true</useSystemClassLoader> | ||||
|                         <useManifestOnlyJar>false</useManifestOnlyJar> | ||||
|                         <argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine> | ||||
|                         <argLine> -Dfile.encoding=UTF-8 -Xmx8g -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine> | ||||
|                         <includes> | ||||
|                             <!-- Default setting only runs tests that start/end with "Test" --> | ||||
|                             <include>*.java</include> | ||||
| @ -292,6 +291,51 @@ | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <inherited>true</inherited> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.nd4j</groupId> | ||||
|                                 <artifactId>nd4j-native</artifactId> | ||||
|                                 <version>${project.version}</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <environmentVariables> | ||||
| 
 | ||||
|                             </environmentVariables> | ||||
|                             <testSourceDirectory>src/test/java</testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>*.java</include> | ||||
|                                 <include>**/*.java</include> | ||||
|                                 <include>**/Test*.java</include> | ||||
|                                 <include>**/*Test.java</include> | ||||
|                                 <include>**/*TestCase.java</include> | ||||
|                             </includes> | ||||
|                             <junitArtifactName>junit:junit</junitArtifactName> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 8g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
| 
 | ||||
|                                 For testing large zoo models, this may not be enough (so comment it out). | ||||
|                             --> | ||||
|                             <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|         </profile> | ||||
|         <!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" --> | ||||
|         <profile> | ||||
| @ -314,6 +358,47 @@ | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|             <!-- Default to ALL modules here, unlike nd4j-native --> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.apache.maven.surefire</groupId> | ||||
|                                 <artifactId>surefire-junit47</artifactId> | ||||
|                                 <version>2.19.1</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <environmentVariables> | ||||
|                             </environmentVariables> | ||||
|                             <testSourceDirectory>src/test/java</testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>*.java</include> | ||||
|                                 <include>**/*.java</include> | ||||
|                                 <include>**/Test*.java</include> | ||||
|                                 <include>**/*Test.java</include> | ||||
|                                 <include>**/*TestCase.java</include> | ||||
|                             </includes> | ||||
|                             <junitArtifactName>junit:junit</junitArtifactName> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
|                             --> | ||||
|                             <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
|  | ||||
| @ -36,7 +36,7 @@ do | ||||
|         # unknown option | ||||
|         ;; | ||||
|     esac | ||||
|      | ||||
| 
 | ||||
|     if [[ $# > 0 ]]; then | ||||
|         shift # past argument or value | ||||
|     fi | ||||
| @ -59,6 +59,6 @@ fi | ||||
| unameOut="$(uname)" | ||||
| echo "$OSTYPE" | ||||
| 
 | ||||
| ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests.exe | ||||
| ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests | ||||
| # Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) | ||||
| #[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ | ||||
| [ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ | ||||
|  | ||||
| @ -881,7 +881,7 @@ public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp, | ||||
|             for (int i = 0; i < outShape.size(); i++) { | ||||
|                 LongShapeDescriptor reqShape = outShape.get(i); | ||||
| 
 | ||||
|                 //Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872 | ||||
|                 //Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/eclipse/deeplearning4j/issues/6872 | ||||
|                 //As a workaround, we'll use the output variable datatype instead. | ||||
|                 DataType dt = sameDiff.getVariable(outNames[i]).dataType(); | ||||
|                 DataType currDT = reqShape.dataType(); | ||||
|  | ||||
| @ -189,7 +189,7 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> { | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 //TODO Temporary workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7102 | ||||
|                 //TODO Temporary workaround for: https://github.com/eclipse/deeplearning4j/issues/7102 | ||||
|                 if(prob.isView()) | ||||
|                     prob = prob.dup(); | ||||
|                 if(label.isView()) | ||||
|  | ||||
| @ -221,7 +221,7 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> { | ||||
|         for (int i = 0; i < n; i++) { | ||||
|             INDArray prob = predictions2d.getColumn(i, true); //Probability of class i | ||||
|             INDArray label = labels2d.getColumn(i, true); | ||||
|             //Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7305 | ||||
|             //Workaround for: https://github.com/eclipse/deeplearning4j/issues/7305 | ||||
|             if(prob.rank() == 0) | ||||
|                 prob = prob.reshape(1,1); | ||||
|             if(label.rank() == 0) | ||||
|  | ||||
| @ -73,7 +73,7 @@ public class Min extends BaseDynamicTransformOp { | ||||
| 
 | ||||
|     @Override | ||||
|     public List<SDVariable> doDiff(List<SDVariable> f1) { | ||||
|         //TODO Switch to minimum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp | ||||
|         //TODO Switch to minimum_bp op - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp | ||||
|         SDVariable min = outputVariables()[0]; | ||||
|         SDVariable eq1 = sameDiff.eq(larg(), min).castTo(arg(0).dataType()); | ||||
|         SDVariable eq2 = sameDiff.eq(rarg(), min).castTo(arg(1).dataType()); | ||||
|  | ||||
| @ -56,7 +56,7 @@ public class Pow extends DynamicCustomOp { | ||||
| 
 | ||||
|     @Override | ||||
|     public List<SDVariable> doDiff(List<SDVariable> f1) { | ||||
|         //TODO: replace this with discrete op once available: https://github.com/deeplearning4j/deeplearning4j/issues/7461 | ||||
|         //TODO: replace this with discrete op once available: https://github.com/eclipse/deeplearning4j/issues/7461 | ||||
|         //If y=a^b, then: | ||||
|         //dL/da = b*a^(b-1) * dL/dy | ||||
|         //dL/db = a^b * log(a) * dL/dy | ||||
|  | ||||
| @ -84,7 +84,7 @@ public class RandomStandardNormal extends DynamicCustomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.FLOAT); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -65,7 +65,7 @@ public class RandomBernoulli extends DynamicCustomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.FLOAT); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -80,7 +80,7 @@ public class RandomExponential extends DynamicCustomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.FLOAT); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -66,7 +66,7 @@ public class RandomNormal extends DynamicCustomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.FLOAT); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -118,7 +118,7 @@ public class BernoulliDistribution extends BaseRandomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(dataType); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -140,7 +140,7 @@ public class BinomialDistribution extends BaseRandomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.DOUBLE); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -91,28 +91,28 @@ public class Linspace extends BaseRandomOp { | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray x(){ | ||||
|         //Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723 | ||||
|         //Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723 | ||||
|         //If x or y is present, can't execute this op properly (wrong signature is used) | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray y(){ | ||||
|         //Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723 | ||||
|         //Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723 | ||||
|         //If x or y is present, can't execute this op properly (wrong signature is used) | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setX(INDArray x){ | ||||
|         //Workaround/hack for: https://github.com/deeplearning4j/deeplearning4j/issues/6723 | ||||
|         //Workaround/hack for: https://github.com/eclipse/deeplearning4j/issues/6723 | ||||
|         //If x or y is present, can't execute this op properly (wrong signature is used) | ||||
|         this.x = null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setY(INDArray y){ | ||||
|         //Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/6723 | ||||
|         //Workaround for: https://github.com/eclipse/deeplearning4j/issues/6723 | ||||
|         //If x or y is present, can't execute this op properly (wrong signature is used) | ||||
|         this.y = null; | ||||
|     } | ||||
|  | ||||
| @ -139,7 +139,7 @@ public class TruncatedNormalDistribution extends BaseRandomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(DataType.DOUBLE); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -110,7 +110,7 @@ public class UniformDistribution extends BaseRandomOp { | ||||
|     public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ | ||||
|         Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); | ||||
|         //Input data type specifies the shape; output data type should be any float | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 | ||||
|         //TODO MAKE CONFIGUREABLE - https://github.com/eclipse/deeplearning4j/issues/6854 | ||||
|         return Collections.singletonList(dataType); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -80,7 +80,7 @@ public class VersionInfo { | ||||
| 
 | ||||
|     public VersionInfo(URI uri) throws IOException { | ||||
|         //Can't use new File(uri).getPath() for URIs pointing to resources in JARs | ||||
|         //But URI.toString() returns "%2520" instead of spaces in path - https://github.com/deeplearning4j/deeplearning4j/issues/6056 | ||||
|         //But URI.toString() returns "%2520" instead of spaces in path - https://github.com/eclipse/deeplearning4j/issues/6056 | ||||
|         String path = uri.toString().replaceAll(HTML_SPACE, " "); | ||||
|         int idxOf = path.lastIndexOf('/'); | ||||
|         idxOf = Math.max(idxOf, path.lastIndexOf('\\')); | ||||
|  | ||||
| @ -141,7 +141,7 @@ | ||||
|                             Maximum heap size was set to 8g, as a minimum required value for tests run. | ||||
|                             Depending on a build machine, default value is not always enough. | ||||
|                         --> | ||||
|                         <argLine>-Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                     </configuration> | ||||
|                 </plugin> | ||||
|                 <plugin> | ||||
|  | ||||
| @ -1,316 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.nd4j</groupId> | ||||
|         <artifactId>nd4j-backends</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>nd4j-tests-tensorflow</artifactId> | ||||
| 
 | ||||
|     <name>nd4j-tests-tensorflow</name> | ||||
| 
 | ||||
|     <properties> | ||||
|         <maven.compiler.source>1.8</maven.compiler.source> | ||||
|         <maven.compiler.target>1.8</maven.compiler.target> | ||||
|         <scala.binary.version>2.11</scala.binary.version> | ||||
|         <maven.compiler.testTarget>1.8</maven.compiler.testTarget> | ||||
|         <maven.compiler.testSource>1.8</maven.compiler.testSource> | ||||
|     </properties> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-tensorflow</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>junit</groupId> | ||||
|             <artifactId>junit</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>ch.qos.logback</groupId> | ||||
|             <artifactId>logback-classic</artifactId> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-common-tests</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <build> | ||||
|         <testSourceDirectory>${test.root}</testSourceDirectory> | ||||
|         <plugins> | ||||
|             <plugin> | ||||
|                 <groupId>org.apache.maven.plugins</groupId> | ||||
|                 <artifactId>maven-enforcer-plugin</artifactId> | ||||
|                 <executions> | ||||
|                     <execution> | ||||
|                         <phase>test</phase> | ||||
|                         <id>enforce-test-resources</id> | ||||
|                         <goals> | ||||
|                             <goal>enforce</goal> | ||||
|                         </goals> | ||||
|                         <configuration> | ||||
|                             <skip>${skipTestResourceEnforcement}</skip> | ||||
|                             <rules> | ||||
|                                 <requireActiveProfile> | ||||
|                                     <profiles>nd4j-tf-cpu,nd4j-tf-gpu</profiles> | ||||
|                                     <all>false</all> | ||||
|                                 </requireActiveProfile> | ||||
|                             </rules> | ||||
|                             <fail>true</fail> | ||||
|                         </configuration> | ||||
|                     </execution> | ||||
|                 </executions> | ||||
|             </plugin> | ||||
|         </plugins> | ||||
|     </build> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>testresources</id> | ||||
|             <activation> | ||||
|                 <activeByDefault>true</activeByDefault> | ||||
|             </activation> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>tf-cpu</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow-platform</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>tf-gpu</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                     <classifier>linux-x86_64-gpu</classifier> | ||||
|                 </dependency> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                     <classifier>windows-x86_64-gpu</classifier> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>nd4j-tf-gpu</id> | ||||
|             <properties> | ||||
|                 <test.root>src/test/gpujava</test.root> | ||||
|             </properties> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-failsafe-plugin</artifactId> | ||||
|                         <version>2.18</version> | ||||
|                         <executions> | ||||
|                             <!-- | ||||
|                                 Invokes both the integration-test and the verify goals of the | ||||
|                                 Failsafe Maven plugin | ||||
|                             --> | ||||
|                             <execution> | ||||
|                                 <id>integration-tests</id> | ||||
|                                 <phase>test</phase> | ||||
|                                 <goals> | ||||
|                                     <goal>integration-test</goal> | ||||
|                                     <goal>verify</goal> | ||||
|                                 </goals> | ||||
|                                 <configuration> | ||||
|                                     <!-- | ||||
|                                         Skips integration tests if the value of skip.integration.tests | ||||
|                                         property is true | ||||
|                                     --> | ||||
|                                     <skipTests>false</skipTests> | ||||
|                                 </configuration> | ||||
|                             </execution> | ||||
|                         </executions> | ||||
|                     </plugin> | ||||
|                     <plugin> | ||||
|                         <groupId>org.codehaus.mojo</groupId> | ||||
|                         <artifactId>build-helper-maven-plugin</artifactId> | ||||
|                         <version>1.9.1</version> | ||||
|                         <executions> | ||||
|                             <execution> | ||||
|                                 <id>add-integration-test-sources</id> | ||||
|                                 <phase>test-compile</phase> | ||||
|                                 <goals> | ||||
|                                     <goal>add-test-source</goal> | ||||
|                                 </goals> | ||||
|                                 <configuration> | ||||
|                                     <!-- Configures the source directory of our integration tests --> | ||||
|                                     <sources> | ||||
|                                         <source>src/test/gpujava</source> | ||||
|                                     </sources> | ||||
|                                 </configuration> | ||||
|                             </execution> | ||||
|                         </executions> | ||||
|                     </plugin> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-compiler-plugin</artifactId> | ||||
|                         <version>${maven-compiler-plugin.version}</version> | ||||
|                         <configuration> | ||||
|                             <source>1.8</source> | ||||
|                             <target>1.8</target> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <version>2.19.1</version> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.apache.maven.surefire</groupId> | ||||
|                                 <artifactId>surefire-junit47</artifactId> | ||||
|                                 <version>2.19.1</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <testSourceDirectory>${project.basedir}/src/test/gpujava | ||||
|                             </testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>**/*.java</include> | ||||
|                             </includes> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.jcublas.JCublasBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
|                             --> | ||||
|                             <skip>false</skip> | ||||
|                             <argLine>-Xmx6g -Dfile.encoding=UTF-8</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-cuda-11.0</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                 </dependency> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                     <classifier>linux-x86_64-gpu</classifier> | ||||
|                 </dependency> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                     <classifier>windows-x86_64-gpu</classifier> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>nd4j-tf-cpu</id> | ||||
|             <properties> | ||||
|                 <test.root>src/test/cpujava</test.root> | ||||
|             </properties> | ||||
|             <build> | ||||
|                 <plugins> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-compiler-plugin</artifactId> | ||||
|                         <version>${maven-compiler-plugin.version}</version> | ||||
|                         <configuration> | ||||
|                             <testSource>1.8</testSource> | ||||
|                             <source>1.8</source> | ||||
|                             <target>1.8</target> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                     <plugin> | ||||
|                         <groupId>org.apache.maven.plugins</groupId> | ||||
|                         <artifactId>maven-surefire-plugin</artifactId> | ||||
|                         <version>2.19.1</version> | ||||
|                         <dependencies> | ||||
|                             <dependency> | ||||
|                                 <groupId>org.apache.maven.surefire</groupId> | ||||
|                                 <artifactId>surefire-junit47</artifactId> | ||||
|                                 <version>2.19.1</version> | ||||
|                             </dependency> | ||||
|                         </dependencies> | ||||
|                         <configuration> | ||||
|                             <testSourceDirectory>${project.basedir}/src/test/cpujava | ||||
|                             </testSourceDirectory> | ||||
|                             <includes> | ||||
|                                 <include>**/*.java</include> | ||||
|                             </includes> | ||||
|                             <systemPropertyVariables> | ||||
|                                 <org.nd4j.linalg.defaultbackend> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.defaultbackend> | ||||
|                                 <org.nd4j.linalg.tests.backendstorun> | ||||
|                                     org.nd4j.linalg.cpu.nativecpu.CpuBackend | ||||
|                                 </org.nd4j.linalg.tests.backendstorun> | ||||
|                             </systemPropertyVariables> | ||||
|                             <!-- | ||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
|                             --> | ||||
|                             <argLine>-Xmx6g -Dfile.encoding=UTF-8</argLine> | ||||
|                             <skipTests>false</skipTests> | ||||
|                             <skip>false</skip> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|             </build> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-native</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                 </dependency> | ||||
|                 <dependency> | ||||
|                     <groupId>org.bytedeco</groupId> | ||||
|                     <artifactId>tensorflow-platform</artifactId> | ||||
|                     <version>${tensorflow.javacpp.version}</version> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,193 +0,0 @@ | ||||
| /* ****************************************************************************** | ||||
|  * | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  *  See the NOTICE file distributed with this work for additional | ||||
|  *  information regarding copyright ownership. | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| package org.nd4j.tensorflow.conversion; | ||||
| 
 | ||||
| import junit.framework.TestCase; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.bytedeco.tensorflow.TF_Tensor; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.resources.Resources; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.shade.protobuf.util.JsonFormat; | ||||
| import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner; | ||||
| import org.nd4j.tensorflow.conversion.graphrunner.SavedModelConfig; | ||||
| import org.tensorflow.framework.ConfigProto; | ||||
| import org.tensorflow.framework.GPUOptions; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.util.Arrays; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertNotNull; | ||||
| 
 | ||||
| public class GraphRunnerTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|     public DataType getDataType() { | ||||
|         return DataType.FLOAT; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public DataType getDefaultFPDataType() { | ||||
|         return DataType.FLOAT; | ||||
|     } | ||||
| 
 | ||||
|     public static ConfigProto getConfig(){ | ||||
|         String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); | ||||
|         if("CUDA".equalsIgnoreCase(backend)) { | ||||
|             org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); | ||||
|             ConfigProto.Builder b = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); | ||||
|             return b.setGpuOptions(GPUOptions.newBuilder() | ||||
|                     .setAllowGrowth(true) | ||||
|                     .setPerProcessGpuMemoryFraction(0.5) | ||||
|                     .build()).build(); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGraphRunner() throws Exception { | ||||
|         List<String> inputs = Arrays.asList("input_0","input_1"); | ||||
|         byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); | ||||
| 
 | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { | ||||
|             runGraphRunnerTest(graphRunner); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGraphRunnerFilePath() throws Exception { | ||||
|         List<String> inputs = Arrays.asList("input_0","input_1"); | ||||
|         byte[] content = FileUtils.readFileToByteArray(Resources.asFile("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb")); | ||||
| 
 | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { | ||||
|             runGraphRunnerTest(graphRunner); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testInputOutputResolution() throws Exception { | ||||
|         ClassPathResource lenetPb = new ClassPathResource("tf_graphs/lenet_frozen.pb"); | ||||
|         byte[] content = IOUtils.toByteArray(lenetPb.getInputStream()); | ||||
|         List<String> inputs = Arrays.asList("Reshape/tensor"); | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { | ||||
|             assertEquals(1, graphRunner.getInputOrder().size()); | ||||
|             assertEquals(1, graphRunner.getOutputOrder().size()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test @Ignore   //Ignored 2019/02/05: ssd_inception_v2_coco_2019_01_28 does not exist in test resources | ||||
|     public void testMultiOutputGraph() throws Exception { | ||||
|         List<String> inputs = Arrays.asList("image_tensor"); | ||||
|         byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { | ||||
|             String[] outputs = new String[]{"detection_boxes", "detection_scores", "detection_classes", "num_detections"}; | ||||
| 
 | ||||
|             assertEquals(1, graphRunner.getInputOrder().size()); | ||||
|             System.out.println(graphRunner.getOutputOrder()); | ||||
|             assertEquals(4, graphRunner.getOutputOrder().size()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private void runGraphRunnerTest(GraphRunner graphRunner) throws Exception { | ||||
|         String json = graphRunner.sessionOptionsToJson(); | ||||
|         if( json != null ) { | ||||
|             org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); | ||||
|             JsonFormat.parser().merge(json, builder); | ||||
|             org.tensorflow.framework.ConfigProto build = builder.build(); | ||||
|             assertEquals(build,graphRunner.getSessionOptionsConfigProto()); | ||||
|         } | ||||
|         assertNotNull(graphRunner.getInputOrder()); | ||||
|         assertNotNull(graphRunner.getOutputOrder()); | ||||
| 
 | ||||
| 
 | ||||
|         org.tensorflow.framework.ConfigProto configProto1 = json == null ? null : GraphRunner.fromJson(json); | ||||
| 
 | ||||
|         assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1); | ||||
|         assertEquals(2,graphRunner.getInputOrder().size()); | ||||
|         assertEquals(1,graphRunner.getOutputOrder().size()); | ||||
| 
 | ||||
|         INDArray input1 = Nd4j.linspace(1,4,4).reshape(4); | ||||
|         INDArray input2 = Nd4j.linspace(1,4,4).reshape(4); | ||||
| 
 | ||||
|         Map<String,INDArray> inputs = new LinkedHashMap<>(); | ||||
|         inputs.put("input_0",input1); | ||||
|         inputs.put("input_1",input2); | ||||
| 
 | ||||
|         for(int i = 0; i < 2; i++) { | ||||
|             Map<String,INDArray> outputs = graphRunner.run(inputs); | ||||
| 
 | ||||
|             INDArray assertion = input1.add(input2); | ||||
|             assertEquals(assertion,outputs.get("output")); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGraphRunnerSavedModel() throws Exception { | ||||
|         File f = testDir.newFolder("test"); | ||||
|         new ClassPathResource("/tf_saved_models/saved_model_counter/00000123/").copyDirectory(f); | ||||
|         SavedModelConfig savedModelConfig = SavedModelConfig.builder() | ||||
|                 .savedModelPath(f.getAbsolutePath()) | ||||
|                 .signatureKey("incr_counter_by") | ||||
|                 .modelTag("serve") | ||||
|                 .build(); | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().savedModelConfig(savedModelConfig).sessionOptionsConfigProto(getConfig()).build()) { | ||||
|             INDArray delta = Nd4j.create(new float[] { 42 }, new long[0]); | ||||
|             Map<String,INDArray> inputs = new LinkedHashMap<>(); | ||||
|             inputs.put("delta:0",delta); | ||||
|             Map<String,INDArray> outputs = graphRunner.run(inputs); | ||||
|             assertEquals(1, outputs.size()); | ||||
|             System.out.println(Arrays.toString(outputs.keySet().toArray(new String[0]))); | ||||
|             INDArray output = outputs.values().toArray(new INDArray[0])[0]; | ||||
|             assertEquals(42.0, output.getDouble(0), 0.0); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGraphRunnerCast() { | ||||
|         INDArray arr = Nd4j.linspace(1,4,4).castTo(DataType.FLOAT); | ||||
|         TF_Tensor tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr); | ||||
|         TF_Tensor tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.FLOAT,TensorDataType.DOUBLE); | ||||
|         INDArray doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor); | ||||
|         TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType()); | ||||
| 
 | ||||
|         arr = arr.castTo(DataType.INT); | ||||
|         tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr); | ||||
|         tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.fromNd4jType(DataType.INT),TensorDataType.DOUBLE); | ||||
|         doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor); | ||||
|         TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType()); | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| @ -1,130 +0,0 @@ | ||||
| /* ****************************************************************************** | ||||
|  * | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  *  See the NOTICE file distributed with this work for additional | ||||
|  *  information regarding copyright ownership. | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.nd4j.tensorflow.conversion; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.tensorflow.framework.GraphDef; | ||||
| 
 | ||||
| import org.bytedeco.tensorflow.*; | ||||
| import static org.bytedeco.tensorflow.global.tensorflow.*; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertNotNull; | ||||
| import static org.junit.Assert.fail; | ||||
| import static org.nd4j.linalg.api.buffer.DataType.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class TensorflowConversionTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testView() { | ||||
|         INDArray matrix = Nd4j.linspace(1,8,8).reshape(2,4); | ||||
|         INDArray view = matrix.slice(0); | ||||
|         TensorflowConversion conversion =TensorflowConversion.getInstance(); | ||||
|         TF_Tensor tf_tensor = conversion.tensorFromNDArray(view); | ||||
|         INDArray converted = conversion.ndArrayFromTensor(tf_tensor); | ||||
|         assertEquals(view,converted); | ||||
|     } | ||||
| 
 | ||||
|     @Test(expected = IllegalArgumentException.class) | ||||
|     public void testNullArray() { | ||||
|         INDArray array = Nd4j.create(2,2); | ||||
|         array.setData(null); | ||||
|         TensorflowConversion conversion =TensorflowConversion.getInstance(); | ||||
|         TF_Tensor tf_tensor = conversion.tensorFromNDArray(array); | ||||
|         fail(); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testConversionFromNdArray() throws Exception { | ||||
|         DataType[] dtypes = new DataType[]{ | ||||
|           DOUBLE, | ||||
|           FLOAT, | ||||
|           SHORT, | ||||
|           LONG, | ||||
|           BYTE, | ||||
|           UBYTE, | ||||
|           UINT16, | ||||
|           UINT32, | ||||
|           UINT64, | ||||
|           BFLOAT16, | ||||
|           BOOL, | ||||
|           INT, | ||||
|           HALF | ||||
|         }; | ||||
|         for(DataType dtype: dtypes){ | ||||
|             log.debug("Testing conversion for data type " + dtype); | ||||
|             INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype); | ||||
|             TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance(); | ||||
|             TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr); | ||||
|             INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); | ||||
|             assertEquals(arr,fromTensor); | ||||
|             if (dtype == BOOL){ | ||||
|                 arr.putScalar(3, 0); | ||||
|             } | ||||
|             else{ | ||||
|                 arr.addi(1.0); | ||||
|             } | ||||
|             tf_tensor = tensorflowConversion.tensorFromNDArray(arr); | ||||
|             fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); | ||||
|             assertEquals(arr,fromTensor); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCudaIfAvailable() throws Exception { | ||||
|         TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance(); | ||||
|         byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); | ||||
|         //byte[] content = Files.readAllBytes(Paths.get(new File("/home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").toURI())); | ||||
|         TF_Status status = TF_Status.newStatus(); | ||||
|         TF_Graph initializedGraphForNd4jDevices = tensorflowConversion.loadGraph(content, status); | ||||
|         assertNotNull(initializedGraphForNd4jDevices); | ||||
| 
 | ||||
|         String deviceName = tensorflowConversion.defaultDeviceForThread(); | ||||
| 
 | ||||
|         byte[] content2 = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); | ||||
|         GraphDef graphDef1 = GraphDef.parseFrom(content2); | ||||
|         System.out.println(graphDef1); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testStringConversion() throws Exception { | ||||
|         String[] strings = {"one", "two", "three"}; | ||||
|         INDArray arr = Nd4j.create(strings); | ||||
|         TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance(); | ||||
|         TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr); | ||||
|         INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); | ||||
|         assertEquals(arr.length(), fromTensor.length()); | ||||
|         for (int i = 0; i < arr.length(); i++) { | ||||
|             assertEquals(strings[i], fromTensor.getString(i)); | ||||
|             assertEquals(arr.getString(i), fromTensor.getString(i)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,94 +0,0 @@ | ||||
| /* ****************************************************************************** | ||||
|  * | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  *  See the NOTICE file distributed with this work for additional | ||||
|  *  information regarding copyright ownership. | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.nd4j.tensorflow.conversion; | ||||
| 
 | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.shade.protobuf.util.JsonFormat; | ||||
| import org.apache.commons.io.IOUtils; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner; | ||||
| import org.tensorflow.framework.ConfigProto; | ||||
| import org.tensorflow.framework.GPUOptions; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.FileInputStream; | ||||
| import java.util.Arrays; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertNotNull; | ||||
| 
 | ||||
| public class GpuGraphRunnerTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|     public long getTimeoutMilliseconds() { | ||||
|         return 180000L; | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testGraphRunner() throws Exception { | ||||
|         byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); | ||||
|         List<String> inputNames = Arrays.asList("input_0","input_1"); | ||||
| 
 | ||||
|         ConfigProto configProto = ConfigProto.newBuilder() | ||||
|                 .setGpuOptions(GPUOptions.newBuilder() | ||||
|                         .setPerProcessGpuMemoryFraction(0.1) | ||||
|                         .setAllowGrowth(false) | ||||
|                         .build()) | ||||
|                 .build(); | ||||
| 
 | ||||
|         try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputNames).sessionOptionsConfigProto(configProto).build()) { | ||||
|             org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); | ||||
|             String json = graphRunner.sessionOptionsToJson(); | ||||
|             JsonFormat.parser().merge(json,builder); | ||||
|             org.tensorflow.framework.ConfigProto build = builder.build(); | ||||
|             assertEquals(build,graphRunner.getSessionOptionsConfigProto()); | ||||
|             assertNotNull(graphRunner.getInputOrder()); | ||||
|             assertNotNull(graphRunner.getOutputOrder()); | ||||
| 
 | ||||
| 
 | ||||
|             org.tensorflow.framework.ConfigProto configProto1 = GraphRunner.fromJson(json); | ||||
| 
 | ||||
|             assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1); | ||||
|             assertEquals(2,graphRunner.getInputOrder().size()); | ||||
|             assertEquals(1,graphRunner.getOutputOrder().size()); | ||||
| 
 | ||||
|             INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); | ||||
|             INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); | ||||
| 
 | ||||
|             Map<String,INDArray> inputs = new LinkedHashMap<>(); | ||||
|             inputs.put("input_0",input1); | ||||
|             inputs.put("input_1",input2); | ||||
| 
 | ||||
|             for(int i = 0; i < 2; i++) { | ||||
|                 Map<String,INDArray> outputs = graphRunner.run(inputs); | ||||
| 
 | ||||
|                 INDArray assertion = input1.add(input2); | ||||
|                 assertEquals(assertion,outputs.get("output")); | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,2 +1,441 @@ | ||||
| Identity,in_0/read | ||||
| MaxPoolWithArgmax,MaxPoolWithArgmax | ||||
| Transpose,transpose | ||||
| Identity,conv2d/kernel/read | ||||
| Identity,batch_normalization/gamma/read | ||||
| Identity,batch_normalization/beta/read | ||||
| Identity,batch_normalization/moving_mean/read | ||||
| Identity,batch_normalization/moving_variance/read | ||||
| Identity,conv2d_1/kernel/read | ||||
| Identity,conv2d_2/kernel/read | ||||
| Identity,batch_normalization_1/gamma/read | ||||
| Identity,batch_normalization_1/beta/read | ||||
| Identity,batch_normalization_1/moving_mean/read | ||||
| Identity,batch_normalization_1/moving_variance/read | ||||
| Identity,conv2d_3/kernel/read | ||||
| Identity,batch_normalization_2/gamma/read | ||||
| Identity,batch_normalization_2/beta/read | ||||
| Identity,batch_normalization_2/moving_mean/read | ||||
| Identity,batch_normalization_2/moving_variance/read | ||||
| Identity,conv2d_4/kernel/read | ||||
| Identity,batch_normalization_3/gamma/read | ||||
| Identity,batch_normalization_3/beta/read | ||||
| Identity,batch_normalization_3/moving_mean/read | ||||
| Identity,batch_normalization_3/moving_variance/read | ||||
| Identity,conv2d_5/kernel/read | ||||
| Identity,batch_normalization_4/gamma/read | ||||
| Identity,batch_normalization_4/beta/read | ||||
| Identity,batch_normalization_4/moving_mean/read | ||||
| Identity,batch_normalization_4/moving_variance/read | ||||
| Identity,conv2d_6/kernel/read | ||||
| Identity,batch_normalization_5/gamma/read | ||||
| Identity,batch_normalization_5/beta/read | ||||
| Identity,batch_normalization_5/moving_mean/read | ||||
| Identity,batch_normalization_5/moving_variance/read | ||||
| Identity,conv2d_7/kernel/read | ||||
| Identity,batch_normalization_6/gamma/read | ||||
| Identity,batch_normalization_6/beta/read | ||||
| Identity,batch_normalization_6/moving_mean/read | ||||
| Identity,batch_normalization_6/moving_variance/read | ||||
| Identity,conv2d_8/kernel/read | ||||
| Identity,batch_normalization_7/gamma/read | ||||
| Identity,batch_normalization_7/beta/read | ||||
| Identity,batch_normalization_7/moving_mean/read | ||||
| Identity,batch_normalization_7/moving_variance/read | ||||
| Identity,conv2d_9/kernel/read | ||||
| Identity,batch_normalization_8/gamma/read | ||||
| Identity,batch_normalization_8/beta/read | ||||
| Identity,batch_normalization_8/moving_mean/read | ||||
| Identity,batch_normalization_8/moving_variance/read | ||||
| Identity,conv2d_10/kernel/read | ||||
| Identity,batch_normalization_9/gamma/read | ||||
| Identity,batch_normalization_9/beta/read | ||||
| Identity,batch_normalization_9/moving_mean/read | ||||
| Identity,batch_normalization_9/moving_variance/read | ||||
| Identity,conv2d_11/kernel/read | ||||
| Identity,conv2d_12/kernel/read | ||||
| Identity,batch_normalization_10/gamma/read | ||||
| Identity,batch_normalization_10/beta/read | ||||
| Identity,batch_normalization_10/moving_mean/read | ||||
| Identity,batch_normalization_10/moving_variance/read | ||||
| Identity,conv2d_13/kernel/read | ||||
| Identity,batch_normalization_11/gamma/read | ||||
| Identity,batch_normalization_11/beta/read | ||||
| Identity,batch_normalization_11/moving_mean/read | ||||
| Identity,batch_normalization_11/moving_variance/read | ||||
| Identity,conv2d_14/kernel/read | ||||
| Identity,batch_normalization_12/gamma/read | ||||
| Identity,batch_normalization_12/beta/read | ||||
| Identity,batch_normalization_12/moving_mean/read | ||||
| Identity,batch_normalization_12/moving_variance/read | ||||
| Identity,conv2d_15/kernel/read | ||||
| Identity,batch_normalization_13/gamma/read | ||||
| Identity,batch_normalization_13/beta/read | ||||
| Identity,batch_normalization_13/moving_mean/read | ||||
| Identity,batch_normalization_13/moving_variance/read | ||||
| Identity,conv2d_16/kernel/read | ||||
| Identity,batch_normalization_14/gamma/read | ||||
| Identity,batch_normalization_14/beta/read | ||||
| Identity,batch_normalization_14/moving_mean/read | ||||
| Identity,batch_normalization_14/moving_variance/read | ||||
| Identity,conv2d_17/kernel/read | ||||
| Identity,batch_normalization_15/gamma/read | ||||
| Identity,batch_normalization_15/beta/read | ||||
| Identity,batch_normalization_15/moving_mean/read | ||||
| Identity,batch_normalization_15/moving_variance/read | ||||
| Identity,conv2d_18/kernel/read | ||||
| Identity,batch_normalization_16/gamma/read | ||||
| Identity,batch_normalization_16/beta/read | ||||
| Identity,batch_normalization_16/moving_mean/read | ||||
| Identity,batch_normalization_16/moving_variance/read | ||||
| Identity,conv2d_19/kernel/read | ||||
| Identity,batch_normalization_17/gamma/read | ||||
| Identity,batch_normalization_17/beta/read | ||||
| Identity,batch_normalization_17/moving_mean/read | ||||
| Identity,batch_normalization_17/moving_variance/read | ||||
| Identity,conv2d_20/kernel/read | ||||
| Identity,batch_normalization_18/gamma/read | ||||
| Identity,batch_normalization_18/beta/read | ||||
| Identity,batch_normalization_18/moving_mean/read | ||||
| Identity,batch_normalization_18/moving_variance/read | ||||
| Identity,conv2d_21/kernel/read | ||||
| Identity,batch_normalization_19/gamma/read | ||||
| Identity,batch_normalization_19/beta/read | ||||
| Identity,batch_normalization_19/moving_mean/read | ||||
| Identity,batch_normalization_19/moving_variance/read | ||||
| Identity,conv2d_22/kernel/read | ||||
| Identity,batch_normalization_20/gamma/read | ||||
| Identity,batch_normalization_20/beta/read | ||||
| Identity,batch_normalization_20/moving_mean/read | ||||
| Identity,batch_normalization_20/moving_variance/read | ||||
| Identity,conv2d_23/kernel/read | ||||
| Identity,batch_normalization_21/gamma/read | ||||
| Identity,batch_normalization_21/beta/read | ||||
| Identity,batch_normalization_21/moving_mean/read | ||||
| Identity,batch_normalization_21/moving_variance/read | ||||
| Identity,conv2d_24/kernel/read | ||||
| Identity,conv2d_25/kernel/read | ||||
| Identity,batch_normalization_22/gamma/read | ||||
| Identity,batch_normalization_22/beta/read | ||||
| Identity,batch_normalization_22/moving_mean/read | ||||
| Identity,batch_normalization_22/moving_variance/read | ||||
| Identity,conv2d_26/kernel/read | ||||
| Identity,batch_normalization_23/gamma/read | ||||
| Identity,batch_normalization_23/beta/read | ||||
| Identity,batch_normalization_23/moving_mean/read | ||||
| Identity,batch_normalization_23/moving_variance/read | ||||
| Identity,conv2d_27/kernel/read | ||||
| Identity,batch_normalization_24/gamma/read | ||||
| Identity,batch_normalization_24/beta/read | ||||
| Identity,batch_normalization_24/moving_mean/read | ||||
| Identity,batch_normalization_24/moving_variance/read | ||||
| Identity,conv2d_28/kernel/read | ||||
| Identity,batch_normalization_25/gamma/read | ||||
| Identity,batch_normalization_25/beta/read | ||||
| Identity,batch_normalization_25/moving_mean/read | ||||
| Identity,batch_normalization_25/moving_variance/read | ||||
| Identity,conv2d_29/kernel/read | ||||
| Identity,batch_normalization_26/gamma/read | ||||
| Identity,batch_normalization_26/beta/read | ||||
| Identity,batch_normalization_26/moving_mean/read | ||||
| Identity,batch_normalization_26/moving_variance/read | ||||
| Identity,conv2d_30/kernel/read | ||||
| Identity,batch_normalization_27/gamma/read | ||||
| Identity,batch_normalization_27/beta/read | ||||
| Identity,batch_normalization_27/moving_mean/read | ||||
| Identity,batch_normalization_27/moving_variance/read | ||||
| Identity,conv2d_31/kernel/read | ||||
| Identity,batch_normalization_28/gamma/read | ||||
| Identity,batch_normalization_28/beta/read | ||||
| Identity,batch_normalization_28/moving_mean/read | ||||
| Identity,batch_normalization_28/moving_variance/read | ||||
| Identity,conv2d_32/kernel/read | ||||
| Identity,batch_normalization_29/gamma/read | ||||
| Identity,batch_normalization_29/beta/read | ||||
| Identity,batch_normalization_29/moving_mean/read | ||||
| Identity,batch_normalization_29/moving_variance/read | ||||
| Identity,conv2d_33/kernel/read | ||||
| Identity,batch_normalization_30/gamma/read | ||||
| Identity,batch_normalization_30/beta/read | ||||
| Identity,batch_normalization_30/moving_mean/read | ||||
| Identity,batch_normalization_30/moving_variance/read | ||||
| Identity,conv2d_34/kernel/read | ||||
| Identity,batch_normalization_31/gamma/read | ||||
| Identity,batch_normalization_31/beta/read | ||||
| Identity,batch_normalization_31/moving_mean/read | ||||
| Identity,batch_normalization_31/moving_variance/read | ||||
| Identity,conv2d_35/kernel/read | ||||
| Identity,batch_normalization_32/gamma/read | ||||
| Identity,batch_normalization_32/beta/read | ||||
| Identity,batch_normalization_32/moving_mean/read | ||||
| Identity,batch_normalization_32/moving_variance/read | ||||
| Identity,conv2d_36/kernel/read | ||||
| Identity,batch_normalization_33/gamma/read | ||||
| Identity,batch_normalization_33/beta/read | ||||
| Identity,batch_normalization_33/moving_mean/read | ||||
| Identity,batch_normalization_33/moving_variance/read | ||||
| Identity,conv2d_37/kernel/read | ||||
| Identity,batch_normalization_34/gamma/read | ||||
| Identity,batch_normalization_34/beta/read | ||||
| Identity,batch_normalization_34/moving_mean/read | ||||
| Identity,batch_normalization_34/moving_variance/read | ||||
| Identity,conv2d_38/kernel/read | ||||
| Identity,batch_normalization_35/gamma/read | ||||
| Identity,batch_normalization_35/beta/read | ||||
| Identity,batch_normalization_35/moving_mean/read | ||||
| Identity,batch_normalization_35/moving_variance/read | ||||
| Identity,conv2d_39/kernel/read | ||||
| Identity,batch_normalization_36/gamma/read | ||||
| Identity,batch_normalization_36/beta/read | ||||
| Identity,batch_normalization_36/moving_mean/read | ||||
| Identity,batch_normalization_36/moving_variance/read | ||||
| Identity,conv2d_40/kernel/read | ||||
| Identity,batch_normalization_37/gamma/read | ||||
| Identity,batch_normalization_37/beta/read | ||||
| Identity,batch_normalization_37/moving_mean/read | ||||
| Identity,batch_normalization_37/moving_variance/read | ||||
| Identity,conv2d_41/kernel/read | ||||
| Identity,batch_normalization_38/gamma/read | ||||
| Identity,batch_normalization_38/beta/read | ||||
| Identity,batch_normalization_38/moving_mean/read | ||||
| Identity,batch_normalization_38/moving_variance/read | ||||
| Identity,conv2d_42/kernel/read | ||||
| Identity,batch_normalization_39/gamma/read | ||||
| Identity,batch_normalization_39/beta/read | ||||
| Identity,batch_normalization_39/moving_mean/read | ||||
| Identity,batch_normalization_39/moving_variance/read | ||||
| Identity,conv2d_43/kernel/read | ||||
| Identity,conv2d_44/kernel/read | ||||
| Identity,batch_normalization_40/gamma/read | ||||
| Identity,batch_normalization_40/beta/read | ||||
| Identity,batch_normalization_40/moving_mean/read | ||||
| Identity,batch_normalization_40/moving_variance/read | ||||
| Identity,conv2d_45/kernel/read | ||||
| Identity,batch_normalization_41/gamma/read | ||||
| Identity,batch_normalization_41/beta/read | ||||
| Identity,batch_normalization_41/moving_mean/read | ||||
| Identity,batch_normalization_41/moving_variance/read | ||||
| Identity,conv2d_46/kernel/read | ||||
| Identity,batch_normalization_42/gamma/read | ||||
| Identity,batch_normalization_42/beta/read | ||||
| Identity,batch_normalization_42/moving_mean/read | ||||
| Identity,batch_normalization_42/moving_variance/read | ||||
| Identity,conv2d_47/kernel/read | ||||
| Identity,batch_normalization_43/gamma/read | ||||
| Identity,batch_normalization_43/beta/read | ||||
| Identity,batch_normalization_43/moving_mean/read | ||||
| Identity,batch_normalization_43/moving_variance/read | ||||
| Identity,conv2d_48/kernel/read | ||||
| Identity,batch_normalization_44/gamma/read | ||||
| Identity,batch_normalization_44/beta/read | ||||
| Identity,batch_normalization_44/moving_mean/read | ||||
| Identity,batch_normalization_44/moving_variance/read | ||||
| Identity,conv2d_49/kernel/read | ||||
| Identity,batch_normalization_45/gamma/read | ||||
| Identity,batch_normalization_45/beta/read | ||||
| Identity,batch_normalization_45/moving_mean/read | ||||
| Identity,batch_normalization_45/moving_variance/read | ||||
| Identity,conv2d_50/kernel/read | ||||
| Identity,batch_normalization_46/gamma/read | ||||
| Identity,batch_normalization_46/beta/read | ||||
| Identity,batch_normalization_46/moving_mean/read | ||||
| Identity,batch_normalization_46/moving_variance/read | ||||
| Identity,conv2d_51/kernel/read | ||||
| Identity,batch_normalization_47/gamma/read | ||||
| Identity,batch_normalization_47/beta/read | ||||
| Identity,batch_normalization_47/moving_mean/read | ||||
| Identity,batch_normalization_47/moving_variance/read | ||||
| Identity,conv2d_52/kernel/read | ||||
| Identity,batch_normalization_48/gamma/read | ||||
| Identity,batch_normalization_48/beta/read | ||||
| Identity,batch_normalization_48/moving_mean/read | ||||
| Identity,batch_normalization_48/moving_variance/read | ||||
| Identity,dense/kernel/read | ||||
| Identity,dense/bias/read | ||||
| Pad,Pad | ||||
| Conv2D,conv2d/Conv2D | ||||
| Identity,initial_conv | ||||
| MaxPool,max_pooling2d/MaxPool | ||||
| Identity,initial_max_pool | ||||
| FusedBatchNorm,batch_normalization/FusedBatchNorm | ||||
| Relu,Relu | ||||
| Conv2D,conv2d_1/Conv2D | ||||
| Conv2D,conv2d_2/Conv2D | ||||
| FusedBatchNorm,batch_normalization_1/FusedBatchNorm | ||||
| Relu,Relu_1 | ||||
| Conv2D,conv2d_3/Conv2D | ||||
| FusedBatchNorm,batch_normalization_2/FusedBatchNorm | ||||
| Relu,Relu_2 | ||||
| Conv2D,conv2d_4/Conv2D | ||||
| Add,add | ||||
| FusedBatchNorm,batch_normalization_3/FusedBatchNorm | ||||
| Relu,Relu_3 | ||||
| Conv2D,conv2d_5/Conv2D | ||||
| FusedBatchNorm,batch_normalization_4/FusedBatchNorm | ||||
| Relu,Relu_4 | ||||
| Conv2D,conv2d_6/Conv2D | ||||
| FusedBatchNorm,batch_normalization_5/FusedBatchNorm | ||||
| Relu,Relu_5 | ||||
| Conv2D,conv2d_7/Conv2D | ||||
| Add,add_1 | ||||
| FusedBatchNorm,batch_normalization_6/FusedBatchNorm | ||||
| Relu,Relu_6 | ||||
| Conv2D,conv2d_8/Conv2D | ||||
| FusedBatchNorm,batch_normalization_7/FusedBatchNorm | ||||
| Relu,Relu_7 | ||||
| Conv2D,conv2d_9/Conv2D | ||||
| FusedBatchNorm,batch_normalization_8/FusedBatchNorm | ||||
| Relu,Relu_8 | ||||
| Conv2D,conv2d_10/Conv2D | ||||
| Add,add_2 | ||||
| Identity,block_layer1 | ||||
| FusedBatchNorm,batch_normalization_9/FusedBatchNorm | ||||
| Relu,Relu_9 | ||||
| Pad,Pad_1 | ||||
| Conv2D,conv2d_12/Conv2D | ||||
| Conv2D,conv2d_11/Conv2D | ||||
| FusedBatchNorm,batch_normalization_10/FusedBatchNorm | ||||
| Relu,Relu_10 | ||||
| Pad,Pad_2 | ||||
| Conv2D,conv2d_13/Conv2D | ||||
| FusedBatchNorm,batch_normalization_11/FusedBatchNorm | ||||
| Relu,Relu_11 | ||||
| Conv2D,conv2d_14/Conv2D | ||||
| Add,add_3 | ||||
| FusedBatchNorm,batch_normalization_12/FusedBatchNorm | ||||
| Relu,Relu_12 | ||||
| Conv2D,conv2d_15/Conv2D | ||||
| FusedBatchNorm,batch_normalization_13/FusedBatchNorm | ||||
| Relu,Relu_13 | ||||
| Conv2D,conv2d_16/Conv2D | ||||
| FusedBatchNorm,batch_normalization_14/FusedBatchNorm | ||||
| Relu,Relu_14 | ||||
| Conv2D,conv2d_17/Conv2D | ||||
| Add,add_4 | ||||
| FusedBatchNorm,batch_normalization_15/FusedBatchNorm | ||||
| Relu,Relu_15 | ||||
| Conv2D,conv2d_18/Conv2D | ||||
| FusedBatchNorm,batch_normalization_16/FusedBatchNorm | ||||
| Relu,Relu_16 | ||||
| Conv2D,conv2d_19/Conv2D | ||||
| FusedBatchNorm,batch_normalization_17/FusedBatchNorm | ||||
| Relu,Relu_17 | ||||
| Conv2D,conv2d_20/Conv2D | ||||
| Add,add_5 | ||||
| FusedBatchNorm,batch_normalization_18/FusedBatchNorm | ||||
| Relu,Relu_18 | ||||
| Conv2D,conv2d_21/Conv2D | ||||
| FusedBatchNorm,batch_normalization_19/FusedBatchNorm | ||||
| Relu,Relu_19 | ||||
| Conv2D,conv2d_22/Conv2D | ||||
| FusedBatchNorm,batch_normalization_20/FusedBatchNorm | ||||
| Relu,Relu_20 | ||||
| Conv2D,conv2d_23/Conv2D | ||||
| Add,add_6 | ||||
| Identity,block_layer2 | ||||
| FusedBatchNorm,batch_normalization_21/FusedBatchNorm | ||||
| Relu,Relu_21 | ||||
| Pad,Pad_3 | ||||
| Conv2D,conv2d_25/Conv2D | ||||
| Conv2D,conv2d_24/Conv2D | ||||
| FusedBatchNorm,batch_normalization_22/FusedBatchNorm | ||||
| Relu,Relu_22 | ||||
| Pad,Pad_4 | ||||
| Conv2D,conv2d_26/Conv2D | ||||
| FusedBatchNorm,batch_normalization_23/FusedBatchNorm | ||||
| Relu,Relu_23 | ||||
| Conv2D,conv2d_27/Conv2D | ||||
| Add,add_7 | ||||
| FusedBatchNorm,batch_normalization_24/FusedBatchNorm | ||||
| Relu,Relu_24 | ||||
| Conv2D,conv2d_28/Conv2D | ||||
| FusedBatchNorm,batch_normalization_25/FusedBatchNorm | ||||
| Relu,Relu_25 | ||||
| Conv2D,conv2d_29/Conv2D | ||||
| FusedBatchNorm,batch_normalization_26/FusedBatchNorm | ||||
| Relu,Relu_26 | ||||
| Conv2D,conv2d_30/Conv2D | ||||
| Add,add_8 | ||||
| FusedBatchNorm,batch_normalization_27/FusedBatchNorm | ||||
| Relu,Relu_27 | ||||
| Conv2D,conv2d_31/Conv2D | ||||
| FusedBatchNorm,batch_normalization_28/FusedBatchNorm | ||||
| Relu,Relu_28 | ||||
| Conv2D,conv2d_32/Conv2D | ||||
| FusedBatchNorm,batch_normalization_29/FusedBatchNorm | ||||
| Relu,Relu_29 | ||||
| Conv2D,conv2d_33/Conv2D | ||||
| Add,add_9 | ||||
| FusedBatchNorm,batch_normalization_30/FusedBatchNorm | ||||
| Relu,Relu_30 | ||||
| Conv2D,conv2d_34/Conv2D | ||||
| FusedBatchNorm,batch_normalization_31/FusedBatchNorm | ||||
| Relu,Relu_31 | ||||
| Conv2D,conv2d_35/Conv2D | ||||
| FusedBatchNorm,batch_normalization_32/FusedBatchNorm | ||||
| Relu,Relu_32 | ||||
| Conv2D,conv2d_36/Conv2D | ||||
| Add,add_10 | ||||
| FusedBatchNorm,batch_normalization_33/FusedBatchNorm | ||||
| Relu,Relu_33 | ||||
| Conv2D,conv2d_37/Conv2D | ||||
| FusedBatchNorm,batch_normalization_34/FusedBatchNorm | ||||
| Relu,Relu_34 | ||||
| Conv2D,conv2d_38/Conv2D | ||||
| FusedBatchNorm,batch_normalization_35/FusedBatchNorm | ||||
| Relu,Relu_35 | ||||
| Conv2D,conv2d_39/Conv2D | ||||
| Add,add_11 | ||||
| FusedBatchNorm,batch_normalization_36/FusedBatchNorm | ||||
| Relu,Relu_36 | ||||
| Conv2D,conv2d_40/Conv2D | ||||
| FusedBatchNorm,batch_normalization_37/FusedBatchNorm | ||||
| Relu,Relu_37 | ||||
| Conv2D,conv2d_41/Conv2D | ||||
| FusedBatchNorm,batch_normalization_38/FusedBatchNorm | ||||
| Relu,Relu_38 | ||||
| Conv2D,conv2d_42/Conv2D | ||||
| Add,add_12 | ||||
| Identity,block_layer3 | ||||
| FusedBatchNorm,batch_normalization_39/FusedBatchNorm | ||||
| Relu,Relu_39 | ||||
| Pad,Pad_5 | ||||
| Conv2D,conv2d_44/Conv2D | ||||
| Conv2D,conv2d_43/Conv2D | ||||
| FusedBatchNorm,batch_normalization_40/FusedBatchNorm | ||||
| Relu,Relu_40 | ||||
| Pad,Pad_6 | ||||
| Conv2D,conv2d_45/Conv2D | ||||
| FusedBatchNorm,batch_normalization_41/FusedBatchNorm | ||||
| Relu,Relu_41 | ||||
| Conv2D,conv2d_46/Conv2D | ||||
| Add,add_13 | ||||
| FusedBatchNorm,batch_normalization_42/FusedBatchNorm | ||||
| Relu,Relu_42 | ||||
| Conv2D,conv2d_47/Conv2D | ||||
| FusedBatchNorm,batch_normalization_43/FusedBatchNorm | ||||
| Relu,Relu_43 | ||||
| Conv2D,conv2d_48/Conv2D | ||||
| FusedBatchNorm,batch_normalization_44/FusedBatchNorm | ||||
| Relu,Relu_44 | ||||
| Conv2D,conv2d_49/Conv2D | ||||
| Add,add_14 | ||||
| FusedBatchNorm,batch_normalization_45/FusedBatchNorm | ||||
| Relu,Relu_45 | ||||
| Conv2D,conv2d_50/Conv2D | ||||
| FusedBatchNorm,batch_normalization_46/FusedBatchNorm | ||||
| Relu,Relu_46 | ||||
| Conv2D,conv2d_51/Conv2D | ||||
| FusedBatchNorm,batch_normalization_47/FusedBatchNorm | ||||
| Relu,Relu_47 | ||||
| Conv2D,conv2d_52/Conv2D | ||||
| Add,add_15 | ||||
| Identity,block_layer4 | ||||
| FusedBatchNorm,batch_normalization_48/FusedBatchNorm | ||||
| Relu,Relu_48 | ||||
| Mean,Mean | ||||
| Identity,final_reduce_mean | ||||
| Reshape,Reshape | ||||
| MatMul,dense/MatMul | ||||
| BiasAdd,dense/BiasAdd | ||||
| Identity,final_dense | ||||
| ArgMax,ArgMax | ||||
| Softmax,softmax_tensor | ||||
|  | ||||
| @ -471,7 +471,7 @@ | ||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||
|                                 Depending on a build machine, default value is not always enough. | ||||
|                             --> | ||||
|                             <argLine> -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                             <argLine> -Dfile.encoding=UTF-8  -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> | ||||
|                         </configuration> | ||||
|                     </plugin> | ||||
|                 </plugins> | ||||
|  | ||||
| @ -343,7 +343,7 @@ public class LayerOpValidation extends BaseOpValidation { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testIm2Col() { | ||||
|         //OpValidationSuite.ignoreFailing();      //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873 | ||||
|         //OpValidationSuite.ignoreFailing();      //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873 | ||||
|         Nd4j.getRandom().setSeed(12345); | ||||
| 
 | ||||
|         int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; | ||||
|  | ||||
| @ -480,7 +480,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { | ||||
|                 dLdInExpected_1.putColumn(i, prod_1); | ||||
|             } | ||||
|             dLdInExpected_1.divi(preReduceInput); | ||||
|             dLdInExpected_1.muliColumnVector(dLdOut_1.reshape(3, 1));    //Reshape is a hack around https://github.com/deeplearning4j/deeplearning4j/issues/5530 | ||||
|             dLdInExpected_1.muliColumnVector(dLdOut_1.reshape(3, 1));    //Reshape is a hack around https://github.com/eclipse/deeplearning4j/issues/5530 | ||||
|             //System.out.println(dLdInExpected_1); | ||||
|             /* | ||||
|             [[   24.0000,   12.0000,    8.0000,    6.0000], | ||||
|  | ||||
| @ -2004,7 +2004,7 @@ public class ShapeOpValidation extends BaseOpValidation { | ||||
|     @Test | ||||
|     public void testCastEmpty(){ | ||||
|         INDArray emptyLong = Nd4j.empty(DataType.LONG); | ||||
|         int dtype = 9;  //INT = 9 - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/array/DataType.h | ||||
|         int dtype = 9;  //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h | ||||
|         DynamicCustomOp op = DynamicCustomOp.builder("cast") | ||||
|                 .addInputs(emptyLong) | ||||
|                 .addIntegerArguments(dtype) | ||||
|  | ||||
| @ -326,7 +326,7 @@ public class TransformOpValidation extends BaseOpValidation { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchToSpace() { | ||||
|         //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 | ||||
|         //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 | ||||
|         Nd4j.getRandom().setSeed(1337); | ||||
| 
 | ||||
|         int miniBatch = 4; | ||||
| @ -363,7 +363,7 @@ public class TransformOpValidation extends BaseOpValidation { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSpaceToBatch() { | ||||
|         //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 | ||||
|         //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 | ||||
| 
 | ||||
|         Nd4j.getRandom().setSeed(7331); | ||||
| 
 | ||||
| @ -1281,7 +1281,7 @@ public class TransformOpValidation extends BaseOpValidation { | ||||
|                     out = sd.math().isInfinite(in); | ||||
|                     break; | ||||
|                 case 2: | ||||
|                     //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872 | ||||
|                     //TODO: IsMax supports both bool and float out: https://github.com/eclipse/deeplearning4j/issues/6872 | ||||
|                     inArr = Nd4j.create(new double[]{-3, 5, 0, 2}); | ||||
|                     exp = Nd4j.create(new boolean[]{false, true, false, false}); | ||||
|                     out = sd.math().isMax(in); | ||||
|  | ||||
| @ -61,10 +61,10 @@ public class ExecutionTests extends BaseNd4jTest { | ||||
|         if(TFGraphTestZooModels.isPPC()){ | ||||
|             /* | ||||
|             Ugly hack to temporarily disable tests on PPC only on CI | ||||
|             Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 | ||||
|             Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657 | ||||
|             These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions | ||||
|              */ | ||||
|             log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); | ||||
|             log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); | ||||
|             OpValidationSuite.ignoreFailing(); | ||||
|         } | ||||
| 
 | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user