From 82bdcc21d2cfb7320bd0b7ad756eac81ab4e95e0 Mon Sep 17 00:00:00 2001 From: agibsonccc Date: Tue, 16 Mar 2021 11:57:24 +0900 Subject: [PATCH] All tests compile --- .../impl/CSVLineSequenceRecordReaderTest.java | 2 +- .../CSVMultiSequenceRecordReaderTest.java | 2 +- .../impl/CSVSequenceRecordReaderTest.java | 2 +- .../impl/FileBatchRecordReaderTest.java | 2 +- .../impl/JacksonLineRecordReaderTest.java | 2 +- .../reader/impl/JacksonRecordReaderTest.java | 2 +- .../records/reader/impl/LineReaderTest.java | 2 +- .../reader/impl/RegexRecordReaderTest.java | 2 +- .../impl/TestCollectionRecordReaders.java | 4 +- .../impl/TestConcatenatingRecordReader.java | 4 +- .../reader/impl/TestSerialization.java | 4 +- .../TransformProcessRecordReaderTests.java | 6 +- .../datavec/api/split/InputSplitTests.java | 4 +- .../split/NumberedFileInputSplitTests.java | 98 ++-- .../api/split/TestStreamInputSplit.java | 27 +- .../api/split/parittion/PartitionerTests.java | 6 +- .../api/transform/TestTransformProcess.java | 4 +- .../transform/condition/TestConditions.java | 6 +- .../api/transform/filter/TestFilters.java | 6 +- .../datavec/api/transform/join/TestJoin.java | 49 +- .../transform/ops/AggregatorImplsTest.java | 37 +- .../transform/reduce/TestMultiOpReduce.java | 20 +- .../api/transform/reduce/TestReductions.java | 4 +- .../api/transform/schema/TestJsonYaml.java | 4 +- .../transform/schema/TestSchemaMethods.java | 4 +- .../TestReduceSequenceByWindowFunction.java | 4 +- .../transform/sequence/TestSequenceSplit.java | 4 +- .../sequence/TestWindowFunctions.java | 4 +- .../serde/TestCustomTransformJsonYaml.java | 4 +- .../transform/serde/TestYamlJsonSerde.java | 4 +- .../transform/stringreduce/TestReduce.java | 4 +- .../transform/RegressionTestJson.java | 4 +- .../api/transform/transform/TestJsonYaml.java | 4 +- .../transform/transform/TestTransforms.java | 4 +- .../TestNDArrayWritableTransforms.java | 4 +- .../transform/ndarray/TestYamlJsonSerde.java | 4 +- .../org/datavec/api/transform/ui/TestUI.java | 20 +- .../TestNDArrayWritableAndSerialization.java | 4 +- .../org/datavec/arrow/ArrowConverterTest.java | 2 +- ...rowWritableRecordTimeSeriesBatchTests.java | 12 +- .../org/datavec/image/LabelGeneratorTest.java | 2 +- .../org/datavec/image/loader/LoaderTests.java | 24 +- .../datavec/image/loader/TestImageLoader.java | 4 +- .../image/loader/TestNativeImageLoader.java | 30 +- .../FileBatchRecordReaderTest.java | 2 +- .../recordreader/TestImageRecordReader.java | 53 +- .../TestObjectDetectionRecordReader.java | 17 +- .../objdetect/TestVocLabelProvider.java | 16 +- .../image/transform/TestImageTransform.java | 8 +- .../poi/excel/ExcelRecordWriterTest.java | 2 +- .../reader/impl/JDBCRecordReaderTest.java | 2 +- ...ocalTransformProcessRecordReaderTests.java | 4 +- .../transforms/analysis/TestAnalyzeLocal.java | 13 +- .../TestLineRecordReaderFunction.java | 6 +- .../TestNDArrayToWritablesFunction.java | 4 +- .../TestWritablesToNDArrayFunction.java | 4 +- .../TestWritablesToStringFunctions.java | 4 +- .../transform/TestGeoTransforms.java | 19 +- .../transform/TestPythonTransformProcess.java | 33 +- .../transforms/transform/join/TestJoin.java | 4 +- .../rank/TestCalculateSortedRank.java | 4 +- .../sequence/TestConvertToSequence.java | 6 +- datavec/datavec-spark/pom.xml | 6 + .../datavec/spark/TestKryoSerialization.java | 8 +- .../TestLineRecordReaderFunction.java | 6 +- .../TestNDArrayToWritablesFunction.java | 4 +- ...PairSequenceRecordReaderBytesFunction.java | 18 +- .../TestRecordReaderBytesFunction.java | 18 +- .../functions/TestRecordReaderFunction.java | 19 +- ...TestSequenceRecordReaderBytesFunction.java | 18 +- .../TestSequenceRecordReaderFunction.java | 23 +- .../TestWritablesToNDArrayFunction.java | 4 +- .../TestWritablesToStringFunctions.java | 4 +- .../spark/storage/TestSparkStorageUtils.java | 6 +- .../spark/transform/DataFramesTests.java | 4 +- .../spark/transform/NormalizationTests.java | 4 +- .../transform/analysis/TestAnalysis.java | 4 +- .../spark/transform/join/TestJoin.java | 4 +- .../rank/TestCalculateSortedRank.java | 4 +- .../sequence/TestConvertToSequence.java | 6 +- .../org/datavec/spark/util/TestSparkUtil.java | 4 +- .../java/org/deeplearning4j/BaseDL4JTest.java | 6 +- .../LayerHelperValidationUtil.java | 18 +- .../java/org/deeplearning4j/RandomTests.java | 6 +- .../java/org/deeplearning4j/TestUtils.java | 4 +- .../datasets/MnistFetcherTest.java | 2 +- .../deeplearning4j/datasets/TestDataSets.java | 2 +- .../RecordReaderDataSetiteratorTest.java | 7 +- .../RecordReaderMultiDataSetIteratorTest.java | 7 +- .../fetchers/SvhnDataFetcherTest.java | 4 +- .../iterator/CombinedPreProcessorTests.java | 4 +- .../iterator/DataSetSplitterTests.java | 104 ++-- .../DummyBlockDataSetIteratorTests.java | 8 +- .../EarlyTerminationDataSetIteratorTest.java | 23 +- ...lyTerminationMultiDataSetIteratorTest.java | 28 +- .../JointMultiDataSetIteratorTests.java | 11 +- .../iterator/LoaderIteratorTests.java | 6 +- .../iterator/MultiDataSetSplitterTests.java | 122 ++--- .../iterator/MultipleEpochsIteratorTest.java | 7 +- .../datasets/iterator/TestAsyncIterator.java | 8 +- .../iterator/TestEmnistDataSetIterator.java | 11 +- .../datasets/iterator/TestFileIterators.java | 54 +- .../earlystopping/TestEarlyStopping.java | 24 +- .../TestEarlyStoppingCompGraph.java | 4 +- .../eval/EvaluationToolsTests.java | 2 +- .../exceptions/TestInvalidConfigurations.java | 89 +++- .../exceptions/TestInvalidInput.java | 6 +- .../exceptions/TestRecordReaders.java | 38 +- .../gradientcheck/AttentionLayerTest.java | 38 +- .../gradientcheck/DropoutGradientCheck.java | 6 +- .../GlobalPoolingGradientCheckTests.java | 4 +- .../gradientcheck/GradientCheckTests.java | 34 +- .../GradientCheckTestsComputationGraph.java | 60 +-- .../GradientCheckTestsMasking.java | 22 +- .../gradientcheck/LRNGradientCheckTests.java | 4 +- .../gradientcheck/LSTMGradientCheckTests.java | 14 +- .../LossFunctionGradientCheck.java | 14 +- .../NoBiasGradientCheckTests.java | 14 +- .../OutputLayerGradientChecks.java | 10 +- .../gradientcheck/RnnGradientChecks.java | 16 +- .../UtilLayerGradientChecks.java | 4 +- .../gradientcheck/VaeGradientCheckTests.java | 12 +- .../gradientcheck/YoloGradientCheckTests.java | 21 +- .../MultiLayerNeuralNetConfigurationTest.java | 2 +- .../nn/conf/constraints/TestConstraints.java | 6 +- .../nn/conf/dropout/TestDropout.java | 8 +- .../conf/preprocessor/TestPreProcessors.java | 24 +- .../nn/conf/weightnoise/TestWeightNoise.java | 6 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 86 ++-- .../nn/graph/ComputationGraphTestRNN.java | 4 +- .../nn/graph/TestCompGraphCNN.java | 87 ++-- .../nn/graph/TestCompGraphUnsupervised.java | 6 +- .../nn/graph/TestComputationGraphNetwork.java | 53 +- .../nn/graph/TestSetGetParameters.java | 4 +- .../nn/graph/TestVariableLengthTSCG.java | 15 +- .../nn/graph/graphnodes/TestGraphNodes.java | 4 +- .../deeplearning4j/nn/layers/TestDropout.java | 8 +- .../convolution/ConvDataFormatTests.java | 64 +-- .../convolution/TestConvolutionModes.java | 4 +- .../layers/custom/TestCustomActivation.java | 6 +- .../nn/layers/custom/TestCustomLayers.java | 6 +- .../objdetect/TestYolo2OutputLayer.java | 34 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 2 +- .../pooling/GlobalPoolingMaskingTests.java | 12 +- .../layers/recurrent/RnnDataFormatTests.java | 60 +-- .../recurrent/TestLastTimeStepLayer.java | 4 +- .../recurrent/TestRecurrentWeightInit.java | 12 +- .../nn/layers/recurrent/TestRnnLayers.java | 24 +- .../nn/layers/recurrent/TestSimpleRnn.java | 6 +- .../layers/recurrent/TestTimeDistributed.java | 4 +- .../samediff/SameDiffCustomLayerTests.java | 75 +-- .../nn/layers/samediff/TestSameDiffConv.java | 10 +- .../nn/layers/samediff/TestSameDiffDense.java | 14 +- .../samediff/TestSameDiffDenseVertex.java | 6 +- .../layers/samediff/TestSameDiffLambda.java | 12 +- .../layers/samediff/TestSameDiffOutput.java | 8 +- .../TestReconstructionDistributions.java | 4 +- .../nn/layers/variational/TestVAE.java | 4 +- .../nn/misc/CloseNetworkTests.java | 12 +- .../deeplearning4j/nn/misc/TestLrChanges.java | 4 +- .../nn/misc/TestMemoryReports.java | 6 +- .../nn/misc/TestNetConversion.java | 4 +- .../nn/misc/WorkspaceTests.java | 12 +- .../nn/mkldnn/ValidateMKLDNN.java | 6 +- .../nn/multilayer/BackPropMLPTest.java | 2 +- .../nn/multilayer/MultiLayerTest.java | 3 +- .../nn/multilayer/MultiLayerTestRNN.java | 4 +- .../nn/multilayer/TestMasking.java | 4 +- .../nn/multilayer/TestSetGetParameters.java | 8 +- .../nn/multilayer/TestVariableLengthTS.java | 12 +- .../rl/TestMultiModelGradientApplication.java | 6 +- .../nn/transferlearning/TestFrozenLayers.java | 12 +- .../TestTransferLearningJson.java | 4 +- .../TestTransferLearningModelSerializer.java | 4 +- .../TransferLearningComplex.java | 8 +- .../nn/updater/TestGradientNormalization.java | 4 +- .../nn/updater/TestUpdaters.java | 12 +- .../nn/updater/custom/TestCustomUpdater.java | 6 +- .../optimize/solver/TestOptimizers.java | 17 +- .../accumulation/ThresholdAlgorithmTests.java | 4 +- .../listener/TestCheckpointListener.java | 40 +- .../listener/TestFailureListener.java | 16 +- .../optimizer/listener/TestListeners.java | 21 +- .../parallelism/FancyBlockingQueueTests.java | 4 +- ...lExistingMiniBatchDataSetIteratorTest.java | 2 +- .../parallelism/RandomTests.java | 4 +- .../perf/listener/SystemPollingTest.java | 2 +- .../perf/listener/TestHardWareMetric.java | 8 +- .../listener/TestSystemInfoPrintListener.java | 22 +- .../regressiontest/MiscRegressionTests.java | 8 +- .../regressiontest/RegressionTest050.java | 8 +- .../regressiontest/RegressionTest060.java | 4 +- .../regressiontest/RegressionTest071.java | 4 +- .../regressiontest/RegressionTest080.java | 4 +- .../regressiontest/RegressionTest100a.java | 16 +- .../regressiontest/RegressionTest100b3.java | 12 +- .../regressiontest/RegressionTest100b4.java | 18 +- .../regressiontest/RegressionTest100b6.java | 10 +- .../TestDistributionDeserializer.java | 6 +- .../CompareTrainingImplementations.java | 26 +- .../util/CrashReportingUtilTest.java | 2 +- .../deeplearning4j/util/ModelGuesserTest.java | 7 +- .../util/ModelSerializerTest.java | 2 +- .../util/ModelValidatorTests.java | 20 +- .../util/SerializationUtilsTest.java | 2 +- .../deeplearning4j/util/TestUIDProvider.java | 6 +- .../deeplearning4j/cuda/TestDataTypes.java | 4 +- .../org/deeplearning4j/cuda/TestUtils.java | 4 +- .../deeplearning4j/cuda/ValidateCuDNN.java | 6 +- .../cuda/convolution/ConvDataFormatTests.java | 4 +- .../cuda/convolution/TestConvolution.java | 8 +- .../gradientcheck/CuDNNGradientChecks.java | 6 +- .../cuda/lstm/ValidateCudnnDropout.java | 6 +- .../cuda/lstm/ValidateCudnnLSTM.java | 4 +- .../cuda/util/CuDNNValidationUtil.java | 2 +- .../graph/data/TestGraphLoading.java | 14 +- .../graph/data/TestGraphLoadingWeighted.java | 13 +- .../deeplearning4j/graph/graph/TestGraph.java | 19 +- .../deepwalk/DeepWalkGradientCheck.java | 21 +- .../graph/models/deepwalk/TestDeepWalk.java | 35 +- .../models/deepwalk/TestGraphHuffman.java | 8 +- .../solr/ltr/model/ScoringModelTest.java | 2 +- .../nn/modelimport/keras/KerasTestUtils.java | 2 +- .../nn/modelimport/keras/MiscTests.java | 63 +-- .../configurations/FullModelComparisons.java | 72 ++- .../Keras2ModelConfigurationTest.java | 8 +- .../keras/e2e/KerasCustomLayerTest.java | 2 +- .../keras/e2e/KerasCustomLossTest.java | 2 +- .../keras/e2e/KerasLambdaTest.java | 2 +- .../keras/e2e/KerasModelEndToEndTest.java | 2 +- .../keras/e2e/KerasYolo9000Test.java | 2 +- .../layers/core/KerasActivationLayer.java | 4 +- .../keras/optimizers/OptimizerImport.java | 2 +- .../TimeSeriesGeneratorImportTest.java | 6 +- .../sequence/TimeSeriesGeneratorTest.java | 4 +- .../text/TokenizerImportTest.java | 12 +- .../preprocessing/text/TokenizerTest.java | 6 +- .../weights/KerasWeightSettingTests.java | 115 ++--- .../deeplearning4j-nlp/pom.xml | 6 + .../java/org/deeplearning4j/TsneTest.java | 64 --- .../vectorizer/BagOfWordsVectorizerTest.java | 45 +- .../vectorizer/TfidfVectorizerTest.java | 133 ++--- .../iterator/TestBertIterator.java | 14 +- .../TestCnnSentenceDataSetIterator.java | 10 +- .../inmemory/InMemoryLookupTableTest.java | 39 +- .../loader/WordVectorSerializerTest.java | 33 +- .../reader/impl/FlatModelUtilsTest.java | 12 +- .../wordvectors/WordVectorsImplTest.java | 8 +- .../models/fasttext/FastTextTest.java | 115 +++-- .../ParagraphVectorsTest.java | 73 +-- .../sequencevectors/SequenceVectorsTest.java | 14 +- .../walkers/impl/PopularityWalkerTest.java | 10 +- .../graph/walkers/impl/RandomWalkerTest.java | 8 +- .../walkers/impl/WeightedWalkerTest.java | 10 +- .../AbstractElementFactoryTest.java | 8 +- .../serialization/VocabWordFactoryTest.java | 8 +- .../impl/GraphTransformerTest.java | 8 +- .../ParallelTransformerIteratorTest.java | 37 +- .../models/word2vec/Word2VecTestsSmall.java | 17 +- .../word2vec/Word2VecVisualizationTests.java | 10 +- .../iterator/Word2VecDataSetIteratorTest.java | 8 +- .../wordstore/VocabConstructorTest.java | 29 +- .../wordstore/VocabularyHolderTest.java | 4 +- .../wordstore/inmemory/AbstractCacheTest.java | 8 +- .../AsyncLabelAwareIteratorTest.java | 8 +- .../BasicLabelAwareIteratorTest.java | 15 +- .../DefaultDocumentIteratorTest.java | 4 +- .../FileDocumentIteratorTest.java | 41 +- .../FileLabelAwareIteratorTest.java | 30 +- .../FilenamesLabelAwareIteratorTest.java | 22 +- .../documentiterator/LabelsSourceTest.java | 8 +- .../AggregatingSentenceIteratorTest.java | 8 +- .../BasicLineIteratorTest.java | 15 +- .../BasicResultSetIteratorTest.java | 8 +- .../MutipleEpochsSentenceIteratorTest.java | 8 +- .../PrefetchingSentenceIteratorTest.java | 13 +- .../StreamLineIteratorTest.java | 6 +- .../BertWordPieceTokenizerTests.java | 18 +- .../tokenizer/DefaulTokenizerTests.java | 6 +- .../tokenizer/NGramTokenizerTest.java | 6 +- .../EndingPreProcessorTest.java | 4 +- .../NGramTokenizerFactoryTest.java | 4 +- .../wordstore/InMemoryVocabStoreTests.java | 4 +- .../ParameterServerParallelWrapperTest.java | 2 +- .../InplaceParallelInferenceTest.java | 4 +- .../parallelism/ParallelInferenceTest.java | 52 +- .../parallelism/ParallelWrapperTest.java | 6 +- .../parallelism/TestListeners.java | 4 +- .../TestParallelEarlyStopping.java | 4 +- .../TestParallelEarlyStoppingUI.java | 8 +- .../factory/DefaultTrainerContextTest.java | 4 +- .../factory/SymmetricTrainerContextTest.java | 4 +- .../BatchedInferenceObservableTest.java | 14 +- .../main/ParallelWrapperMainTest.java | 17 +- .../SparkSequenceVectorsTest.java | 14 +- .../export/ExportContainerTest.java | 8 +- .../models/word2vec/SparkWord2VecTest.java | 16 +- .../embeddings/word2vec/Word2VecTest.java | 27 +- .../spark/text/BaseSparkTest.java | 8 +- .../spark/text/TextPipelineTest.java | 18 +- .../spark/parameterserver/BaseSparkTest.java | 8 +- ...haredTrainingAccumulationFunctionTest.java | 8 +- .../SharedTrainingAggregateFunctionTest.java | 8 +- .../iterators/VirtualDataSetIteratorTest.java | 8 +- .../iterators/VirtualIteratorTest.java | 8 +- .../elephas/TestElephasImport.java | 4 +- .../train/GradientSharingTrainingTest.java | 44 +- .../deeplearning4j/spark/BaseSparkTest.java | 8 +- .../spark/TestEarlyStoppingSpark.java | 8 +- .../TestEarlyStoppingSparkCompGraph.java | 4 +- .../org/deeplearning4j/spark/TestKryo.java | 6 +- .../deeplearning4j/spark/common/AddTest.java | 4 +- .../spark/data/TestShuffleExamples.java | 6 +- .../spark/data/TestSparkDataUtils.java | 2 +- .../spark/datavec/MiniBatchTests.java | 6 +- .../datavec/TestDataVecDataSetFunctions.java | 34 +- .../spark/datavec/TestExport.java | 6 +- .../spark/datavec/TestPreProcessedData.java | 6 +- .../datavec/iterator/TestIteratorUtils.java | 4 +- .../spark/impl/TestKryoWarning.java | 16 +- .../repartition/BalancedPartitionerTest.java | 8 +- .../HashingBalancedPartitionerTest.java | 4 +- .../impl/customlayer/TestCustomLayer.java | 2 +- .../impl/graph/TestSparkComputationGraph.java | 18 +- .../spark/impl/misc/TestFrozenLayers.java | 12 +- .../impl/multilayer/TestMiscFunctions.java | 6 +- .../multilayer/TestSparkDl4jMultiLayer.java | 6 +- ...arameterAveragingSparkVsSingleMachine.java | 8 +- .../spark/impl/paramavg/TestJsonYaml.java | 4 +- ...TestSparkMultiLayerParameterAveraging.java | 42 +- .../impl/paramavg/util/ExportSupportTest.java | 6 +- .../stats/TestTrainingStatsCollection.java | 6 +- .../spark/time/TestTimeSource.java | 6 +- .../spark/ui/TestListeners.java | 6 +- .../spark/util/MLLIbUtilTest.java | 4 +- .../spark/util/TestRepartitioning.java | 9 +- .../spark/util/TestValidation.java | 25 +- .../ui/TestComponentSerialization.java | 4 +- .../org/deeplearning4j/ui/TestRendering.java | 6 +- .../org/deeplearning4j/ui/TestStandAlone.java | 4 +- .../ui/TestStorageMetaData.java | 4 +- .../ui/stats/TestStatsClasses.java | 159 +++--- .../ui/stats/TestStatsListener.java | 6 +- .../ui/stats/TestTransferStatsCollection.java | 6 +- .../ui/storage/TestStatsStorage.java | 66 +-- .../deeplearning4j/ui/TestRemoteReceiver.java | 14 +- .../org/deeplearning4j/ui/TestSameDiffUI.java | 25 +- .../org/deeplearning4j/ui/TestVertxUI.java | 52 +- .../deeplearning4j/ui/TestVertxUIManual.java | 24 +- .../ui/TestVertxUIMultiSession.java | 38 +- .../org/deeplearning4j/zoo/MiscTests.java | 6 +- .../org/deeplearning4j/zoo/TestDownload.java | 22 +- .../org/deeplearning4j/zoo/TestImageNet.java | 12 +- .../deeplearning4j/zoo/TestInstantiation.java | 26 +- .../org/deeplearning4j/zoo/TestUtils.java | 2 +- .../IntegrationTestBaselineGenerator.java | 2 +- .../integration/IntegrationTestRunner.java | 63 +-- .../integration/IntegrationTestsDL4J.java | 19 +- .../integration/IntegrationTestsSameDiff.java | 15 +- .../deeplearning4j/integration/TestUtils.java | 2 +- .../allocator/DeviceLocalNDArrayTests.java | 6 +- .../allocator/impl/MemoryTrackerTest.java | 2 +- .../jita/workspace/CudaWorkspaceTest.java | 4 +- .../buffer/BaseCudaDataBufferTest.java | 8 +- .../test/java/org/nd4j/OpValidationSuite.java | 7 +- .../java/org/nd4j/autodiff/TestOpMapping.java | 26 +- .../java/org/nd4j/autodiff/TestSessions.java | 8 +- .../internal/TestDependencyTracker.java | 4 +- .../opvalidation/ActivationGradChecks.java | 4 +- .../opvalidation/BaseOpValidation.java | 4 +- .../opvalidation/LayerOpValidation.java | 136 ++--- .../opvalidation/LossOpValidation.java | 10 +- .../opvalidation/MiscOpValidation.java | 26 +- .../opvalidation/RandomOpValidation.java | 16 +- .../opvalidation/ReductionBpOpValidation.java | 14 +- .../opvalidation/ReductionOpValidation.java | 32 +- .../opvalidation/RnnOpValidation.java | 4 +- .../opvalidation/ShapeOpValidation.java | 36 +- .../opvalidation/TransformOpValidation.java | 26 +- .../autodiff/samediff/ConvConfigTests.java | 6 +- .../samediff/FailingSameDiffTests.java | 10 +- .../samediff/FlatBufferSerdeTest.java | 44 +- .../samediff/GraphTransformUtilTests.java | 8 +- .../nd4j/autodiff/samediff/MemoryMgrTest.java | 4 +- .../autodiff/samediff/NameScopeTests.java | 20 +- .../samediff/SameDiffMultiThreadTests.java | 39 +- .../autodiff/samediff/SameDiffOutputTest.java | 8 +- .../SameDiffSpecifiedLossVarsTests.java | 4 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 100 ++-- .../samediff/SameDiffTrainingTest.java | 12 +- .../listeners/CheckpointListenerTest.java | 35 +- .../listeners/ExecDebuggingListenerTest.java | 2 +- .../samediff/listeners/ListenerTest.java | 6 +- .../listeners/ProfilingListenerTest.java | 21 +- .../nd4j/autodiff/ui/FileReadWriteTests.java | 26 +- .../org/nd4j/autodiff/ui/UIListenerTest.java | 33 +- .../nd4j/evaluation/CustomEvaluationTest.java | 6 +- .../nd4j/evaluation/EmptyEvaluationTests.java | 24 +- .../nd4j/evaluation/EvalCustomThreshold.java | 6 +- .../org/nd4j/evaluation/EvalJsonTest.java | 6 +- .../java/org/nd4j/evaluation/EvalTest.java | 30 +- .../nd4j/evaluation/EvaluationBinaryTest.java | 14 +- .../evaluation/EvaluationCalibrationTest.java | 6 +- .../org/nd4j/evaluation/NewInstanceTest.java | 4 +- .../org/nd4j/evaluation/ROCBinaryTest.java | 20 +- .../java/org/nd4j/evaluation/ROCTest.java | 8 +- .../nd4j/evaluation/RegressionEvalTest.java | 30 +- .../evaluation/TestLegacyJsonLoading.java | 4 +- .../java/org/nd4j/imports/ByteOrderTests.java | 10 +- .../java/org/nd4j/imports/ExecutionTests.java | 8 +- .../test/java/org/nd4j/imports/NameTests.java | 4 +- .../nd4j/imports/TFGraphs/BERTGraphTest.java | 18 +- .../nd4j/imports/TFGraphs/CustomOpTests.java | 6 +- .../imports/TFGraphs/NodeReaderTests.java | 6 +- .../TFGraphs/TFGraphTestAllHelper.java | 30 +- .../TFGraphs/TFGraphTestAllLibnd4j.java | 25 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 24 +- .../imports/TFGraphs/TFGraphTestList.java | 28 +- .../TFGraphs/TFGraphTestZooModels.java | 27 +- .../TFGraphs/ValidateZooModelPredictions.java | 31 +- .../nd4j/imports/TensorFlowImportTest.java | 63 +-- .../java/org/nd4j/imports/TestReverse.java | 2 +- .../listeners/ImportModelDebugger.java | 8 +- .../java/org/nd4j/linalg/AveragingTests.java | 12 +- .../java/org/nd4j/linalg/BaseNd4jTest.java | 4 +- .../java/org/nd4j/linalg/DataTypeTest.java | 6 +- .../org/nd4j/linalg/InputValidationTests.java | 4 +- .../test/java/org/nd4j/linalg/LoneTest.java | 20 +- .../test/java/org/nd4j/linalg/MmulBug.java | 4 +- .../org/nd4j/linalg/NDArrayTestsFortran.java | 60 +-- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 481 ++++++++++-------- .../org/nd4j/linalg/Nd4jTestsComparisonC.java | 44 +- .../linalg/Nd4jTestsComparisonFortran.java | 40 +- .../test/java/org/nd4j/linalg/Nd4jTestsF.java | 4 +- .../java/org/nd4j/linalg/ShufflesTests.java | 20 +- .../test/java/org/nd4j/linalg/TestEigen.java | 12 +- .../java/org/nd4j/linalg/ToStringTest.java | 4 +- .../linalg/activations/TestActivation.java | 10 +- .../java/org/nd4j/linalg/api/TestBackend.java | 4 +- .../org/nd4j/linalg/api/TestEnvironment.java | 4 +- .../nd4j/linalg/api/TestNDArrayCreation.java | 18 +- .../linalg/api/TestNDArrayCreationUtil.java | 14 +- .../org/nd4j/linalg/api/TestNamespaces.java | 2 +- .../org/nd4j/linalg/api/blas/LapackTest.java | 4 +- .../org/nd4j/linalg/api/blas/Level1Test.java | 6 +- .../org/nd4j/linalg/api/blas/Level2Test.java | 4 +- .../org/nd4j/linalg/api/blas/Level3Test.java | 4 +- .../linalg/api/blas/params/ParamsTestsF.java | 4 +- .../linalg/api/buffer/DataBufferTests.java | 8 +- .../api/buffer/DataTypeValidationTests.java | 53 +- .../api/buffer/DoubleDataBufferTest.java | 27 +- .../api/buffer/FloatDataBufferTest.java | 41 +- .../linalg/api/buffer/IntDataBufferTests.java | 4 +- .../linalg/api/indexing/IndexingTests.java | 6 +- .../linalg/api/indexing/IndexingTestsC.java | 10 +- .../resolve/NDArrayIndexResolveTests.java | 4 +- .../api/indexing/shape/IndexShapeTests.java | 4 +- .../api/indexing/shape/IndexShapeTests2d.java | 4 +- .../api/iterator/NDIndexIteratorTest.java | 4 +- .../api/ndarray/TestNdArrReadWriteTxt.java | 25 +- .../api/ndarray/TestNdArrReadWriteTxtC.java | 13 +- .../linalg/api/ndarray/TestSerialization.java | 4 +- .../TestSerializationDoubleToFloat.java | 8 +- .../TestSerializationFloatToDouble.java | 10 +- .../org/nd4j/linalg/api/rng/RngTests.java | 4 +- .../linalg/api/string/TestFormatting.java | 2 +- .../api/tad/TestTensorAlongDimension.java | 8 +- .../java/org/nd4j/linalg/blas/BlasTests.java | 8 +- .../linalg/broadcast/BasicBroadcastTests.java | 79 +-- .../compression/CompressionMagicTests.java | 8 +- .../CompressionPerformanceTests.java | 6 +- .../compression/CompressionSerDeTests.java | 4 +- .../linalg/compression/CompressionTests.java | 10 +- .../linalg/convolution/ConvolutionTests.java | 44 +- .../linalg/convolution/ConvolutionTestsC.java | 14 +- .../nd4j/linalg/convolution/DeconvTests.java | 19 +- .../java/org/nd4j/linalg/crash/CrashTest.java | 6 +- .../org/nd4j/linalg/crash/SpecialTests.java | 28 +- .../nd4j/linalg/custom/CustomOpsTests.java | 214 ++++---- .../linalg/custom/ExpandableOpsTests.java | 6 +- .../dataset/BalanceMinibatchesTest.java | 31 +- .../dataset/CachingDataSetIteratorTest.java | 4 +- .../org/nd4j/linalg/dataset/DataSetTest.java | 35 +- .../dataset/ImagePreProcessortTest.java | 6 +- .../linalg/dataset/KFoldIteratorTest.java | 28 +- .../nd4j/linalg/dataset/MinMaxStatsTest.java | 4 +- .../MiniBatchFileDataSetIteratorTest.java | 17 +- .../nd4j/linalg/dataset/MultiDataSetTest.java | 6 +- .../dataset/MultiNormalizerHybridTest.java | 8 +- .../MultiNormalizerMinMaxScalerTest.java | 8 +- .../MultiNormalizerStandardizeTest.java | 8 +- .../dataset/NormalizerMinMaxScalerTest.java | 4 +- .../dataset/NormalizerSerializerTest.java | 70 +-- .../NormalizerStandardizeLabelsTest.java | 8 +- .../dataset/NormalizerStandardizeTest.java | 4 +- .../nd4j/linalg/dataset/NormalizerTests.java | 14 +- .../linalg/dataset/PreProcessor3D4DTest.java | 4 +- .../linalg/dataset/PreProcessorTests.java | 4 +- .../linalg/dataset/StandardScalerTest.java | 6 +- .../CompositeDataSetPreProcessorTest.java | 18 +- .../CropAndResizeDataSetPreProcessorTest.java | 64 ++- .../api/preprocessor/MinMaxStrategyTest.java | 4 +- .../PermuteDataSetPreProcessorTest.java | 17 +- ...RGBtoGrayscaleDataSetPreProcessorTest.java | 18 +- .../UnderSamplingPreProcessorTest.java | 54 +- .../dimensionalityreduction/TestPCA.java | 38 +- .../TestRandomProjection.java | 41 +- .../org/nd4j/linalg/factory/Nd4jTest.java | 34 +- .../nd4j/linalg/factory/ops/NDBaseTest.java | 4 +- .../nd4j/linalg/factory/ops/NDLossTest.java | 6 +- .../nd4j/linalg/generated/SDLinalgTest.java | 10 +- .../linalg/indexing/BooleanIndexingTest.java | 4 +- .../nd4j/linalg/indexing/TransformsTest.java | 6 +- .../linalg/inverse/TestInvertMatrices.java | 24 +- .../org/nd4j/linalg/lapack/LapackTestsC.java | 12 +- .../org/nd4j/linalg/lapack/LapackTestsF.java | 12 +- .../org/nd4j/linalg/learning/UpdaterTest.java | 4 +- .../linalg/learning/UpdaterValidation.java | 4 +- .../lossfunctions/LossFunctionJson.java | 4 +- .../lossfunctions/LossFunctionTest.java | 4 +- .../TestLossFunctionsSizeChecks.java | 2 +- .../nd4j/linalg/memory/AccountingTests.java | 8 +- .../nd4j/linalg/memory/CloseableTests.java | 28 +- .../memory/DeviceLocalNDArrayTests.java | 6 +- .../linalg/mixed/MixedDataTypesTests.java | 49 +- .../nd4j/linalg/mixed/StringArrayTests.java | 12 +- .../multithreading/MultithreadedTests.java | 4 +- .../nd4j/linalg/nativ/NativeBlasTests.java | 12 +- .../nd4j/linalg/nativ/OpsMappingTests.java | 2 +- .../org/nd4j/linalg/ops/DerivativeTests.java | 14 +- .../nd4j/linalg/ops/OpConstructorTests.java | 10 +- .../nd4j/linalg/ops/OpExecutionerTests.java | 54 +- .../nd4j/linalg/ops/OpExecutionerTestsC.java | 60 +-- .../org/nd4j/linalg/ops/RationalTanhTest.java | 4 +- .../ops/broadcast/row/RowVectorOpsC.java | 4 +- .../org/nd4j/linalg/ops/copy/CopyTest.java | 4 +- .../linalg/options/ArrayOptionsTests.java | 14 +- .../nd4j/linalg/profiling/InfNanTests.java | 64 ++- .../profiling/OperationProfilerTests.java | 125 +++-- .../profiling/PerformanceTrackerTests.java | 20 +- .../profiling/StackAggregatorTests.java | 18 +- .../java/org/nd4j/linalg/rng/HalfTests.java | 10 +- .../linalg/rng/RandomPerformanceTests.java | 4 +- .../java/org/nd4j/linalg/rng/RandomTests.java | 36 +- .../nd4j/linalg/rng/RngValidationTests.java | 10 +- .../nd4j/linalg/schedule/TestSchedules.java | 6 +- .../nd4j/linalg/serde/BasicSerDeTests.java | 6 +- .../org/nd4j/linalg/serde/JsonSerdeTests.java | 6 +- .../nd4j/linalg/serde/LargeSerDeTests.java | 12 +- .../nd4j/linalg/serde/NumpyFormatTests.java | 124 +++-- .../org/nd4j/linalg/shape/EmptyTests.java | 28 +- .../org/nd4j/linalg/shape/LongShapeTests.java | 6 +- .../nd4j/linalg/shape/NDArrayMathTests.java | 4 +- .../nd4j/linalg/shape/ShapeBufferTests.java | 4 +- .../org/nd4j/linalg/shape/ShapeTests.java | 10 +- .../org/nd4j/linalg/shape/ShapeTestsC.java | 26 +- .../nd4j/linalg/shape/StaticShapeTests.java | 6 +- .../java/org/nd4j/linalg/shape/TADTests.java | 10 +- .../nd4j/linalg/shape/concat/ConcatTests.java | 10 +- .../linalg/shape/concat/ConcatTestsC.java | 16 +- .../shape/concat/padding/PaddingTests.java | 6 +- .../shape/concat/padding/PaddingTestsC.java | 6 +- .../linalg/shape/indexing/IndexingTests.java | 22 +- .../linalg/shape/indexing/IndexingTestsC.java | 22 +- .../shape/ones/LeadingAndTrailingOnes.java | 4 +- .../shape/ones/LeadingAndTrailingOnesC.java | 4 +- .../linalg/shape/reshape/ReshapeTests.java | 6 +- .../org/nd4j/linalg/slicing/SlicingTests.java | 4 +- .../nd4j/linalg/slicing/SlicingTestsC.java | 6 +- .../org/nd4j/linalg/specials/CudaTests.java | 12 +- .../org/nd4j/linalg/specials/LongTests.java | 22 +- .../nd4j/linalg/specials/RavelIndexTest.java | 10 +- .../nd4j/linalg/specials/SortCooTests.java | 12 +- .../nd4j/linalg/util/DataSetUtilsTest.java | 22 +- .../org/nd4j/linalg/util/NDArrayUtilTest.java | 6 +- .../nd4j/linalg/util/PreconditionsTest.java | 6 +- .../java/org/nd4j/linalg/util/ShapeTest.java | 4 +- .../java/org/nd4j/linalg/util/ShapeTestC.java | 20 +- .../org/nd4j/linalg/util/TestArrayUtils.java | 6 +- .../org/nd4j/linalg/util/TestCollections.java | 6 +- .../nd4j/linalg/util/ValidationUtilTests.java | 112 ++-- .../linalg/workspace/BasicWorkspaceTests.java | 18 +- .../linalg/workspace/CudaWorkspaceTests.java | 6 +- .../workspace/CyclicWorkspaceTests.java | 6 +- .../nd4j/linalg/workspace/DebugModeTests.java | 12 +- .../workspace/EndlessWorkspaceTests.java | 16 +- .../workspace/SpecialWorkspaceTests.java | 52 +- .../workspace/WorkspaceProviderTests.java | 22 +- .../java/org/nd4j/list/NDArrayListTest.java | 4 +- .../org/nd4j/serde/base64/Nd4jBase64Test.java | 4 +- .../nd4j/serde/binary/BinarySerdeTest.java | 4 +- .../java/org/nd4j/smoketests/SmokeTest.java | 2 +- .../org/nd4j/systeminfo/TestSystemInfo.java | 2 +- .../custom/CustomOpTensorflowInteropTests.kt | 4 +- .../nd4j/common/base/TestPreconditions.java | 6 +- .../common/function/FunctionalUtilsTest.java | 4 +- .../nd4j/common/io/ClassPathResourceTest.java | 16 +- .../org/nd4j/common/loader/TestFileBatch.java | 30 +- .../nd4j/common/primitives/AtomicTest.java | 4 +- .../common/primitives/CounterMapTest.java | 4 +- .../nd4j/common/primitives/CounterTest.java | 4 +- .../common/resources/TestArchiveUtils.java | 14 +- .../nd4j/common/resources/TestStrumpf.java | 23 +- .../org/nd4j/common/tools/BToolsTest.java | 4 +- .../org/nd4j/common/tools/InfoLineTest.java | 4 +- .../org/nd4j/common/tools/InfoValuesTest.java | 4 +- .../nd4j/common/tools/PropertyParserTest.java | 20 +- .../java/org/nd4j/common/tools/SISTest.java | 21 +- .../org/nd4j/common/util/ArrayUtilTest.java | 4 +- .../nd4j/common/util/OneTimeLoggerTest.java | 6 +- .../RemoteParameterServerClientTests.java | 15 +- .../ParameterServerClientPartialTest.java | 47 +- .../client/ParameterServerClientTest.java | 39 +- .../VoidParameterServerStressTest.java | 30 +- .../distributed/VoidParameterServerTest.java | 37 +- .../conf/VoidConfigurationTest.java | 37 +- .../distributed/logic/ClipboardTest.java | 14 +- .../logic/FrameCompletionHandlerTest.java | 19 +- .../logic/routing/InterleavedRouterTest.java | 19 +- .../distributed/messages/FrameTest.java | 19 +- .../distributed/messages/VoidMessageTest.java | 16 +- .../aggregations/VoidAggregationTest.java | 14 +- .../transport/RoutedTransportTest.java | 16 +- .../util/NetworkOrganizerTest.java | 15 +- .../v2/DelayedModelParameterServerTest.java | 50 +- .../v2/ModelParameterServerTest.java | 13 +- .../v2/chunks/impl/FileChunksTrackerTest.java | 8 +- .../impl/InmemoryChunksTrackerTest.java | 8 +- .../v2/messages/VoidMessageTest.java | 4 +- .../history/HashHistoryHolderTest.java | 4 +- .../transport/impl/AeronUdpTransportTest.java | 8 +- .../v2/transport/impl/DummyTransportTest.java | 4 +- .../v2/util/MeshOrganizerTest.java | 8 +- .../v2/util/MessageSplitterTest.java | 4 +- .../node/ParameterServerNodeTest.java | 35 +- .../updater/storage/UpdaterStorageTests.java | 6 +- .../status/play/StatusServerTests.java | 6 +- .../status/play/StorageTests.java | 11 +- .../updater/ParameterServerUpdaterTests.java | 10 +- .../updater/storage/UpdaterStorageTests.java | 24 +- .../nd4j/aeron/ipc/AeronNDArraySerdeTest.java | 12 +- .../nd4j/aeron/ipc/LargeNdArrayIpcTest.java | 18 +- .../nd4j/aeron/ipc/NDArrayMessageTest.java | 8 +- .../org/nd4j/aeron/ipc/NdArrayIpcTest.java | 16 +- .../ipc/chunk/ChunkAccumulatorTests.java | 8 +- .../ipc/chunk/NDArrayMessageChunkTests.java | 10 +- .../response/AeronNDArrayResponseTest.java | 12 +- .../java/org/nd4j/arrow/ArrowSerdeTest.java | 4 +- .../org/nd4j/TestNd4jKryoSerialization.java | 20 +- .../frameworkimport/onnx/TestOnnxIR.kt | 42 +- .../importer/TestOnnxFrameworkImporter.kt | 4 +- .../onnx/modelzoo/TestPretrainedModels.kt | 6 +- .../tensorflow/TestTensorflowIR.kt | 67 +-- .../importer/TestTensorflowImporter.kt | 6 +- .../test/java/PythonBasicExecutionTest.java | 13 +- .../src/test/java/PythonCollectionsTest.java | 2 +- .../test/java/PythonContextManagerTest.java | 2 +- .../src/test/java/PythonGCTest.java | 2 +- .../src/test/java/PythonMultiThreadTest.java | 12 +- .../test/java/PythonPrimitiveTypesTest.java | 2 +- .../src/test/java/PythonNumpyBasicTest.java | 2 +- .../test/java/PythonNumpyCollectionsTest.java | 2 +- .../src/test/java/PythonNumpyGCTest.java | 2 +- .../src/test/java/PythonNumpyImportTest.java | 2 +- .../test/java/PythonNumpyMultiThreadTest.java | 2 +- .../java/PythonNumpyServiceLoaderTest.java | 2 +- rl4j/rl4j-core/pom.xml | 16 + .../rl4j/agent/AgentLearnerTest.java | 4 +- .../deeplearning4j/rl4j/agent/AgentTest.java | 16 +- .../NonRecurrentActorCriticHelperTest.java | 6 +- .../NonRecurrentAdvantageActorCriticTest.java | 8 +- .../RecurrentActorCriticHelperTest.java | 6 +- .../RecurrentAdvantageActorCriticTest.java | 8 +- .../learning/algorithm/dqn/DoubleDQNTest.java | 8 +- .../algorithm/dqn/StandardDQNTest.java | 8 +- .../NonRecurrentNStepQLearningHelperTest.java | 4 +- .../NonRecurrentNStepQLearningTest.java | 4 +- .../RecurrentNStepQLearningHelperTest.java | 6 +- .../RecurrentNStepQLearningTest.java | 4 +- .../behavior/LearningBehaviorTest.java | 10 +- .../learning/update/FeaturesBuilderTest.java | 6 +- .../learning/update/FeaturesLabelsTest.java | 4 +- .../agent/learning/update/FeaturesTest.java | 6 +- .../agent/learning/update/GradientsTest.java | 6 +- .../agent/learning/update/UpdateRuleTest.java | 8 +- .../AsyncGradientsNeuralNetUpdaterTest.java | 2 +- .../AsyncLabelsNeuralNetUpdaterTest.java | 2 +- .../AsyncSharedNetworksUpdateHandlerTest.java | 6 +- .../SyncGradientsNeuralNetUpdaterTest.java | 2 +- .../sync/SyncLabelsNeuralNetUpdaterTest.java | 6 +- .../builder/BaseAgentLearnerBuilderTest.java | 6 +- .../ReplayMemoryExperienceHandlerTest.java | 4 +- .../StateActionExperienceHandlerTest.java | 4 +- .../rl4j/helper/INDArrayHelperTest.java | 4 +- .../rl4j/learning/HistoryProcessorTest.java | 4 +- .../learning/async/AsyncLearningTest.java | 6 +- .../async/AsyncThreadDiscreteTest.java | 12 +- .../rl4j/learning/async/AsyncThreadTest.java | 8 +- ...vantageActorCriticUpdateAlgorithmTest.java | 8 +- .../AsyncTrainingListenerListTest.java | 6 +- .../QLearningUpdateAlgorithmTest.java | 4 +- .../listener/TrainingListenerListTest.java | 4 +- .../rl4j/learning/sync/ExpReplayTest.java | 4 +- .../sync/StateActionRewardStateTest.java | 4 +- .../rl4j/learning/sync/SyncLearningTest.java | 6 +- .../qlearning/QLearningConfigurationTest.java | 7 +- .../discrete/QLearningDiscreteTest.java | 10 +- .../rl4j/network/ActorCriticNetworkTest.java | 6 +- .../rl4j/network/BaseNetworkTest.java | 6 +- .../ChannelToNetworkInputMapperTest.java | 4 +- .../network/CompoundNetworkHandlerTest.java | 4 +- .../network/ComputationGraphHandlerTest.java | 4 +- .../network/MultiLayerNetworkHandlerTest.java | 4 +- .../rl4j/network/NetworkHelperTest.java | 6 +- .../rl4j/network/QNetworkTest.java | 6 +- .../rl4j/network/ac/ActorCriticTest.java | 6 +- .../rl4j/network/dqn/DQNTest.java | 4 +- .../transform/TransformProcessTest.java | 300 ++++++----- .../filter/UniformSkippingFilterTest.java | 14 +- .../ArrayToINDArrayTransformTest.java | 6 +- .../operation/HistoryMergeTransformTest.java | 4 +- .../SimpleNormalizationTransformTest.java | 15 +- .../historymerge/CircularFifoStoreTest.java | 13 +- .../HistoryStackAssemblerTest.java | 4 +- .../rl4j/policy/PolicyTest.java | 4 +- .../rl4j/trainer/AsyncTrainerTest.java | 8 +- .../rl4j/trainer/SyncTrainerTest.java | 6 +- .../util/DataManagerTrainingListenerTest.java | 6 +- .../rl4j/mdp/gym/GymEnvTest.java | 8 +- 729 files changed, 6080 insertions(+), 5619 deletions(-) delete mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 5ce4cb254..b450983dc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index f108a4438..59a28f4b3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index e022746e0..2ce347893 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -26,7 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index 1acbf2fac..87d313ded 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index 4095d1af7..f182bc9ee 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -27,7 +27,7 @@ import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index 2e4a2261b..aa5026207 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -30,7 +30,7 @@ import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index dd81758d0..2339c47df 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -29,7 +29,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 997a6de10..dbf3ea379 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -32,7 +32,7 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index 3165c7df3..decbf0275 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestCollectionRecordReaders extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java index f99a36325..b39a678ce 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java @@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestConcatenatingRecordReader extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java index 2c40fcad0..933d65103 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java @@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; @@ -47,7 +47,7 @@ import java.io.*; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestSerialization extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index 7d00ef96b..ee2c9b091 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; @@ -38,8 +38,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TransformProcessRecordReaderTests extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index 0a5c7588f..74c4c3bc9 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.io.*; import java.net.URI; @@ -35,7 +35,7 @@ import java.util.ArrayList; import java.util.Random; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java index 34bb06eaa..72c06bd5d 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java @@ -20,13 +20,12 @@ package org.datavec.api.split; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.net.URI; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class NumberedFileInputSplitTests extends BaseND4JTest { @Test @@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest { runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithLeadingSpaces() { - String baseString = "/path/to/files/prefix-%5d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix-%5d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() { - String baseString = "/path/to/files/prefix%5d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class, () -> { + String baseString = "/path/to/files/prefix%5d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithLeadingPlusInPadding() { - String baseString = "/path/to/files/prefix%+5d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix%+5d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithLeadingMinusInPadding() { - String baseString = "/path/to/files/prefix%-5d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix%-5d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithTwoDigitsInPadding() { - String baseString = "/path/to/files/prefix%011d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix%011d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithInnerZerosInPadding() { - String baseString = "/path/to/files/prefix%101d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix%101d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() { - String baseString = "/path/to/files/prefix%0505d.suffix"; - int minIdx = 0; - int maxIdx = 10; - runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + assertThrows(IllegalArgumentException.class,() -> { + String baseString = "/path/to/files/prefix%0505d.suffix"; + int minIdx = 0; + int maxIdx = 10; + runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); + }); + } @@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest { String path = locs[j++].getPath(); String exp = String.format(baseString, i); String msg = exp + " vs " + path; - assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/" + assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/" } } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index 45fe5d77a..fc718ed56 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.function.Function; @@ -37,22 +38,22 @@ import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class TestStreamInputSplit extends BaseND4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testCsvSimple() throws Exception { - File dir = testDir.newFolder(); + public void testCsvSimple(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); File f1 = new File(dir, "file1.txt"); File f2 = new File(dir, "file2.txt"); @@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest { @Test - public void testCsvSequenceSimple() throws Exception { + public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception { - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f1 = new File(dir, "file1.txt"); File f2 = new File(dir, "file2.txt"); @@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest { } @Test - public void testShuffle() throws Exception { - File dir = testDir.newFolder(); + public void testShuffle(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); File f1 = new File(dir, "file1.txt"); File f2 = new File(dir, "file2.txt"); File f3 = new File(dir, "file3.txt"); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index 0627a5647..f9c5cb1b6 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -27,14 +27,14 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.Partitioner; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.io.File; import java.io.OutputStream; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class PartitionerTests extends BaseND4JTest { @Test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index 609ae3dc8..7a968ddfe 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestTransformProcess extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index 3cb62a4a1..f49e0c4d4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.writable.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestConditions extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 5ad6d6813..0b339bffa 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; @@ -36,8 +36,8 @@ import java.util.Collections; import java.util.List; import static java.util.Arrays.asList; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestFilters extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index 044a084b6..6db425d9e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; public class TestJoin extends BaseND4JTest { @Test - public void testJoin() { + public void testJoin(@TempDir Path testDir) { Schema firstSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build(); @@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest { Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build(); List> first = new ArrayList<>(); - first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1))); - first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11))); + first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1))); + first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11))); List> second = new ArrayList<>(); - second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100))); - second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110))); + second.add(Arrays.asList(new Text("key0"), new IntWritable(100))); + second.add(Arrays.asList(new Text("key1"), new IntWritable(110))); Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn") .setSchemas(firstSchema, secondSchema).build(); List> expected = new ArrayList<>(); - expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1), + expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1), new IntWritable(100))); - expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11), + expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11), new IntWritable(110))); @@ -94,27 +97,31 @@ public class TestJoin extends BaseND4JTest { } - @Test(expected = IllegalArgumentException.class) + @Test() public void testJoinValidation() { + assertThrows(IllegalArgumentException.class,() -> { + Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") + .build(); - Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") - .build(); + Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build(); - Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build(); + new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist") + .setSchemas(firstSchema, secondSchema).build(); + }); - new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist") - .setSchemas(firstSchema, secondSchema).build(); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testJoinValidation2() { + assertThrows(IllegalArgumentException.class,() -> { + Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") + .build(); - Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") - .build(); + Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build(); - Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build(); + new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema) + .build(); + }); - new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema) - .build(); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index e7c8de557..c2549b405 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -19,17 +19,18 @@ */ package org.datavec.api.transform.ops; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; + import org.junit.jupiter.api.DisplayName; +import static org.junit.jupiter.api.Assertions.*; + @DisplayName("Aggregator Impls Test") class AggregatorImplsTest extends BaseND4JTest { @@ -265,23 +266,25 @@ class AggregatorImplsTest extends BaseND4JTest { assertEquals(9, cu.get().toInt()); } - @Rule - public final ExpectedException exception = ExpectedException.none(); + @Test @DisplayName("Incompatible Aggregator Test") void incompatibleAggregatorTest() { - AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>(); - for (int i = 0; i < intList.size(); i++) { - sm.accept(intList.get(i)); - } - assertEquals(45, sm.get().toInt()); - AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>(); - for (int i = 0; i < intList.size(); i++) { - reverse.accept(intList.get(intList.size() - i - 1)); - } - exception.expect(UnsupportedOperationException.class); - sm.combine(reverse); - assertEquals(45, sm.get().toInt()); + assertThrows(UnsupportedOperationException.class,() -> { + AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>(); + for (int i = 0; i < intList.size(); i++) { + sm.accept(intList.get(i)); + } + assertEquals(45, sm.get().toInt()); + AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>(); + for (int i = 0; i < intList.size(); i++) { + reverse.accept(intList.get(intList.size() - i - 1)); + } + + sm.combine(reverse); + assertEquals(45, sm.get().toInt()); + }); + } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index fb24eb4dc..ec32079a7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -32,13 +32,13 @@ import org.datavec.api.transform.ops.AggregableMultiOp; import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestMultiOpReduce extends BaseND4JTest { @@ -46,10 +46,10 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testMultiOpReducerDouble() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2))); Map exp = new LinkedHashMap<>(); exp.put(ReduceOp.Min, 0.0); @@ -82,7 +82,7 @@ public class TestMultiOpReduce extends BaseND4JTest { assertEquals(out.get(0), new Text("someKey")); String msg = op.toString(); - assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); + assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg); } } @@ -126,7 +126,7 @@ public class TestMultiOpReduce extends BaseND4JTest { assertEquals(out.get(0), new Text("someKey")); String msg = op.toString(); - assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); + assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg); } } @@ -210,7 +210,7 @@ public class TestMultiOpReduce extends BaseND4JTest { assertEquals(out.get(0), new Text("someKey")); String msg = op.toString(); - assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); + assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg); } for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean, diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index 65b1bbf0d..f7aa89170 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestReductions extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index 72c98d266..0f9263bb4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -22,10 +22,10 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.metadata.ColumnMetaData; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestJsonYaml extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java index d2a7efbe3..1439cfc40 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java @@ -21,10 +21,10 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.ColumnType; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestSchemaMethods extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 7b57ba4d7..1bb9ae62a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; @@ -41,7 +41,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestReduceSequenceByWindowFunction extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index 1d219214f..c26eaec61 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; @@ -35,7 +35,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestSequenceSplit extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index 6e0ff65ee..ff45a3f3e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; @@ -37,7 +37,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestWindowFunctions extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java index 32100d93f..53b63bb49 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java @@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomTransform; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestCustomTransformJsonYaml extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java index 1076126c1..84da1c272 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java @@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform; import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestYamlJsonSerde extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index 30017e62a..f7eaa85ad 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestReduce extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java index e32fd5436..c6d4d3a67 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java @@ -50,7 +50,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; @@ -61,7 +61,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class RegressionTestJson extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java index dec17ac40..b45e67a67 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java @@ -50,13 +50,13 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestJsonYaml extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 4a2174c0c..c42981d27 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -59,7 +59,7 @@ import org.datavec.api.writable.*; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -72,7 +72,7 @@ import java.util.*; import java.util.concurrent.TimeUnit; import static junit.framework.TestCase.assertEquals; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestTransforms extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index c05b5cabf..8c4a44687 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestNDArrayWritableTransforms extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java index 019d03ab8..b60f4b903 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java @@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestYamlJsonSerde extends BaseND4JTest { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index 8761f183f..6032e13a3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestUI extends BaseND4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testUI() throws Exception { + public void testUI(@TempDir Path testDir) throws Exception { Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn") .addColumnInteger("IntColumn2").addColumnInteger("IntColumn3") .addColumnTime("TimeColumn", DateTimeZone.UTC).build(); @@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest { DataAnalysis da = new DataAnalysis(schema, list); - File fDir = testDir.newFolder(); + File fDir = testDir.toFile(); String tempDir = fDir.getAbsolutePath(); String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html"); System.out.println(outPath); @@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest { @Test - @Ignore + @Disabled public void testSequencePlot() throws Exception { Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx") diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java index 11db62c57..71149b9b2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java @@ -21,14 +21,14 @@ package org.datavec.api.writable; import org.datavec.api.transform.metadata.NDArrayMetaData; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestNDArrayWritableAndSerialization extends BaseND4JTest { diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index 23ffdb856..a0300d73c 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -41,7 +41,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index bb14ce351..a18dd11c0 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { @@ -69,7 +69,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { assertEquals(3,fieldVectors.size()); for(FieldVector fieldVector : fieldVectors) { for(int i = 0; i < fieldVector.getValueCount(); i++) { - assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i)); + assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector); } } @@ -79,7 +79,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { @Test //not worried about this till after next release - @Ignore + @Disabled public void testVariableLengthTS() { Schema.Builder schema = new Schema.Builder() .addColumnString("str") diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java index 5cdc2bf40..4ef1a2443 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java @@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.image.recordreader.ImageRecordReader; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java index c0ddabb43..b35b6966f 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java @@ -22,8 +22,8 @@ package org.datavec.image.loader; import org.apache.commons.io.FilenameUtils; import org.datavec.api.records.reader.RecordReader; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import java.io.File; @@ -32,9 +32,9 @@ import java.io.InputStream; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @@ -182,7 +182,7 @@ public class LoaderTests { } - @Ignore // Use when confirming data is getting stored + @Disabled // Use when confirming data is getting stored @Test public void testProcessCifar() { int row = 32; @@ -208,15 +208,15 @@ public class LoaderTests { int minibatch = 100; int nMinibatches = 50000 / minibatch; - for( int i=0; i()); - new ImageRecordReader().initialize(data, null); + assertThrows(IllegalArgumentException.class,() -> { + InputSplit data = new CollectionInputSplit(new ArrayList<>()); + new ImageRecordReader().initialize(data, null); + }); + } @Test - public void testMetaData() throws IOException { + public void testMetaData(@TempDir Path testDir) throws IOException { - File parentDir = testDir.newFolder(); + File parentDir = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir); // System.out.println(f.getAbsolutePath()); // System.out.println(f.getParentFile().getParentFile().getAbsolutePath()); @@ -104,11 +107,11 @@ public class TestImageRecordReader { } @Test - public void testImageRecordReaderLabelsOrder() throws Exception { + public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception { //Labels order should be consistent, regardless of file iteration order //Idea: labels order should be consistent regardless of input file order - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); File f0 = new File(f, "/class0/0.jpg"); File f1 = new File(f, "/class1/A.jpg"); @@ -135,11 +138,11 @@ public class TestImageRecordReader { @Test - public void testImageRecordReaderRandomization() throws Exception { + public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception { //Order of FileSplit+ImageRecordReader should be different after reset //Idea: labels order should be consistent regardless of input file order - File f0 = testDir.newFolder(); + File f0 = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); FileSplit fs = new FileSplit(f0, new Random(12345)); @@ -189,13 +192,13 @@ public class TestImageRecordReader { @Test - public void testImageRecordReaderRegression() throws Exception { + public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception { PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen(); ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen); - File rootDir = testDir.newFolder(); + File rootDir = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); FileSplit fs = new FileSplit(rootDir); rr.initialize(fs); @@ -244,10 +247,10 @@ public class TestImageRecordReader { } @Test - public void testListenerInvocationBatch() throws IOException { + public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException { ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); File parent = f; @@ -260,10 +263,10 @@ public class TestImageRecordReader { } @Test - public void testListenerInvocationSingle() throws IOException { + public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException { ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); - File parent = testDir.newFolder(); + File parent = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent); int numFiles = parent.list().length; rr.initialize(new FileSplit(parent)); @@ -315,7 +318,7 @@ public class TestImageRecordReader { @Test - public void testImageRecordReaderPathMultiLabelGenerator() throws Exception { + public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception { Nd4j.setDataType(DataType.FLOAT); //Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively // PLUS single value (Writable) regression label @@ -324,7 +327,7 @@ public class TestImageRecordReader { ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen); - File rootDir = testDir.newFolder(); + File rootDir = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); FileSplit fs = new FileSplit(rootDir); rr.initialize(fs); @@ -471,9 +474,9 @@ public class TestImageRecordReader { @Test - public void testNCHW_NCHW() throws Exception { + public void testNCHW_NCHW(@TempDir Path testDir) throws Exception { //Idea: labels order should be consistent regardless of input file order - File f0 = testDir.newFolder(); + File f0 = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); FileSplit fs0 = new FileSplit(f0, new Random(12345)); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java index a30d42827..4c69b76cf 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java @@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform; import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.ResizeImageTransform; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; @@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.net.URI; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestObjectDetectionRecordReader { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void test() throws Exception { + public void test(@TempDir Path testDir) throws Exception { for(boolean nchw : new boolean[]{true, false}) { ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); String path = new File(f, "000012.jpg").getParent(); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/objdetect/TestVocLabelProvider.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/objdetect/TestVocLabelProvider.java index 11114219a..0a4e61660 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/objdetect/TestVocLabelProvider.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/objdetect/TestVocLabelProvider.java @@ -21,27 +21,27 @@ package org.datavec.image.recordreader.objdetect; import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestVocLabelProvider { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testVocLabelProvider() throws Exception { + public void testVocLabelProvider(@TempDir Path testDir) throws Exception { - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f); String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent(); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java index 5e6a4f588..e337500ea 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java @@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.primitives.Pair; import org.datavec.image.data.ImageWritable; import org.datavec.image.loader.NativeImageLoader; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import java.awt.*; import java.util.LinkedList; @@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_imgproc.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @@ -255,7 +255,7 @@ public class TestImageTransform { assertEquals(22, transformed[1], 0); } - @Ignore + @Disabled @Test public void testFilterImageTransform() throws Exception { ImageWritable writable = makeRandomImage(0, 0, 4); diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java index 3d03f764e..bc706daa3 100644 --- a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java @@ -25,7 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.primitives.Triple; diff --git a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java index fb7daa5e9..a3ab033b6 100644 --- a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java +++ b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java @@ -49,7 +49,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.DisplayName; diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java index 3464f9547..2ec96607e 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java @@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class LocalTransformProcessRecordReaderTests { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java index b0489f9e0..2ed08bd95 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java @@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.Writable; import org.datavec.local.transforms.AnalyzeLocal; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.io.ClassPathResource; @@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestAnalyzeLocal { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test public void testAnalysisBasic() throws Exception { @@ -72,7 +71,7 @@ public class TestAnalyzeLocal { INDArray mean = arr.mean(0); INDArray std = arr.std(0); - for( int i=0; i<5; i++ ){ + for( int i = 0; i < 5; i++) { double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean(); double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev(); assertEquals(mean.getDouble(i), m, 1e-3); diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java index 182642ebe..11d4672b1 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; @@ -36,8 +36,8 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestLineRecordReaderFunction { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java index d7d2c55f4..37a86a2f3 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +33,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestNDArrayToWritablesFunction { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java index f86dbd411..1cc2943f8 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestWritablesToNDArrayFunction { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java index bbc08d90f..fca45adb1 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java @@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable; import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; import org.datavec.local.transforms.misc.WritablesToStringFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestWritablesToStringFunctions { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java index d1b304c96..f81fdfd2e 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java @@ -32,7 +32,8 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.*; @@ -40,14 +41,14 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author saudet */ public class TestGeoTransforms { - @BeforeClass + @BeforeAll public static void beforeClass() throws Exception { //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile(); @@ -63,7 +64,7 @@ public class TestGeoTransforms { @Test public void testCoordinatesDistanceTransform() throws Exception { Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev") - .build(); + .build(); Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|"); transform.setInputSchema(schema); @@ -72,14 +73,14 @@ public class TestGeoTransforms { assertEquals(4, out.numColumns()); assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames()); assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double), - out.getColumnTypes()); + out.getColumnTypes()); assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), - transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); + transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), - new DoubleWritable(Math.sqrt(160))), - transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), - new Text("10|5")))); + new DoubleWritable(Math.sqrt(160))), + transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), + new Text("10|5")))); } @Test diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 2659b6136..d8b9d423b 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -30,7 +30,8 @@ import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.api.writable.*; import org.datavec.python.PythonCondition; import org.datavec.python.PythonTransform; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -43,7 +44,7 @@ import java.util.List; import static junit.framework.TestCase.assertTrue; import static org.datavec.api.transform.schema.Schema.Builder; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @NotThreadSafe public class TestPythonTransformProcess { @@ -77,8 +78,9 @@ public class TestPythonTransformProcess { } - @Test(timeout = 60000L) - public void testMixedTypes() throws Exception{ + @Test() + @Timeout(60000L) + public void testMixedTypes() throws Exception { Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") @@ -99,7 +101,7 @@ public class TestPythonTransformProcess { .inputSchema(initialSchema) .build() ).build(); - List inputs = Arrays.asList((Writable)new IntWritable(10), + List inputs = Arrays.asList(new IntWritable(10), new FloatWritable(3.5f), new Text("5"), new DoubleWritable(2.0) @@ -109,8 +111,9 @@ public class TestPythonTransformProcess { assertEquals(((LongWritable)outputs.get(4)).get(), 36); } - @Test(timeout = 60000L) - public void testNDArray() throws Exception{ + @Test() + @Timeout(60000L) + public void testNDArray() throws Exception { long[] shape = new long[]{3, 2}; INDArray arr1 = Nd4j.rand(shape); INDArray arr2 = Nd4j.rand(shape); @@ -145,8 +148,9 @@ public class TestPythonTransformProcess { } - @Test(timeout = 60000L) - public void testNDArray2() throws Exception{ + @Test() + @Timeout(60000L) + public void testNDArray2() throws Exception { long[] shape = new long[]{3, 2}; INDArray arr1 = Nd4j.rand(shape); INDArray arr2 = Nd4j.rand(shape); @@ -181,7 +185,8 @@ public class TestPythonTransformProcess { } - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testNDArrayMixed() throws Exception{ long[] shape = new long[]{3, 2}; INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape); @@ -217,7 +222,8 @@ public class TestPythonTransformProcess { } - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testPythonFilter() { Schema schema = new Builder().addColumnInteger("column").build(); @@ -237,8 +243,9 @@ public class TestPythonTransformProcess { } - @Test(timeout = 60000L) - public void testPythonFilterAndTransform() throws Exception{ + @Test() + @Timeout(60000L) + public void testPythonFilterAndTransform() throws Exception { Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java index 6d0006b4c..adb511603 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java @@ -28,11 +28,11 @@ import org.datavec.api.writable.*; import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestJoin { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java index 9061b9ba7..39f3405a9 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java @@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestCalculateSortedRank { diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java index 2659e3701..04a4a5c47 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java @@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestConvertToSequence { diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 431fa2233..27648bdfe 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -41,6 +41,12 @@ + + com.tdunning + t-digest + 3.2 + test + org.scala-lang scala-library diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/TestKryoSerialization.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/TestKryoSerialization.java index c63aafc1c..a684fb61d 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/TestKryoSerialization.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/TestKryoSerialization.java @@ -25,15 +25,15 @@ import org.apache.spark.serializer.SerializerInstance; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.nio.ByteBuffer; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestKryoSerialization extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java index 8e9a7150b..d7a906597 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; @@ -35,8 +35,8 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestLineRecordReaderFunction extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java index 2b5ebf12f..4990cfe03 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java @@ -24,7 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.spark.transform.misc.NDArrayToWritablesFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestNDArrayToWritablesFunction { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java index 64df3e679..3ce9afe46 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java @@ -38,9 +38,10 @@ import org.datavec.spark.functions.pairdata.PairSequenceRecordReaderBytesFunctio import org.datavec.spark.functions.pairdata.PathToKeyConverter; import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename; import org.datavec.spark.util.DataVecSparkUtil; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import scala.Tuple2; @@ -50,16 +51,13 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Test - public void test() throws Exception { + public void test(@TempDir Path testDir) throws Exception { //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader //For example: use to combine input and labels data from separate files for training a RNN if(Platform.isWindows()) { @@ -67,7 +65,7 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { } JavaSparkContext sc = getContext(); - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/video/").copyDirectory(f); String path = f.getAbsolutePath() + "/*"; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java index d917d6e3e..0ccd52a35 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java @@ -36,9 +36,10 @@ import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.RecordReaderBytesFunction; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; @@ -48,23 +49,22 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestRecordReaderBytesFunction extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testRecordReaderBytesFunction() throws Exception { + public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { return; } JavaSparkContext sc = getContext(); //Local file path - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f); List labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call String path = f.getAbsolutePath() + "/*"; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java index 63a8b8e3e..8d4090096 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java @@ -31,30 +31,29 @@ import org.datavec.api.writable.ArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.spark.BaseSparkTest; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestRecordReaderFunction extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Test - public void testRecordReaderFunction() throws Exception { + public void testRecordReaderFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { return; } - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f); List labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java index 44d45001d..550f0f12a 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java @@ -36,9 +36,10 @@ import org.datavec.codec.reader.CodecRecordReader; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; @@ -47,21 +48,20 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testRecordReaderBytesFunction() throws Exception { + public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { return; } //Local file path - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/video/").copyDirectory(f); String path = f.getAbsolutePath() + "/*"; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java index b48360b64..48062a036 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java @@ -33,28 +33,29 @@ import org.datavec.api.writable.ArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.codec.reader.CodecRecordReader; import org.datavec.spark.BaseSparkTest; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestSequenceRecordReaderFunction extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testSequenceRecordReaderFunctionCSV() throws Exception { + public void testSequenceRecordReaderFunctionCSV(@TempDir Path testDir) throws Exception { JavaSparkContext sc = getContext(); - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f); String path = f.getAbsolutePath() + "/*"; @@ -120,10 +121,10 @@ public class TestSequenceRecordReaderFunction extends BaseSparkTest { @Test - public void testSequenceRecordReaderFunctionVideo() throws Exception { + public void testSequenceRecordReaderFunctionVideo(@TempDir Path testDir) throws Exception { JavaSparkContext sc = getContext(); - File f = testDir.newFolder(); + File f = testDir.toFile(); new ClassPathResource("datavec-spark/video/").copyDirectory(f); String path = f.getAbsolutePath() + "/*"; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java index 964d8de54..62021a252 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.functions; import org.datavec.api.writable.*; import org.datavec.spark.transform.misc.WritablesToNDArrayFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestWritablesToNDArrayFunction { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java index 2ee8d4c78..070bda4ed 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java @@ -29,14 +29,14 @@ import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction; import org.datavec.spark.transform.misc.WritablesToStringFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import scala.Tuple2; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestWritablesToStringFunctions extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java index 6366703a7..f3964af6d 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.datavec.api.writable.*; import org.datavec.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; import java.io.File; @@ -35,8 +35,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestSparkStorageUtils extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java index f62e5b08a..62237f0b4 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java @@ -30,13 +30,13 @@ import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class DataFramesTests extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 61ebfcb6b..61a7c59be 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -28,7 +28,7 @@ import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -41,7 +41,7 @@ import java.util.ArrayList; import java.util.List; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class NormalizationTests extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 1516bfe87..4fc4f3323 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -38,7 +38,7 @@ import org.datavec.local.transforms.AnalyzeLocal; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.AnalyzeSpark; import org.joda.time.DateTimeZone; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; @@ -47,7 +47,7 @@ import java.io.File; import java.nio.file.Files; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestAnalysis extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/join/TestJoin.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/join/TestJoin.java index 4625ecea0..29da7a0a4 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/join/TestJoin.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/join/TestJoin.java @@ -27,11 +27,11 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.SparkTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestJoin extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java index 13265df69..6ff564418 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java @@ -30,13 +30,13 @@ import org.datavec.api.writable.Writable; import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.SparkTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestCalculateSortedRank extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java index b98771858..7faca7235 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java @@ -29,14 +29,14 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.SparkTransformExecutor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestConvertToSequence extends BaseSparkTest { diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/util/TestSparkUtil.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/util/TestSparkUtil.java index c9546f5b8..a2dd04ce0 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/util/TestSparkUtil.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/util/TestSparkUtil.java @@ -28,7 +28,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.utils.SparkUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.io.File; import java.io.FileInputStream; @@ -36,7 +36,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestSparkUtil extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 98c0e328b..8baa8cc6c 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -20,7 +20,6 @@ package org.deeplearning4j; import ch.qos.logback.classic.LoggerContext; -import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Pointer; import org.junit.jupiter.api.*; @@ -32,6 +31,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.ProfilerConfig; import org.slf4j.ILoggerFactory; +import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.management.ManagementFactory; import java.util.List; @@ -39,13 +39,11 @@ import java.util.Map; import java.util.Properties; import static org.junit.jupiter.api.Assumptions.assumeTrue; -import org.junit.jupiter.api.extension.ExtendWith; -@Slf4j @DisplayName("Base DL 4 J Test") public abstract class BaseDL4JTest { - + private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName()); protected long startTime; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index d35527076..dd9b3af6a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -43,7 +43,7 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class LayerHelperValidationUtil { @@ -145,7 +145,7 @@ public class LayerHelperValidationUtil { System.out.println(p1); System.out.println(p2); } - assertTrue(s + " - param changed during forward pass: " + p, maxRE < t.getMaxRelError()); + assertTrue(maxRE < t.getMaxRelError(),s + " - param changed during forward pass: " + p); } for( int i=0; i max relative error = " + t.getMaxRelError(), - maxRE < t.getMaxRelError()); + assertTrue(maxRE < t.getMaxRelError(), + t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError()); } } @@ -283,7 +283,7 @@ public class LayerHelperValidationUtil { double d2 = listNew.get(j); double re = relError(d1, d2); String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2; - assertTrue(msg, re < t.getMaxRelError()); + assertTrue(re < t.getMaxRelError(), msg); System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2); } } @@ -315,7 +315,7 @@ public class LayerHelperValidationUtil { try { if (keepAndAssertPresent) { Object o = f.get(l); - assertNotNull("Expect helper to be present for layer: " + l.getClass(), o); + assertNotNull(o,"Expect helper to be present for layer: " + l.getClass()); } else { f.set(l, null); Integer i = map.get(l.getClass()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java index 50fab2c70..493aab237 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java @@ -26,8 +26,8 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -38,7 +38,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.nio.file.Files; import java.util.concurrent.CountDownLatch; -@Ignore +@Disabled public class RandomTests extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index 7b1944802..f1e12d123 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -50,8 +50,8 @@ import java.lang.reflect.Field; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class TestUtils { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index 29041b5f5..575ff2ec6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -27,7 +27,7 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.rules.Timeout; + import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.dataset.DataSet; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java index 725bb402f..b6978c969 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java @@ -23,7 +23,7 @@ package org.deeplearning4j.datasets; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher; import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestDataSets extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 966830a86..47558926a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -19,7 +19,7 @@ */ package org.deeplearning4j.datasets.datavec; -import org.junit.rules.Timeout; + import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -47,7 +47,7 @@ import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader; import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.buffer.DataType; @@ -74,9 +74,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; @DisplayName("Record Reader Data Setiterator Test") class RecordReaderDataSetiteratorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); - @Override public DataType getDataType() { return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 507d80e9e..95049bcbe 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -19,7 +19,7 @@ */ package org.deeplearning4j.datasets.datavec; -import org.junit.rules.Timeout; + import org.nd4j.shade.guava.io.Files; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; @@ -44,7 +44,7 @@ import org.datavec.api.writable.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,8 +73,7 @@ class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @TempDir public Path temporaryFolder; - @Rule - public Timeout timeout = Timeout.seconds(300); + @Test @DisplayName("Tests Basic") diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index 7a59ae012..d66ef3d46 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -20,9 +20,9 @@ package org.deeplearning4j.datasets.fetchers; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; + import org.junit.jupiter.api.Test; -import org.junit.rules.Timeout; + import java.io.File; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessorTests.java index 99f302f64..36ac3c338 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessorTests.java @@ -21,14 +21,14 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class CombinedPreProcessorTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java index 6ac61086e..f62037f6d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java @@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class DataSetSplitterTests extends BaseDL4JTest { @Test @@ -54,7 +54,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (train.hasNext()) { val data = train.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTrain++; global++; } @@ -64,7 +64,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (test.hasNext()) { val data = test.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTest++; global++; } @@ -94,7 +94,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (train.hasNext()) { val data = train.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTrain++; global++; } @@ -104,7 +104,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { if (e % 2 == 0) while (test.hasNext()) { val data = test.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTest++; global++; } @@ -113,46 +113,50 @@ public class DataSetSplitterTests extends BaseDL4JTest { assertEquals(700 * numEpochs + (300 * numEpochs / 2), global); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testSplitter_3() throws Exception { - val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + assertThrows(ND4JIllegalStateException.class, () -> { + val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); - val splitter = new DataSetIteratorSplitter(back, 1000, 0.7); + val splitter = new DataSetIteratorSplitter(back, 1000, 0.7); - val train = splitter.getTrainIterator(); - val test = splitter.getTestIterator(); - val numEpochs = 10; + val train = splitter.getTrainIterator(); + val test = splitter.getTestIterator(); + val numEpochs = 10; - int gcntTrain = 0; - int gcntTest = 0; - int global = 0; - // emulating epochs here - for (int e = 0; e < numEpochs; e++) { - int cnt = 0; - while (train.hasNext()) { - val data = train.next().getFeatures(); + int gcntTrain = 0; + int gcntTest = 0; + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++) { + int cnt = 0; + while (train.hasNext()) { + val data = train.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); - gcntTrain++; - global++; - } + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); + gcntTrain++; + global++; + } - train.reset(); + train.reset(); - while (test.hasNext()) { - val data = test.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); - gcntTest++; - global++; - } + while (test.hasNext()) { + val data = test.next().getFeatures(); + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); + gcntTest++; + global++; + } + + // shifting underlying iterator by one + train.hasNext(); + back.shift(); + } + + assertEquals(1000 * numEpochs, global); + }); - // shifting underlying iterator by one - train.hasNext(); - back.shift(); - } - assertEquals(1000 * numEpochs, global); } @Test @@ -172,8 +176,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { partIterator.reset(); while (partIterator.hasNext()) { val data = partIterator.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, - (float) perEpoch, data.getFloat(0), 1e-5); + assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); //gcntTrain++; global++; cnt++; @@ -206,8 +209,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { int cnt = 0; val data = partIterator.next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, - (float) perEpoch, data.getFloat(0), 1e-5); + assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); //gcntTrain++; global++; cnt++; @@ -247,10 +249,10 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = trainIter.next(); assertNotNull(ds); - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); globalIter++; } - assertTrue("Failed at epoch [" + e + "]", trained); + assertTrue(trained,"Failed at epoch [" + e + "]"); assertEquals(800, globalIter); @@ -262,10 +264,10 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = testIter.next(); assertNotNull(ds); - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); globalIter++; } - assertTrue("Failed at epoch [" + e + "]", tested); + assertTrue(tested,"Failed at epoch [" + e + "]"); assertEquals(900, globalIter); // validation set is used every 5 epochs @@ -277,10 +279,10 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = validationIter.next(); assertNotNull(ds); - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); globalIter++; } - assertTrue("Failed at epoch [" + e + "]", validated); + assertTrue(validated,"Failed at epoch [" + e + "]"); } // all 3 iterators have exactly 1000 elements combined @@ -312,7 +314,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { int farCnt = (1000 / 2) * (partNumber) + cnt; val data = iteratorList.get(partNumber).next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5); + assertEquals((float) farCnt, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); cnt++; global++; } @@ -322,7 +324,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(0).hasNext()) { val data = iteratorList.get(0).next().getFeatures(); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); global++; } } @@ -341,7 +343,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(partNumber).hasNext()) { val data = iteratorList.get(partNumber).next().getFeatures(); - assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); + assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt); cnt++; } } @@ -365,7 +367,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(partNumber).hasNext()) { val data = iteratorList.get(partNumber).next().getFeatures(); - assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); + assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt); cnt++; } } @@ -390,7 +392,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = validationIter.next(); assertNotNull(ds); - assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5); + assertEquals((float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5,"Validation failed on iteration " + valCnt); valCnt++; } assertEquals(5, valCnt); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java index 228f54b5e..96efd4728 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java @@ -25,15 +25,15 @@ import lombok.val; import lombok.var; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.DataSet; import java.util.ArrayList; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class DummyBlockDataSetIteratorTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java index 40f2d8abe..0fe9528b8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.dataset.DataSet; @@ -43,8 +43,7 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { int numExamples = 105; - @Rule - public final ExpectedException exception = ExpectedException.none(); + @Test @DisplayName("Test Next And Reset") @@ -86,14 +85,16 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { } @Test - @DisplayName("Test Callsto Next Not Allowed") + @DisplayName("Test calls to Next Not Allowed") void testCallstoNextNotAllowed() throws IOException { - int terminateAfter = 1; - DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); - EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); - earlyEndIter.next(10); - iter.reset(); - exception.expect(RuntimeException.class); - earlyEndIter.next(10); + assertThrows(RuntimeException.class,() -> { + int terminateAfter = 1; + DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); + EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); + earlyEndIter.next(10); + iter.reset(); + earlyEndIter.next(10); + }); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java index 06b55bfcb..6a953278b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; @@ -30,11 +30,12 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import java.io.IOException; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; + import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.*; + @DisplayName("Early Termination Multi Data Set Iterator Test") class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { @@ -42,8 +43,7 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { int numExamples = 105; - @Rule - public final ExpectedException exception = ExpectedException.none(); + @Test @DisplayName("Test Next And Reset") @@ -91,14 +91,16 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - @DisplayName("Test Callsto Next Not Allowed") + @DisplayName("Test calls to Next Not Allowed") void testCallstoNextNotAllowed() throws IOException { - int terminateAfter = 1; - MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); - EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); - earlyEndIter.next(10); - iter.reset(); - exception.expect(RuntimeException.class); - earlyEndIter.next(10); + assertThrows(RuntimeException.class,() -> { + int terminateAfter = 1; + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); + EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); + earlyEndIter.next(10); + iter.reset(); + earlyEndIter.next(10); + }); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java index c4f66a5ac..c5871ee60 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java @@ -24,14 +24,16 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class JointMultiDataSetIteratorTests extends BaseDL4JTest { - @Test (timeout = 20000L) + @Test () + @Timeout(20000L) public void testJMDSI_1() { val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2}); val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2}); @@ -75,7 +77,8 @@ public class JointMultiDataSetIteratorTests extends BaseDL4JTest { } - @Test (timeout = 20000L) + @Test () + @Timeout(20000L) public void testJMDSI_2() { val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2}); val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2}); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java index f35780127..da97c5cd7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java @@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator; import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.loader.Loader; import org.nd4j.common.loader.LocalFileSourceFactory; import org.nd4j.common.loader.Source; @@ -39,8 +39,8 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class LoaderIteratorTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java index ec23dc31a..d8e2a39dd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java @@ -24,7 +24,7 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -32,7 +32,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import java.util.List; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class MultiDataSetSplitterTests extends BaseDL4JTest { @@ -55,7 +55,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (train.hasNext()) { val data = train.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTrain++; global++; } @@ -65,7 +65,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (test.hasNext()) { val data = test.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTest++; global++; } @@ -96,7 +96,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (train.hasNext()) { val data = train.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTrain++; global++; } @@ -106,7 +106,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { if (e % 2 == 0) while (test.hasNext()) { val data = test.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); gcntTest++; global++; } @@ -115,46 +115,49 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertEquals(700 * numEpochs + (300 * numEpochs / 2), global); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testSplitter_3() throws Exception { - val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); + assertThrows(ND4JIllegalStateException.class,() -> { + val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); - val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7); + val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7); - val train = splitter.getTrainIterator(); - val test = splitter.getTestIterator(); - val numEpochs = 10; + val train = splitter.getTrainIterator(); + val test = splitter.getTestIterator(); + val numEpochs = 10; - int gcntTrain = 0; - int gcntTest = 0; - int global = 0; - // emulating epochs here - for (int e = 0; e < numEpochs; e++){ - int cnt = 0; - while (train.hasNext()) { - val data = train.next().getFeatures(0); + int gcntTrain = 0; + int gcntTest = 0; + int global = 0; + // emulating epochs here + for (int e = 0; e < numEpochs; e++){ + int cnt = 0; + while (train.hasNext()) { + val data = train.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); - gcntTrain++; - global++; - } + assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); + gcntTrain++; + global++; + } - train.reset(); + train.reset(); - while (test.hasNext()) { - val data = test.next().getFeatures(0); - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); - gcntTest++; - global++; - } + while (test.hasNext()) { + val data = test.next().getFeatures(0); + assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); + gcntTest++; + global++; + } - // shifting underlying iterator by one - train.hasNext(); - back.shift(); - } + // shifting underlying iterator by one + train.hasNext(); + back.shift(); + } + + assertEquals(1000 * numEpochs, global); + }); - assertEquals(1000 * numEpochs, global); } @Test @@ -185,11 +188,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", trained); + assertTrue(trained,"Failed at epoch [" + e + "]"); assertEquals(800, globalIter); @@ -202,11 +205,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", tested); + assertTrue(tested,"Failed at epoch [" + e + "]"); assertEquals(900, globalIter); // validation set is used every 5 epochs @@ -219,11 +222,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", validated); + assertTrue(validated,"Failed at epoch [" + e + "]"); } // all 3 iterators have exactly 1000 elements combined @@ -256,8 +259,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { val data = partIterator.next().getFeatures(); for (int i = 0; i < data.length; ++i) { - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, - (float) perEpoch, data[i].getFloat(0), 1e-5); + assertEquals((float) perEpoch, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); } //gcntTrain++; global++; @@ -299,12 +301,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, - ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals((double) globalIter, + ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", trained); + assertTrue(trained,"Failed at epoch [" + e + "]"); assertEquals(800, globalIter); @@ -316,11 +318,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { val ds = testIter.next(); assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", tested); + assertTrue(tested,"Failed at epoch [" + e + "]"); assertEquals(900, globalIter); // validation set is used every 5 epochs @@ -333,12 +335,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, - ds.getFeatures()[i].getDouble(0), 1e-5f); + assertEquals((double) globalIter, + ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]"); } globalIter++; } - assertTrue("Failed at epoch [" + e + "]", validated); + assertTrue(validated,"Failed at epoch [" + e + "]"); } // all 3 iterators have exactly 1000 elements combined @@ -370,7 +372,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { int farCnt = (1000 / 2) * (partNumber) + cnt; val data = iteratorList.get(partNumber).next().getFeatures(); for (int i = 0; i < data.length; ++i) { - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5); + assertEquals( (float) farCnt, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); } cnt++; global++; @@ -381,8 +383,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(0).hasNext()) { val data = iteratorList.get(0).next().getFeatures(); for (int i = 0; i < data.length; ++i) { - assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, - data[i].getFloat(0), 1e-5); + assertEquals((float) cnt++, + data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e); } global++; } @@ -402,7 +404,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(partNumber).hasNext()) { val data = iteratorList.get(partNumber).next().getFeatures(); for (int i = 0; i < data.length; ++i) { - assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5); + assertEquals( (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt); } cnt++; } @@ -427,8 +429,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { while (iteratorList.get(partNumber).hasNext()) { val data = iteratorList.get(partNumber).next().getFeatures(); for (int i = 0; i < data.length; ++i) { - assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), - data[i].getFloat(0), 1e-5); + assertEquals( (float) (500 * partNumber + cnt), + data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt); } cnt++; } @@ -454,8 +456,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { val ds = validationIter.next(); assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, - ds.getFeatures()[i].getFloat(0), 1e-5); + assertEquals((float) valCnt + 90, + ds.getFeatures()[i].getFloat(0), 1e-5,"Validation failed on iteration " + valCnt); } valCnt++; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index 97a4f491b..5c586285f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -25,9 +25,9 @@ import org.datavec.api.split.FileSplit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.util.TestDataSetConsumer; -import org.junit.Rule; + import org.junit.jupiter.api.Test; -import org.junit.rules.Timeout; + import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; @@ -42,9 +42,6 @@ import org.junit.jupiter.api.extension.ExtendWith; @DisplayName("Multiple Epochs Iterator Test") class MultipleEpochsIteratorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); - @Test @DisplayName("Test Next And Reset") void testNextAndReset() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java index 66837621d..8507df4c1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java @@ -22,8 +22,8 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; @@ -33,9 +33,9 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.List; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled public class TestAsyncIterator extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java index 888e514de..09bdfa9bd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java @@ -23,21 +23,20 @@ package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; + +import org.junit.jupiter.api.Test; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestEmnistDataSetIterator extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(600); + @Override public DataType getDataType(){ diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestFileIterators.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestFileIterators.java index 6edcb8296..a99a7c724 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestFileIterators.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestFileIterators.java @@ -23,9 +23,10 @@ package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.file.FileDataSetIterator; import org.deeplearning4j.datasets.iterator.file.FileMultiDataSetIterator; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -33,23 +34,20 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import java.io.File; +import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestFileIterators extends BaseDL4JTest { - @Rule - public TemporaryFolder folder = new TemporaryFolder(); - @Rule - public TemporaryFolder folder2 = new TemporaryFolder(); @Test - public void testFileDataSetIterator() throws Exception { - folder.create(); - File f = folder.newFolder(); + public void testFileDataSetIterator(@TempDir Path folder, @TempDir Path testDir2) throws Exception { + + File f = folder.toFile(); DataSet d1 = new DataSet(Nd4j.linspace(1, 10, 10).reshape(10,1), Nd4j.linspace(101, 110, 10).reshape(10,1)); @@ -77,10 +75,13 @@ public class TestFileIterators extends BaseDL4JTest { assertEquals(exp, act); //Test multiple directories - folder2.create(); - File f2a = folder2.newFolder(); - File f2b = folder2.newFolder(); - File f2c = folder2.newFolder(); + + File f2a = new File(testDir2.toFile(),"folder1"); + f2a.mkdirs(); + File f2b = new File(testDir2.toFile(),"folder2"); + f2b.mkdirs(); + File f2c = new File(testDir2.toFile(),"folder3"); + f2c.mkdirs(); d1.save(new File(f2a, "d1.bin")); d2.save(new File(f2a, "d2.bin")); d3.save(new File(f2b, "d3.bin")); @@ -134,7 +135,9 @@ public class TestFileIterators extends BaseDL4JTest { //Test batch size != saved size - f = folder.newFolder(); + File f4 = new File(folder.toFile(),"newFolder"); + f4.mkdirs(); + f = f4; d1.save(new File(f, "d1.bin")); d2.save(new File(f, "d2.bin")); d3.save(new File(f, "d3.bin")); @@ -159,9 +162,8 @@ public class TestFileIterators extends BaseDL4JTest { } @Test - public void testFileMultiDataSetIterator() throws Exception { - folder.create(); - File f = folder.newFolder(); + public void testFileMultiDataSetIterator(@TempDir Path folder) throws Exception { + File f = folder.toFile(); MultiDataSet d1 = new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.linspace(1, 10, 10).reshape(10,1), Nd4j.linspace(101, 110, 10).reshape(10,1)); @@ -189,10 +191,11 @@ public class TestFileIterators extends BaseDL4JTest { assertEquals(exp, act); //Test multiple directories - folder2.create(); - File f2a = folder2.newFolder(); - File f2b = folder2.newFolder(); - File f2c = folder2.newFolder(); + File newDir = new File(folder.toFile(),"folder2"); + newDir.mkdirs(); + File f2a = new File(newDir,"folder-1"); + File f2b = new File(newDir,"folder-2"); + File f2c = new File(newDir,"folder-3"); d1.save(new File(f2a, "d1.bin")); d2.save(new File(f2a, "d2.bin")); d3.save(new File(f2b, "d3.bin")); @@ -243,7 +246,8 @@ public class TestFileIterators extends BaseDL4JTest { //Test batch size != saved size - f = folder.newFolder(); + f = new File(folder.toFile(),"newolder"); + f.mkdirs(); d1.save(new File(f, "d1.bin")); d2.save(new File(f, "d2.bin")); d3.save(new File(f, "d3.bin")); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index de1c24df7..c506b78bd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -50,9 +50,10 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.solvers.BaseOptimizer; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; @@ -71,16 +72,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestEarlyStopping extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Override public DataType getDataType(){ @@ -92,7 +92,7 @@ public class TestEarlyStopping extends BaseDL4JTest { DataSetIterator irisIter = new IrisDataSetIterator(150, 150); - for( int i=0; i<6; i++ ) { + for( int i = 0; i < 6; i++ ) { Nd4j.getRandom().setSeed(12345); ScoreCalculator sc; @@ -181,8 +181,8 @@ public class TestEarlyStopping extends BaseDL4JTest { bestEpoch = j; } } - assertEquals(msg, bestEpoch, out.getEpochCount()); - assertEquals(msg, bestScore, result.getBestModelScore(), 1e-5); + assertEquals(bestEpoch, out.getEpochCount(),msg); + assertEquals( bestScore, result.getBestModelScore(), 1e-5,msg); //Check that best score actually matches (returned model vs. manually calculated score) MultiLayerNetwork bestNetwork = result.getBestModel(); @@ -213,7 +213,7 @@ public class TestEarlyStopping extends BaseDL4JTest { default: throw new RuntimeException(); } - assertEquals(msg, result.getBestModelScore(), score, 1e-2); + assertEquals(result.getBestModelScore(), score, 1e-2,msg); } } @@ -845,7 +845,7 @@ public class TestEarlyStopping extends BaseDL4JTest { } @Test - public void testEarlyStoppingMaximizeScore() throws Exception { + public void testEarlyStoppingMaximizeScore(@TempDir Path testDir) throws Exception { Nd4j.getRandom().setSeed(12345); int outputs = 2; @@ -883,7 +883,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .build()) .build(); - File f = testDir.newFolder(); + File f = testDir.toFile(); EarlyStoppingModelSaver saver = new LocalFileModelSaver(f.getAbsolutePath()); EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index 83af3ac95..1a02ffd7f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -45,7 +45,7 @@ import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.activations.Activation; @@ -64,7 +64,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestEarlyStoppingCompGraph extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java index 63406f221..70271cd95 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java index 88ff4ccb1..128b62e91 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java @@ -30,11 +30,12 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; @Slf4j public class TestInvalidConfigurations extends BaseDL4JTest { @@ -355,64 +356,100 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } } - @Test(expected = IllegalStateException.class) + @Test() public void testCnnInvalidKernel() { - new ConvolutionLayer.Builder().kernelSize(3, 0).build(); + assertThrows(IllegalStateException.class, () -> { + new ConvolutionLayer.Builder().kernelSize(3, 0).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testCnnInvalidKernel2() { - new ConvolutionLayer.Builder().kernelSize(2, 2, 2).build(); + assertThrows(IllegalArgumentException.class, () -> { + new ConvolutionLayer.Builder().kernelSize(2, 2, 2).build(); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testCnnInvalidStride() { - new ConvolutionLayer.Builder().kernelSize(3, 3).stride(0, 1).build(); + assertThrows(IllegalStateException.class,() -> { + new ConvolutionLayer.Builder().kernelSize(3, 3).stride(0, 1).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testCnnInvalidStride2() { - new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1).build(); + assertThrows(IllegalArgumentException.class,() -> { + new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testCnnInvalidPadding() { - new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(-1, 0).build(); + assertThrows(IllegalArgumentException.class,() -> { + new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(-1, 0).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testCnnInvalidPadding2() { - new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(0, 0, 0).build(); + assertThrows(IllegalArgumentException.class,() -> { + new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(0, 0, 0).build(); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testSubsamplingInvalidKernel() { - new SubsamplingLayer.Builder().kernelSize(3, 0).build(); + assertThrows(IllegalStateException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(3, 0).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testSubsamplingInvalidKernel2() { - new SubsamplingLayer.Builder().kernelSize(2).build(); + assertThrows(IllegalArgumentException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(2).build(); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testSubsamplingInvalidStride() { - new SubsamplingLayer.Builder().kernelSize(3, 3).stride(0, 1).build(); + assertThrows(IllegalStateException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(3, 3).stride(0, 1).build(); + + }); } - @Test(expected = RuntimeException.class) + @Test() public void testSubsamplingInvalidStride2() { - new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1, 1).build(); + assertThrows(RuntimeException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1, 1).build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testSubsamplingInvalidPadding() { - new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(-1, 0).build(); + assertThrows(IllegalArgumentException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(-1, 0).build(); + + }); } - @Test(expected = RuntimeException.class) + @Test() public void testSubsamplingInvalidPadding2() { - new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(0).build(); + assertThrows(RuntimeException.class,() -> { + new SubsamplingLayer.Builder().kernelSize(3, 3).stride(1, 1).padding(0).build(); + + }); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java index 495241964..7d958355a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java @@ -29,14 +29,14 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestInvalidInput extends BaseDL4JTest { @@ -291,7 +291,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (Exception e) { log.error("",e); String msg = e.getMessage(); - assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); + assertTrue(msg != null && msg.contains("rnn") && msg.contains("batch"), msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java index e647b794b..7d5e13501 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java @@ -29,7 +29,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.exception.DL4JException; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -38,15 +38,15 @@ import java.util.Arrays; import java.util.Collection; import static junit.framework.TestCase.fail; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestRecordReaders extends BaseDL4JTest { @Test public void testClassIndexOutsideOfRangeRRDSI() { Collection> c = new ArrayList<>(); - c.add(Arrays.asList(new DoubleWritable(0.5), new IntWritable(0))); - c.add(Arrays.asList(new DoubleWritable(1.0), new IntWritable(2))); + c.add(Arrays.asList(new DoubleWritable(0.5), new IntWritable(0))); + c.add(Arrays.asList(new DoubleWritable(1.0), new IntWritable(2))); CollectionRecordReader crr = new CollectionRecordReader(c); @@ -56,7 +56,7 @@ public class TestRecordReaders extends BaseDL4JTest { DataSet ds = iter.next(); fail("Expected exception"); } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); + assertTrue( e.getMessage().contains("to one-hot"),e.getMessage()); } } @@ -65,13 +65,13 @@ public class TestRecordReaders extends BaseDL4JTest { Collection>> c = new ArrayList<>(); Collection> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); - seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(1))); + seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); + seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(1))); c.add(seq1); Collection> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); - seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(2))); + seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); + seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(2))); c.add(seq2); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c); @@ -81,7 +81,7 @@ public class TestRecordReaders extends BaseDL4JTest { DataSet ds = dsi.next(); fail("Expected exception"); } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); + assertTrue(e.getMessage().contains("to one-hot"),e.getMessage()); } } @@ -90,24 +90,24 @@ public class TestRecordReaders extends BaseDL4JTest { Collection>> c1 = new ArrayList<>(); Collection> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new DoubleWritable(0.0))); - seq1.add(Arrays.asList(new DoubleWritable(0.0))); + seq1.add(Arrays.asList(new DoubleWritable(0.0))); + seq1.add(Arrays.asList(new DoubleWritable(0.0))); c1.add(seq1); Collection> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new DoubleWritable(0.0))); - seq2.add(Arrays.asList(new DoubleWritable(0.0))); + seq2.add(Arrays.asList(new DoubleWritable(0.0))); + seq2.add(Arrays.asList(new DoubleWritable(0.0))); c1.add(seq2); Collection>> c2 = new ArrayList<>(); Collection> seq1a = new ArrayList<>(); - seq1a.add(Arrays.asList(new IntWritable(0))); - seq1a.add(Arrays.asList(new IntWritable(1))); + seq1a.add(Arrays.asList(new IntWritable(0))); + seq1a.add(Arrays.asList(new IntWritable(1))); c2.add(seq1a); Collection> seq2a = new ArrayList<>(); - seq2a.add(Arrays.asList(new IntWritable(0))); - seq2a.add(Arrays.asList(new IntWritable(2))); + seq2a.add(Arrays.asList(new IntWritable(0))); + seq2a.add(Arrays.asList(new IntWritable(2))); c2.add(seq2a); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1); @@ -118,7 +118,7 @@ public class TestRecordReaders extends BaseDL4JTest { DataSet ds = dsi.next(); fail("Expected exception"); } catch (Exception e) { - assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); + assertTrue(e.getMessage().contains("to one-hot"),e.getMessage()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index 2f9fbf18c..023f35449 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.activations.Activation; @@ -42,6 +42,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; @@ -50,8 +52,7 @@ import org.junit.jupiter.api.extension.ExtendWith; @DisplayName("Attention Layer Test") class AttentionLayerTest extends BaseDL4JTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + @Override public long getTimeoutMilliseconds() { @@ -178,21 +179,22 @@ class AttentionLayerTest extends BaseDL4JTest { @Test @DisplayName("Test Recurrent Attention Layer _ differing Time Steps") void testRecurrentAttentionLayer_differingTimeSteps() { - int nIn = 9; - int nOut = 5; - int layerSize = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - final INDArray initialInput = Nd4j.rand(new int[] { 8, nIn, 7 }); - final INDArray goodNextInput = Nd4j.rand(new int[] { 8, nIn, 7 }); - final INDArray badNextInput = Nd4j.rand(new int[] { 8, nIn, 12 }); - final INDArray labels = Nd4j.rand(new int[] { 8, nOut }); - net.fit(initialInput, labels); - net.fit(goodNextInput, labels); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("This layer only supports fixed length mini-batches. Expected 7 time steps but got 12."); - net.fit(badNextInput, labels); + assertThrows(IllegalArgumentException.class, () -> { + int nIn = 9; + int nOut = 5; + int layerSize = 8; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + final INDArray initialInput = Nd4j.rand(new int[] { 8, nIn, 7 }); + final INDArray goodNextInput = Nd4j.rand(new int[] { 8, nIn, 7 }); + final INDArray badNextInput = Nd4j.rand(new int[] { 8, nIn, 12 }); + final INDArray labels = Nd4j.rand(new int[] { 8, nOut }); + net.fit(initialInput, labels); + net.fit(goodNextInput, labels); + net.fit(badNextInput, labels); + }); + } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index ec36fdd82..7ca1064b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,7 +44,7 @@ import org.nd4j.common.function.Consumer; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class DropoutGradientCheck extends BaseDL4JTest { @@ -141,7 +141,7 @@ public class DropoutGradientCheck extends BaseDL4JTest { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, false, -1, null, 12345); //Last arg: ensures RNG is reset at each iter... otherwise will fail due to randomness! - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index 34315df6d..214cb895e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,7 +41,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index b5cad31fa..37667dc9f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -36,8 +36,8 @@ import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -55,7 +55,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Random; import static org.deeplearning4j.gradientcheck.GradientCheckUtil.checkGradients; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class GradientCheckTests extends BaseDL4JTest { @@ -136,7 +136,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "testMinibatchApplication() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } @@ -216,7 +216,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -294,7 +294,7 @@ public class GradientCheckTests extends BaseDL4JTest { + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1 + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.8 * scoreBefore); + assertTrue(scoreAfter < 0.8 * scoreBefore, msg); } if (PRINT_RESULTS) { @@ -311,7 +311,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -354,7 +354,7 @@ public class GradientCheckTests extends BaseDL4JTest { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); String msg = "testEmbeddingLayerSimple"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); } @Test @@ -394,7 +394,7 @@ public class GradientCheckTests extends BaseDL4JTest { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); String msg = "testEmbeddingLayerSimple"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } @@ -468,7 +468,7 @@ public class GradientCheckTests extends BaseDL4JTest { + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1 + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < scoreBefore); + assertTrue(scoreAfter < scoreBefore, msg); } msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf @@ -482,7 +482,7 @@ public class GradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -541,7 +541,7 @@ public class GradientCheckTests extends BaseDL4JTest { + "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id" + ", doLearningFirst=" + "true" + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.8 * scoreBefore); + assertTrue(scoreAfter < 0.8 * scoreBefore, msg); // expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight log.info("params after learning: " + netGraph.getLayer(1).paramTable()); @@ -551,7 +551,7 @@ public class GradientCheckTests extends BaseDL4JTest { msg = "elementWiseMultiplicationLayerTest() - activationFn=" + "ID" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id" + ", doLearningFirst=" + "true"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(netGraph); } @@ -606,7 +606,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "mask=" + maskArray + ", inputRank=" + inputRank; boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(label).inputMask(fMask)); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(net); @@ -704,7 +704,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "testGradientWeightDecay() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1; - assertTrue(msg, gradOK1); + assertTrue(gradOK1, msg); TestUtils.testModelSerialization(mln); } @@ -713,7 +713,7 @@ public class GradientCheckTests extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") + @Disabled("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testGradientMLP2LayerIrisLayerNorm() { //Parameterized test, testing combinations of: // (a) activation function @@ -789,7 +789,7 @@ public class GradientCheckTests extends BaseDL4JTest { String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", layerNorm=" + layerNorm; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index 6f791c3d2..d8bd9ee0d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -42,7 +42,7 @@ import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -56,7 +56,7 @@ import java.util.Arrays; import java.util.Map; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @@ -118,7 +118,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testBasicIris()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -169,7 +169,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithMerging()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -226,7 +226,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -286,7 +286,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -333,7 +333,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in}) .labels(new INDArray[]{labels})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -387,7 +387,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{labels})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -450,7 +450,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{labels})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -490,7 +490,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testLSTMWithSubset()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -528,7 +528,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testLSTMWithLastTimeStepVertex()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); //Second: test with input mask arrays. INDArray inMask = Nd4j.zeros(3, 4); @@ -538,7 +538,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{labels}).inputMask(new INDArray[]{inMask})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -591,7 +591,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testLSTMWithDuplicateToTimeSeries()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -640,7 +640,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testLSTMWithDuplicateToTimeSeries()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); //Second: test with input mask arrays. INDArray inMask = Nd4j.zeros(3, 5); @@ -651,7 +651,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{labels})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -694,7 +694,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs) .labels(new INDArray[]{out})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -734,7 +734,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{out})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -780,7 +780,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input) .labels(new INDArray[]{out})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -831,7 +831,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) .labels(new INDArray[]{out})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -900,7 +900,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testBasicIrisTripletStackingL2Loss()"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } @@ -960,7 +960,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example}) .labels(new INDArray[]{labels})); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(graph); } } @@ -1025,7 +1025,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, example, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(net); } } @@ -1074,7 +1074,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) .labels(new INDArray[]{labels})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1132,7 +1132,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) .labels(new INDArray[]{labels1, labels2})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1190,7 +1190,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) .labels(new INDArray[]{labels1, labels2})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1255,7 +1255,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) .labels(new INDArray[]{labels1, labels2}).inputMask(new INDArray[]{inMask1, inMask2})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1311,7 +1311,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) .labels(new INDArray[]{labels1, labels2})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1358,7 +1358,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) .labels(new INDArray[]{labels1})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1409,7 +1409,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) .labels(new INDArray[]{labels1})); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(graph); } } @@ -1448,7 +1448,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .labels(new INDArray[]{labels})); String msg = "testGraphEmbeddingLayerSimple"; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(cg); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index 06b0ca1fd..1e8faae70 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -47,8 +47,8 @@ import org.nd4j.linalg.lossfunctions.impl.*; import java.util.Arrays; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.nd4j.linalg.indexing.NDArrayIndex.*; public class GradientCheckTestsMasking extends BaseDL4JTest { @@ -139,7 +139,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength + ", miniBatchSize=" + 1; - assertTrue(msg, gradOK); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(mln); } } @@ -269,7 +269,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features) .labels(labels).labelMask(labelMask)); - assertTrue(msg, gradOK); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -365,7 +365,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features) .labels(labels).labelMask(labelMask)); - assertTrue(msg, gradOK); + assertTrue(gradOK,msg); //Check the equivalent compgraph: @@ -388,7 +388,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{features}) .labels(new INDArray[]{labels}).labelMask(new INDArray[]{labelMask})); - assertTrue(msg + " (compgraph)", gradOK); + assertTrue(gradOK,msg + " (compgraph)"); TestUtils.testModelSerialization(graph); } } @@ -424,7 +424,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){ lm = TestUtils.randomBernoulli(mb, 1); } - assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0); + assertTrue( lm.sumNumber().intValue() > 0,"Could not generate non-zero mask after " + attempts + " attempts"); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) .labels(l).labelMask(lm)); @@ -446,7 +446,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { double score2 = net.score(new DataSet(f,l,null,lm)); - assertEquals(String.valueOf(i), score, score2, 1e-8); + assertEquals( score, score2, 1e-8,String.valueOf(i)); } } @@ -481,7 +481,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){ lm = TestUtils.randomBernoulli(mb, 1); } - assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0); + assertTrue(lm.sumNumber().intValue() > 0,"Could not generate non-zero mask after " + attempts + " attempts"); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f}) .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm})); @@ -503,7 +503,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { double score2 = net.score(new DataSet(f,l,null,lm)); - assertEquals(String.valueOf(i), score, score2, 1e-8); + assertEquals(score, score2, 1e-8,String.valueOf(i)); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index c128dcfdf..3ab2efd59 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class LRNGradientCheckTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index 2b3391828..00fef6150 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,7 +41,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class LSTMGradientCheckTests extends BaseDL4JTest { @@ -137,7 +137,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(mln); } } @@ -226,7 +226,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) .labels(labels).subset(true).maxPerParam(128)); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(mln); } } @@ -276,7 +276,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { System.out.println(msg); boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -356,7 +356,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { String msg = "testGradientGravesLSTMFull() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -405,7 +405,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i] + ", miniBatchSize=" + miniBatchSize[i]; - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index efaef6db3..e09197d69 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -34,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -56,8 +56,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -242,7 +242,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { } } - assertEquals("Tests failed", 0, failed.size()); + assertEquals(0, failed.size(),"Tests failed"); } @Test @@ -349,7 +349,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { lossFunctions[i] = lf2; } catch (IOException ex) { ex.printStackTrace(); - assertEquals("Tests failed: serialization of " + lossFunctions[i], 0, 1); + assertEquals(0, 1,"Tests failed: serialization of " + lossFunctions[i]); } Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -410,7 +410,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { System.out.println(s); } - assertEquals("Tests failed", 0, failed.size()); + assertEquals(0, failed.size(),"Tests failed"); } @Test @@ -718,6 +718,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { System.out.println(s); } - assertEquals("Tests failed", 0, failed.size()); + assertEquals(0, failed.size(),"Tests failed"); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index dee6ad81b..c9e65579b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -37,8 +37,8 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class NoBiasGradientCheckTests extends BaseDL4JTest { @@ -123,7 +123,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } @@ -180,7 +180,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } @@ -242,7 +242,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } @@ -308,7 +308,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(net); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 3c86a4910..12a1340e2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,7 +42,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class OutputLayerGradientChecks extends BaseDL4JTest { @@ -149,7 +149,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) .labels(labels).labelMask(labelMask)); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(mln); } } @@ -256,7 +256,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) .labels(labels).labelMask(labelMask)); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(mln); } } @@ -405,7 +405,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) .labels(labels).labelMask(labelMask)); - assertTrue(testName, gradOK); + assertTrue(gradOK, testName); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index bf4fed712..bcfa02f30 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -35,8 +35,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,7 +46,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class RnnGradientChecks extends BaseDL4JTest { @@ -62,7 +62,7 @@ public class RnnGradientChecks extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") + @Disabled("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testBidirectionalWrapper() { int nIn = 3; @@ -146,7 +146,7 @@ public class RnnGradientChecks extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") + @Disabled("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testSimpleRnn() { int nOut = 5; @@ -226,7 +226,7 @@ public class RnnGradientChecks extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") + @Disabled("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testLastTimeStepLayer(){ int nIn = 3; int nOut = 5; @@ -289,7 +289,7 @@ public class RnnGradientChecks extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); - assertTrue(name, gradOK); + assertTrue(gradOK, name); TestUtils.testModelSerialization(net); } } @@ -353,7 +353,7 @@ public class RnnGradientChecks extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); - assertTrue(name, gradOK); + assertTrue(gradOK, name); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index e9770cedf..25d594d9a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.util.MaskLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,7 +48,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class UtilLayerGradientChecks extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index 4730de189..ec9fdab25 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.buffer.DataType; @@ -43,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossMSE; import java.util.Arrays; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class VaeGradientCheckTests extends BaseDL4JTest { @@ -135,7 +135,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -207,7 +207,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, 12345); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -295,7 +295,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, data, 12345); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } @@ -337,7 +337,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, features, 12345); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 0926c0c16..0a280d9f0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -35,9 +35,10 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -53,9 +54,10 @@ import org.nd4j.linalg.learning.config.NoOp; import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; +import java.nio.file.Path; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class YoloGradientCheckTests extends BaseDL4JTest { @@ -73,8 +75,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest { return CNN2DFormat.values(); } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Override public long getTimeoutMilliseconds() { @@ -154,7 +154,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { .minAbsoluteError(1e-6) .labels(labels).subset(true).maxPerParam(100)); - assertTrue(msg, gradOK); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -181,14 +181,15 @@ public class YoloGradientCheckTests extends BaseDL4JTest { @Test - public void yoloGradientCheckRealData() throws Exception { + public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception { Nd4j.getRandom().setSeed(12345); InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream(); InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream(); InputStream is3 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2008_003344.jpg").getInputStream(); InputStream is4 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2008_003344.xml").getInputStream(); - File dir = testDir.newFolder("testYoloOverfitting"); + File dir = new File(testDir.toFile(),"testYoloOverfitting"); + dir.mkdirs(); File jpg = new File(dir, "JPEGImages"); File annot = new File(dir, "Annotations"); jpg.mkdirs(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java index e08c01440..09a274b63 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java index 9a2e0cd61..fda02a451 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java @@ -41,7 +41,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -51,8 +51,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestConstraints extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java index 981df8ee6..cd63d4c5e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java @@ -34,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.dataset.DataSet; @@ -50,8 +50,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -241,7 +241,7 @@ public class TestDropout extends BaseDL4JTest { if(i < 5){ countTwos = Nd4j.getExecutioner().exec(new MatchCondition(out, Conditions.equals(2))).getInt(0); - assertEquals(String.valueOf(i), 100, countZeros + countTwos); //Should only be 0 or 2 + assertEquals( 100, countZeros + countTwos,String.valueOf(i)); //Should only be 0 or 2 //Stochastic, but this should hold for most cases assertTrue(countZeros >= 25 && countZeros <= 75); assertTrue(countTwos >= 25 && countTwos <= 75); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index dd6967371..28d9558cd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -34,7 +34,7 @@ import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -43,7 +43,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestPreProcessors extends BaseDL4JTest { @@ -186,8 +186,8 @@ public class TestPreProcessors extends BaseDL4JTest { //Again epsilons and activations have same shape, we can do this (even though it's not the intended use) INDArray epsilon2d1 = proc.backprop(activations3dc, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); INDArray epsilon2d2 = proc.backprop(activations3df, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(msg, activations2dc, epsilon2d1); - assertEquals(msg, activations2dc, epsilon2d2); + assertEquals(activations2dc, epsilon2d1, msg); + assertEquals(activations2dc, epsilon2d2, msg); //Also check backprop with 3d activations in f order vs. c order: INDArray act3d_c = Nd4j.create(activations3dc.shape(), 'c'); @@ -195,8 +195,8 @@ public class TestPreProcessors extends BaseDL4JTest { INDArray act3d_f = Nd4j.create(activations3dc.shape(), 'f'); act3d_f.assign(activations3dc); - assertEquals(msg, activations2dc, proc.backprop(act3d_c, miniBatchSize, LayerWorkspaceMgr.noWorkspaces())); - assertEquals(msg, activations2dc, proc.backprop(act3d_f, miniBatchSize, LayerWorkspaceMgr.noWorkspaces())); + assertEquals(activations2dc, proc.backprop(act3d_c, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()), msg); + assertEquals(activations2dc, proc.backprop(act3d_f, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()), msg); } } @@ -245,14 +245,14 @@ public class TestPreProcessors extends BaseDL4JTest { //Check shape of outputs: val prod = nChannels * inputHeight * inputWidth; INDArray activationsRnn = proc.preProcess(activationsCnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(msg, new long[] {miniBatchSize, prod, timeSeriesLength}, - activationsRnn.shape()); + assertArrayEquals(new long[] {miniBatchSize, prod, timeSeriesLength}, + activationsRnn.shape(),msg); //Check backward pass. Given that activations and epsilons have same shape, they should //be opposite operations - i.e., get the same thing back out INDArray twiceProcessed = proc.backprop(activationsRnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(msg, activationsCnn.shape(), twiceProcessed.shape()); - assertEquals(msg, activationsCnn, twiceProcessed); + assertArrayEquals(activationsCnn.shape(), twiceProcessed.shape(),msg); + assertEquals(activationsCnn, twiceProcessed, msg); //Second way to check: compare to ComposableInputPreProcessor(CNNtoFF, FFtoRNN) InputPreProcessor compProc = new ComposableInputPreProcessor( @@ -260,7 +260,7 @@ public class TestPreProcessors extends BaseDL4JTest { new FeedForwardToRnnPreProcessor()); INDArray activationsRnnComp = compProc.preProcess(activationsCnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(msg, activationsRnnComp, activationsRnn); + assertEquals(activationsRnnComp, activationsRnn, msg); INDArray epsilonsRnn = Nd4j.rand(new int[] {miniBatchSize, nChannels * inputHeight * inputWidth, timeSeriesLength}); @@ -276,7 +276,7 @@ public class TestPreProcessors extends BaseDL4JTest { System.out.println(Arrays.toString(epsilonsCnn.shape())); System.out.println(epsilonsCnn); } - assertEquals(msg, epsilonsCnnComp, epsilonsCnn); + assertEquals(epsilonsCnnComp, epsilonsCnn, msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index 525f736f8..1449c8d04 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -52,7 +52,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestWeightNoise extends BaseDL4JTest { @@ -211,7 +211,7 @@ public class TestWeightNoise extends BaseDL4JTest { graph.output(trainData.get(0).getFeatures()); for (int i = 0; i < 3; i++) { - assertEquals(String.valueOf(i), expCalls.get(i), list.get(i).getAllCalls()); + assertEquals(expCalls.get(i), list.get(i).getAllCalls(), String.valueOf(i)); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 7b979af3d..400f8f199 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -20,8 +20,8 @@ package org.deeplearning4j.nn.dtypes; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; @@ -143,8 +143,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.junit.AfterClass; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; @@ -170,7 +170,7 @@ import java.util.Map; import java.util.Set; @Slf4j -@Ignore +@Disabled public class DTypeTests extends BaseDL4JTest { protected static Set> seenLayers = new HashSet<>(); @@ -542,9 +542,9 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(msg, networkDtype, net.params().dataType()); - assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); - assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); + assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); INDArray in = Nd4j.rand(networkDtype, 2, 8 * 8); INDArray label; @@ -561,11 +561,11 @@ public class DTypeTests extends BaseDL4JTest { } INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -647,9 +647,9 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(msg, networkDtype, net.params().dataType()); - assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); - assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); + assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); INDArray in = Nd4j.rand(networkDtype, 2, 1, 8, 8, 8); INDArray label; @@ -665,11 +665,11 @@ public class DTypeTests extends BaseDL4JTest { } INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -697,7 +697,7 @@ public class DTypeTests extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void testDtypesModelVsGlobalDtypeCnn1d() { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); Nd4j.getEnvironment().setDebug(true); @@ -760,9 +760,9 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(msg, networkDtype, net.params().dataType()); - assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); - assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); + assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); INDArray in = Nd4j.rand(networkDtype, 2, 5, 10); INDArray label; @@ -775,11 +775,11 @@ public class DTypeTests extends BaseDL4JTest { } INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -833,19 +833,19 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(msg, networkDtype, net.params().dataType()); - assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); - assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); + assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); INDArray in = Nd4j.rand(networkDtype, 2, 5, 28, 28); INDArray label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -922,9 +922,9 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(msg, networkDtype, net.params().dataType()); - assertEquals(msg, networkDtype, net.getFlattenedGradients().dataType()); - assertEquals(msg, networkDtype, net.getUpdater(true).getStateViewArray().dataType()); + assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); + assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); INDArray in = Nd4j.rand(networkDtype, 2, 5, 2); INDArray label; @@ -936,10 +936,10 @@ public class DTypeTests extends BaseDL4JTest { INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - assertEquals(msg, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), msg); } net.setInput(in); @@ -1014,11 +1014,11 @@ public class DTypeTests extends BaseDL4JTest { } INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -1102,13 +1102,13 @@ public class DTypeTests extends BaseDL4JTest { INDArray label = Nd4j.zeros(networkDtype, 10, 10); INDArray out = net.outputSingle(input); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); Map ff = net.feedForward(input, false); for (Map.Entry e : ff.entrySet()) { if (e.getKey().equals("in")) continue; String s = msg + " - layer: " + e.getKey(); - assertEquals(s, networkDtype, e.getValue().dataType()); + assertEquals(networkDtype, e.getValue().dataType(), s); } net.setInput(0, input); @@ -1257,13 +1257,13 @@ public class DTypeTests extends BaseDL4JTest { INDArray label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); INDArray out = net.outputSingle(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); Map ff = net.feedForward(in, false); for (Map.Entry e : ff.entrySet()) { if (e.getKey().equals("in")) continue; String s = msg + " - layer: " + e.getKey(); - assertEquals(s, networkDtype, e.getValue().dataType()); + assertEquals(networkDtype, e.getValue().dataType(), s); } net.setInputs(in); @@ -1343,13 +1343,13 @@ public class DTypeTests extends BaseDL4JTest { net.init(); INDArray out = net.outputSingle(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); Map ff = net.feedForward(in, false); for (Map.Entry e : ff.entrySet()) { if (e.getKey().equals("in")) continue; String s = msg + " - layer: " + e.getKey(); - assertEquals(s, networkDtype, e.getValue().dataType()); + assertEquals(networkDtype, e.getValue().dataType(), s); } net.setInputs(in); @@ -1419,11 +1419,11 @@ public class DTypeTests extends BaseDL4JTest { net.init(); INDArray out = net.output(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); - assertEquals(s, networkDtype, ff.get(i).dataType()); + assertEquals(networkDtype, ff.get(i).dataType(), s); } net.setInput(in); @@ -1506,11 +1506,11 @@ public class DTypeTests extends BaseDL4JTest { net.init(); INDArray out = net.outputSingle(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(), msg); Map ff = net.feedForward(in, false); for(Map.Entry e : ff.entrySet()){ String s = msg + " - layer " + e.getKey(); - assertEquals(s, networkDtype, e.getValue().dataType()); + assertEquals(networkDtype, e.getValue().dataType(), s); } net.setInput(0, in); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index f412ea68f..2bddca70a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -39,7 +39,7 @@ import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer; import org.deeplearning4j.nn.layers.recurrent.GravesLSTM; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -55,7 +55,7 @@ import org.nd4j.common.primitives.Pair; import java.util.Collections; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class ComputationGraphTestRNN extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java index 30bebcf91..bd5a1ccfa 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java @@ -30,9 +30,9 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,10 +44,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -//@Ignore +//@Disabled public class TestCompGraphCNN extends BaseDL4JTest { protected ComputationGraphConfiguration conf; @@ -96,8 +96,8 @@ public class TestCompGraphCNN extends BaseDL4JTest { return 2 * (3 * 1 * 4 * 4 * 3 + 3) + (7 * 14 * 14 * 6 + 7) + (7 * 10 + 10); } - @Before - @Ignore + @BeforeEach + @Disabled public void beforeDo() { conf = getMultiInputGraphConfig(); graph = new ComputationGraph(conf); @@ -138,51 +138,54 @@ public class TestCompGraphCNN extends BaseDL4JTest { } - @Test(expected = DL4JInvalidConfigException.class) + @Test() public void testCNNComputationGraphKernelTooLarge() { - int imageWidth = 23; - int imageHeight = 19; - int nChannels = 1; - int classes = 2; - int numSamples = 200; + assertThrows(DL4JInvalidConfigException.class,() -> { + int imageWidth = 23; + int imageHeight = 19; + int nChannels = 1; + int classes = 2; + int numSamples = 200; - int kernelHeight = 3; - int kernelWidth = imageWidth; + int kernelHeight = 3; + int kernelWidth = imageWidth; - DataSet trainInput; + DataSet trainInput; - ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .seed(123).graphBuilder().addInputs("input") - .setInputTypes(InputType.convolutional(nChannels, imageWidth, - imageHeight)) - .addLayer("conv1", new ConvolutionLayer.Builder() - .kernelSize(kernelHeight, kernelWidth).stride(1, 1) - .dataFormat(CNN2DFormat.NCHW) - .nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build(), "input") - .addLayer("pool1", - new SubsamplingLayer.Builder() - .dataFormat(CNN2DFormat.NCHW) - .poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(imageHeight - kernelHeight + 1, 1) - .stride(1, 1).build(), - "conv1") - .addLayer("output", new OutputLayer.Builder().nOut(classes).activation(Activation.SOFTMAX).build(), "pool1") - .setOutputs("output").build(); + ComputationGraphConfiguration conf = + new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .seed(123).graphBuilder().addInputs("input") + .setInputTypes(InputType.convolutional(nChannels, imageWidth, + imageHeight)) + .addLayer("conv1", new ConvolutionLayer.Builder() + .kernelSize(kernelHeight, kernelWidth).stride(1, 1) + .dataFormat(CNN2DFormat.NCHW) + .nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build(), "input") + .addLayer("pool1", + new SubsamplingLayer.Builder() + .dataFormat(CNN2DFormat.NCHW) + .poolingType(SubsamplingLayer.PoolingType.MAX) + .kernelSize(imageHeight - kernelHeight + 1, 1) + .stride(1, 1).build(), + "conv1") + .addLayer("output", new OutputLayer.Builder().nOut(classes).activation(Activation.SOFTMAX).build(), "pool1") + .setOutputs("output").build(); - ComputationGraph model = new ComputationGraph(conf); - model.init(); + ComputationGraph model = new ComputationGraph(conf); + model.init(); - INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); - INDArray emptyLables = Nd4j.zeros(numSamples, classes); + INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); + INDArray emptyLables = Nd4j.zeros(numSamples, classes); - trainInput = new DataSet(emptyFeatures, emptyLables); + trainInput = new DataSet(emptyFeatures, emptyLables); + + model.fit(trainInput); + }); - model.fit(trainInput); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java index dcdd56f05..a17979bf2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistr import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,8 +46,8 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class TestCompGraphUnsupervised extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 03206a6cd..563a67cf5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -61,8 +61,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -86,17 +86,15 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.*; import static org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional.Mode.CONCAT; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestComputationGraphNetwork extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - private static ComputationGraphConfiguration getIrisGraphConfiguration() { return new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -120,17 +118,16 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { private static OpExecutioner.ProfilingMode origMode; - @BeforeClass - public static void beforeClass(){ + @BeforeAll public static void beforeClass(){ origMode = Nd4j.getExecutioner().getProfilingMode(); } - @Before + @BeforeEach public void before(){ Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } - @AfterClass + @AfterAll public static void afterClass(){ Nd4j.getExecutioner().setProfilingMode(origMode); } @@ -322,7 +319,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(paramsMLN, paramsGraph); } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testIrisFitMultiDataSetIterator() throws Exception { RecordReader rr = new CSVRecordReader(0, ','); @@ -1174,23 +1172,26 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { g.calcRegularizationScore(false); } - @Test(expected = DL4JException.class) + @Test() public void testErrorNoOutputLayer() { + assertThrows(DL4JException.class,() -> { + ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + .addLayer("dense", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("dense") + .build(); - ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") - .addLayer("dense", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("dense") - .build(); + ComputationGraph cg = new ComputationGraph(c); + cg.init(); - ComputationGraph cg = new ComputationGraph(c); - cg.init(); + INDArray f = Nd4j.create(1, 10); + INDArray l = Nd4j.create(1, 10); - INDArray f = Nd4j.create(1, 10); - INDArray l = Nd4j.create(1, 10); + cg.setInputs(f); + cg.setLabels(l); + + cg.computeGradientAndScore(); + }); - cg.setInputs(f); - cg.setLabels(l); - cg.computeGradientAndScore(); } @@ -1514,7 +1515,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Hack output layer to be identity mapping graph.getOutputLayer(0).setParam("W", Nd4j.eye(input.length())); graph.getOutputLayer(0).setParam("b", Nd4j.zeros(input.length())); - assertEquals("Incorrect output", Nd4j.create(expected).reshape(1,expected.length), graph.outputSingle(input)); + assertEquals(Nd4j.create(expected).reshape(1,expected.length), graph.outputSingle(input),"Incorrect output"); } private static INDArray getInputArray4d(float[] inputArr) { @@ -1771,7 +1772,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { for(String s : exp.keySet()){ boolean allowed = ((org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer)cg.getLayer(s)).isInputModificationAllowed(); // System.out.println(s + "\t" + allowed); - assertEquals(s, exp.get(s), allowed); + assertEquals( exp.get(s), allowed,s); } } @@ -2188,7 +2189,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { } @Test - public void testMergeNchw() throws Exception { + public void testMergeNchw(@TempDir Path testDir) throws Exception { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .convolutionMode(ConvolutionMode.Same) .graphBuilder() @@ -2215,7 +2216,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)}; INDArray out = cg.outputSingle(in); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "net.zip"); cg.save(f); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java index 563b1f051..ec5c47894 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java @@ -24,7 +24,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestSetGetParameters extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java index e6b7fd1dc..b11d783dc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -48,8 +48,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Map; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class TestVariableLengthTSCG extends BaseDL4JTest { @@ -122,7 +122,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { for (String s : g1map.keySet()) { INDArray g1s = g1map.get(s); INDArray g2s = g2map.get(s); - assertEquals(s, g1s, g2s); + assertEquals(g1s, g2s, s); } //Finally: check that the values at the masked outputs don't actually make any difference to: @@ -140,7 +140,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { for (String s : g2map.keySet()) { INDArray g2s = g2map.get(s); INDArray g2sa = g2a.getGradientFor(s); - assertEquals(s, g2s, g2sa); + assertEquals(g2s, g2sa, s); } } } @@ -225,7 +225,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { INDArray g1s = g1map.get(s); INDArray g2s = g2map.get(s); - assertNotEquals(s, g1s, g2s); + assertNotEquals(g1s, g2s, s); } //Modify the values at the masked time step, and check that neither the gradients, score or activations change @@ -331,8 +331,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.computeGradientAndScore(); double score = net.score(); - - assertEquals(msg, expScore, score, 0.1); + assertEquals( expScore, score, 0.1,msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java index 3e83bbe4e..de4010554 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java @@ -40,7 +40,7 @@ import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.graph.vertex.impl.*; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -55,7 +55,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestGraphNodes extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java index 86f3b7cb7..8e912bef8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java @@ -29,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,8 +42,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.lang.reflect.Field; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TestDropout extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java index a25344a70..d549ec3ad 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.ConvolutionUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -51,8 +51,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class ConvDataFormatTests extends BaseDL4JTest { @@ -865,13 +865,13 @@ public class ConvDataFormatTests extends BaseDL4JTest { INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1); INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1); - assertEquals(tc.msg, l0_1, l0_2); + assertEquals(l0_1, l0_2,tc.msg); if(l0_1.rank() == 4) { - assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2)); - assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2)); + assertEquals(l0_1, l0_3.permute(0, 3, 1, 2),tc.msg); + assertEquals(l0_1, l0_4.permute(0, 3, 1, 2),tc.msg); } else { - assertEquals(tc.msg, l0_1, l0_3); - assertEquals(tc.msg, l0_1, l0_4); + assertEquals(l0_1, l0_3,tc.msg); + assertEquals( l0_1, l0_4,tc.msg); } @@ -880,13 +880,13 @@ public class ConvDataFormatTests extends BaseDL4JTest { INDArray out3 = tc.net3.output(inNHWC); INDArray out4 = tc.net4.output(inNHWC); - assertEquals(tc.msg, out1, out2); + assertEquals(out1, out2,tc.msg); if(!tc.nhwcOutput) { - assertEquals(tc.msg, out1, out3); - assertEquals(tc.msg, out1, out4); + assertEquals(out1, out3,tc.msg); + assertEquals( out1, out4,tc.msg); } else { - assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW - assertEquals(tc.msg, out1, out4.permute(0,3,1,2)); + assertEquals(out1, out3.permute(0,3,1,2),tc.msg); //NHWC to NCHW + assertEquals(out1, out4.permute(0,3,1,2),tc.msg); } //Test backprop @@ -896,29 +896,29 @@ public class ConvDataFormatTests extends BaseDL4JTest { Pair p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null); //Inpput gradients - assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); - assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format - assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2)); + assertEquals( p1.getSecond(), p2.getSecond(),tc.msg); + assertEquals(p1.getSecond(), p3.getSecond().permute(0,3,1,2),tc.msg); //Input gradients for NHWC input are also in NHWC format + assertEquals( p1.getSecond(), p4.getSecond().permute(0,3,1,2),tc.msg); List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); - assertEquals(tc.msg + " " + diff12, 0, diff12.size()); - assertEquals(tc.msg + " " + diff13, 0, diff13.size()); - assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + assertEquals( 0, diff12.size(),tc.msg + " " + diff12); + assertEquals( 0, diff13.size(),tc.msg + " " + diff13); + assertEquals(0, diff14.size(),tc.msg + " " + diff14); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + assertEquals(p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable(),tc.msg); + assertEquals(p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable(),tc.msg); + assertEquals( p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable(),tc.msg); tc.net1.fit(inNCHW, tc.labelsNCHW); tc.net2.fit(inNCHW, tc.labelsNCHW); tc.net3.fit(inNHWC, tc.labelsNHWC); tc.net4.fit(inNHWC, tc.labelsNHWC); - assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); - assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); - assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + assertEquals(tc.net1.params(), tc.net2.params(),tc.msg); + assertEquals(tc.net1.params(), tc.net3.params(),tc.msg); + assertEquals(tc.net1.params(), tc.net4.params(),tc.msg); //Test serialization MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); @@ -927,14 +927,14 @@ public class ConvDataFormatTests extends BaseDL4JTest { MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); out1 = tc.net1.output(inNCHW); - assertEquals(tc.msg, out1, net1a.output(inNCHW)); - assertEquals(tc.msg, out1, net2a.output(inNCHW)); + assertEquals(out1, net1a.output(inNCHW),tc.msg); + assertEquals(out1, net2a.output(inNCHW),tc.msg); if(!tc.nhwcOutput) { - assertEquals(tc.msg, out1, net3a.output(inNHWC)); - assertEquals(tc.msg, out1, net4a.output(inNHWC)); + assertEquals( out1, net3a.output(inNHWC),tc.msg); + assertEquals(out1, net4a.output(inNHWC),tc.msg); } else { - assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW - assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); + assertEquals(out1, net3a.output(inNHWC).permute(0,3,1,2),tc.msg); //NHWC to NCHW + assertEquals(out1, net4a.output(inNHWC).permute(0,3,1,2),tc.msg); } } @@ -1033,7 +1033,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { } catch (DL4JInvalidInputException e) { // e.printStackTrace(); String msg = e.getMessage(); - assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration")); + assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"),msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java index d5b20a072..6cc561ceb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ConvolutionUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -45,7 +45,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestConvolutionModes extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java index 0e777ea1e..2af9576fc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.layers.custom.testclasses.CustomActivation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.learning.config.Sgd; @@ -37,8 +37,8 @@ import org.nd4j.shade.jackson.databind.jsontype.NamedType; import java.util.Collection; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestCustomActivation extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java index cd73c0850..47a4338a8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.layers.custom.testclasses.CustomOutputLayer; import org.deeplearning4j.nn.layers.custom.testclasses.CustomOutputLayerImpl; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -47,8 +47,8 @@ import java.util.Collection; import java.util.HashSet; import java.util.Set; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestCustomLayers extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index 72efd60b7..24f0b06fb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -25,9 +25,10 @@ import org.apache.commons.io.IOUtils; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.common.io.ClassPathResource; import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader; @@ -45,7 +46,7 @@ import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -61,17 +62,17 @@ import java.io.InputStream; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.net.URI; +import java.nio.file.Path; import java.util.Collections; import java.util.Comparator; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; public class TestYolo2OutputLayer extends BaseDL4JTest { - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); + @Test public void testYoloActivateScoreBasic() { @@ -225,12 +226,13 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { } @Test - public void testIOUCalc() throws Exception { + public void testIOUCalc(@TempDir Path tempDir) throws Exception { InputStream is1 = new ClassPathResource("yolo/VOC_SingleImage/JPEGImages/2007_009346.jpg").getInputStream(); InputStream is2 = new ClassPathResource("yolo/VOC_SingleImage/Annotations/2007_009346.xml").getInputStream(); - File dir = tempDir.newFolder("testYoloOverfitting"); + File dir = new File(tempDir.toFile(),"testYoloOverfitting"); + dir.mkdirs(); File jpg = new File(dir, "JPEGImages"); File annot = new File(dir, "Annotations"); jpg.mkdirs(); @@ -428,8 +430,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { @Test - @Ignore //TODO UNIGNORE THIS - IGNORED AS CRASHING JVM HENCE GETTING IN THE WAY OF FIXING OTHER PROBLEMS - public void testYoloOverfitting() throws Exception { + @Disabled //TODO UNIGNORE THIS - IGNORED AS CRASHING JVM HENCE GETTING IN THE WAY OF FIXING OTHER PROBLEMS + public void testYoloOverfitting(@TempDir Path tempDir) throws Exception { Nd4j.getRandom().setSeed(12345); InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream(); @@ -437,7 +439,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { InputStream is3 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2008_003344.jpg").getInputStream(); InputStream is4 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2008_003344.xml").getInputStream(); - File dir = tempDir.newFolder(); + File dir = tempDir.toFile(); File jpg = new File(dir, "JPEGImages"); File annot = new File(dir, "Annotations"); jpg.mkdirs(); @@ -584,8 +586,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { double p1 = o1.getClassPredictions().getDouble(idxCat); double c1 = o1.getConfidence(); assertEquals(idxCat, o1.getPredictedClass() ); - assertTrue(String.valueOf(p1), p1 >= 0.85); - assertTrue(String.valueOf(c1), c1 >= 0.85); + assertTrue(p1 >= 0.85,String.valueOf(p1)); + assertTrue(c1 >= 0.85,String.valueOf(c1)); assertEquals(cx1, o1.getCenterX(), 0.1); assertEquals(cy1, o1.getCenterY(), 0.1); assertEquals(wGrid1, o1.getWidth(), 0.2); @@ -596,8 +598,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { double p2 = o2.getClassPredictions().getDouble(idxCat); double c2 = o2.getConfidence(); assertEquals(idxCat, o2.getPredictedClass() ); - assertTrue(String.valueOf(p2), p2 >= 0.85); - assertTrue(String.valueOf(c2), c2 >= 0.85); + assertTrue(p2 >= 0.85,String.valueOf(p2)); + assertTrue(c2 >= 0.85,String.valueOf(c2)); assertEquals(cx2, o2.getCenterX(), 0.1); assertEquals(cy2, o2.getCenterY(), 0.1); assertEquals(wGrid2, o2.getWidth(), 0.2); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index 4d0f9bc66..2c6907137 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.impl.ActivationIdentity; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index 6733bdaab..a5096cf2f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,8 +41,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Random; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.*; public class GlobalPoolingMaskingTests extends BaseDL4JTest { @@ -288,7 +288,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { INDArray outSubset = net.output(subset); INDArray outMaskedSubset = outMasked.getRow(i, true); - assertEquals("minibatch: " + i, outSubset, outMaskedSubset); + assertEquals(outSubset, outMaskedSubset, "minibatch: " + i); } } } @@ -347,7 +347,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { INDArray outSubset = net.output(subset); INDArray outMaskedSubset = outMasked.getRow(i, true); - assertEquals("minibatch: " + i, outSubset, outMaskedSubset); + assertEquals(outSubset, outMaskedSubset, "minibatch: " + i); } } } @@ -412,7 +412,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { INDArray outSubset = net.output(subset); INDArray outMaskedSubset = outMasked.getRow(i,true); - assertEquals("minibatch: " + i + ", " + pt, outSubset, outMaskedSubset); + assertEquals(outSubset, outMaskedSubset, "minibatch: " + i + ", " + pt); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java index e420ece1d..118fbf6b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -39,7 +39,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -52,7 +52,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) @AllArgsConstructor @@ -302,28 +302,28 @@ public class RnnDataFormatTests extends BaseDL4JTest { INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1); boolean rank3Out = tc.labelsNCW.rank() == 3; - assertEquals(tc.msg, l0_1, l0_2); + assertEquals(l0_1, l0_2, tc.msg); if (rank3Out){ - assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1)); - assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1)); + assertEquals(l0_1, l0_3.permute(0, 2, 1), tc.msg); + assertEquals(l0_1, l0_4.permute(0, 2, 1), tc.msg); } else{ - assertEquals(tc.msg, l0_1, l0_3); - assertEquals(tc.msg, l0_1, l0_4); + assertEquals(l0_1, l0_3, tc.msg); + assertEquals(l0_1, l0_4, tc.msg); } INDArray out1 = tc.net1.output(inNCW); INDArray out2 = tc.net2.output(inNCW); INDArray out3 = tc.net3.output(inNWC); INDArray out4 = tc.net4.output(inNWC); - assertEquals(tc.msg, out1, out2); + assertEquals(out1, out2, tc.msg); if (rank3Out){ - assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW - assertEquals(tc.msg, out1, out4.permute(0, 2, 1)); + assertEquals(out1, out3.permute(0, 2, 1), tc.msg); //NWC to NCW + assertEquals(out1, out4.permute(0, 2, 1), tc.msg); } else{ - assertEquals(tc.msg, out1, out3); //NWC to NCW - assertEquals(tc.msg, out1, out4); + assertEquals(out1, out3, tc.msg); //NWC to NCW + assertEquals(out1, out4, tc.msg); } @@ -334,31 +334,31 @@ public class RnnDataFormatTests extends BaseDL4JTest { Pair p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null); //Inpput gradients - assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + assertEquals(p1.getSecond(), p2.getSecond(), tc.msg); - assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format - assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1)); + assertEquals(p1.getSecond(), p3.getSecond().permute(0, 2, 1), tc.msg); //Input gradients for NWC input are also in NWC format + assertEquals(p1.getSecond(), p4.getSecond().permute(0, 2, 1), tc.msg); List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); - assertEquals(tc.msg + " " + diff12, 0, diff12.size()); - assertEquals(tc.msg + " " + diff13, 0, diff13.size()); - assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + assertEquals(0, diff12.size(),tc.msg + " " + diff12); + assertEquals(0, diff13.size(),tc.msg + " " + diff13); + assertEquals( 0, diff14.size(),tc.msg + " " + diff14); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); - assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + assertEquals(p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable(), tc.msg); + assertEquals(p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable(), tc.msg); + assertEquals(p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable(), tc.msg); tc.net1.fit(inNCW, tc.labelsNCW); tc.net2.fit(inNCW, tc.labelsNCW); tc.net3.fit(inNWC, tc.labelsNWC); tc.net4.fit(inNWC, tc.labelsNWC); - assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); - assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); - assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + assertEquals(tc.net1.params(), tc.net2.params(), tc.msg); + assertEquals(tc.net1.params(), tc.net3.params(), tc.msg); + assertEquals(tc.net1.params(), tc.net4.params(), tc.msg); //Test serialization MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); @@ -367,16 +367,16 @@ public class RnnDataFormatTests extends BaseDL4JTest { MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); out1 = tc.net1.output(inNCW); - assertEquals(tc.msg, out1, net1a.output(inNCW)); - assertEquals(tc.msg, out1, net2a.output(inNCW)); + assertEquals(out1, net1a.output(inNCW), tc.msg); + assertEquals(out1, net2a.output(inNCW), tc.msg); if (rank3Out) { - assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW - assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1)); + assertEquals(out1, net3a.output(inNWC).permute(0, 2, 1), tc.msg); //NWC to NCW + assertEquals(out1, net4a.output(inNWC).permute(0, 2, 1), tc.msg); } else{ - assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW - assertEquals(tc.msg, out1, net4a.output(inNWC)); + assertEquals(out1, net3a.output(inNWC), tc.msg); //NWC to NCW + assertEquals(out1, net4a.output(inNWC), tc.msg); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 614bf1d65..66a87c872 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,7 +44,7 @@ import org.nd4j.linalg.learning.config.AdaGrad; import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.activations.Activation.IDENTITY; import static org.nd4j.linalg.activations.Activation.TANH; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java index 39b73b11e..74cfd2010 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java @@ -27,10 +27,10 @@ import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestRecurrentWeightInit extends BaseDL4JTest { @@ -86,11 +86,11 @@ public class TestRecurrentWeightInit extends BaseDL4JTest { double min = rw.minNumber().doubleValue(); double max = rw.maxNumber().doubleValue(); if(rwInit){ - assertTrue(String.valueOf(min), min >= 2.0); - assertTrue(String.valueOf(max), max <= 3.0); + assertTrue(min >= 2.0, String.valueOf(min)); + assertTrue(max <= 3.0, String.valueOf(max)); } else { - assertTrue(String.valueOf(min), min >= 0.0); - assertTrue(String.valueOf(max), max <= 1.0); + assertTrue(min >= 0.0, String.valueOf(min)); + assertTrue(max <= 1.0, String.valueOf(max)); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index cf425991b..dba4ae308 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -50,9 +50,9 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class TestRnnLayers extends BaseDL4JTest { @@ -178,8 +178,8 @@ public class TestRnnLayers extends BaseDL4JTest { MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2); netD2.init(); - assertEquals(s, net.params(), netD.params()); - assertEquals(s, net.params(), netD2.params()); + assertEquals(net.params(), netD.params(), s); + assertEquals(net.params(), netD2.params(), s); INDArray f = Nd4j.rand(DataType.FLOAT, new int[]{3, 10, 10}); @@ -187,18 +187,18 @@ public class TestRnnLayers extends BaseDL4JTest { INDArray out1 = net.output(f); INDArray out1D = netD.output(f); INDArray out1D2 = netD2.output(f); - assertEquals(s, out1, out1D); - assertEquals(s, out1, out1D2); + assertEquals(out1, out1D, s); + assertEquals(out1, out1D2, s); INDArray out2 = net.output(f, true); INDArray out2D = netD.output(f, true); - assertNotEquals(s, out2, out2D); + assertNotEquals(out2, out2D, s); INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); net.fit(f.dup(), l); netD.fit(f.dup(), l); - assertNotEquals(s, net.params(), netD.params()); + assertNotEquals(net.params(), netD.params(), s); netD2.fit(f.dup(), l); netD2.fit(f.dup(), l); @@ -210,7 +210,7 @@ public class TestRnnLayers extends BaseDL4JTest { new Pair<>(1, 0), new Pair<>(2, 0)); - assertEquals(s, expected, cd.getAllCalls()); + assertEquals(expected, cd.getAllCalls(), s); } } @@ -248,7 +248,7 @@ public class TestRnnLayers extends BaseDL4JTest { if(msg == null) t.printStackTrace(); System.out.println(i); - assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); + assertTrue(msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"), msg); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index c609d54be..a316ac858 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -38,7 +38,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -115,7 +115,7 @@ public class TestSimpleRnn extends BaseDL4JTest { else{ outActCurrent = out.get(all(), point(i), all()); } - assertEquals(String.valueOf(i), outExpCurrent, outActCurrent); + assertEquals(outExpCurrent, outActCurrent, String.valueOf(i)); outLast = outExpCurrent; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index c1b157ff9..acae4faf3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -47,7 +47,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class TestTimeDistributed extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java index e15525a98..c9dc7bc5b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java @@ -34,11 +34,10 @@ import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -51,14 +50,14 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertThrows; + @Slf4j public class SameDiffCustomLayerTests extends BaseDL4JTest { private DataType initialType; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - @Before + @BeforeEach public void before() { Nd4j.create(1); initialType = Nd4j.dataType(); @@ -67,7 +66,7 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest { Nd4j.getRandom().setSeed(123); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); @@ -77,46 +76,48 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest { @Test public void testInputValidationSameDiffLayer(){ - final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().list() - .layer(new ValidatingSameDiffLayer()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build()) - .setInputType(InputType.feedForward(2)) - .build(); + assertThrows(IllegalArgumentException.class,() -> { + final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().list() + .layer(new ValidatingSameDiffLayer()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build()) + .setInputType(InputType.feedForward(2)) + .build(); - final MultiLayerNetwork net = new MultiLayerNetwork(config); - net.init(); + final MultiLayerNetwork net = new MultiLayerNetwork(config); + net.init(); - final INDArray goodInput = Nd4j.rand(1, 2); - final INDArray badInput = Nd4j.rand(2, 2); + final INDArray goodInput = Nd4j.rand(1, 2); + final INDArray badInput = Nd4j.rand(2, 2); - net.fit(goodInput, goodInput); + net.fit(goodInput, goodInput); + net.fit(badInput, badInput); + + + }); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Expected Message"); - net.fit(badInput, badInput); } @Test public void testInputValidationSameDiffVertex(){ - final ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().graphBuilder() - .addVertex("a", new ValidatingSameDiffVertex(), "input") - .addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build(), "a") - .addInputs("input") - .setInputTypes(InputType.feedForward(2)) - .setOutputs("output") - .build(); + assertThrows(IllegalArgumentException.class,() -> { + final ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().graphBuilder() + .addVertex("a", new ValidatingSameDiffVertex(), "input") + .addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build(), "a") + .addInputs("input") + .setInputTypes(InputType.feedForward(2)) + .setOutputs("output") + .build(); - final ComputationGraph net = new ComputationGraph(config); - net.init(); + final ComputationGraph net = new ComputationGraph(config); + net.init(); - final INDArray goodInput = Nd4j.rand(1, 2); - final INDArray badInput = Nd4j.rand(2, 2); + final INDArray goodInput = Nd4j.rand(1, 2); + final INDArray badInput = Nd4j.rand(2, 2); - net.fit(new INDArray[]{goodInput}, new INDArray[]{goodInput}); + net.fit(new INDArray[]{goodInput}, new INDArray[]{goodInput}); + net.fit(new INDArray[]{badInput}, new INDArray[]{badInput}); + }); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Expected Message"); - net.fit(new INDArray[]{badInput}, new INDArray[]{badInput}); } private class ValidatingSameDiffLayer extends org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index dd109a6fa..e1510e56a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffConv; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -47,7 +47,7 @@ import java.util.Arrays; import java.util.Map; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.junit.Assume.assumeTrue; @Slf4j @@ -210,13 +210,13 @@ public class TestSameDiffConv extends BaseDL4JTest { INDArray out = net.output(in); INDArray outExp = net2.output(in); - assertEquals(msg, outExp, out); + assertEquals(outExp, out, msg); //Also check serialization: MultiLayerNetwork netLoaded = TestUtils.testModelSerialization(net); INDArray outLoaded = netLoaded.output(in); - assertEquals(msg, outExp, outLoaded); + assertEquals(outExp, outLoaded, msg); //Sanity check on different minibatch sizes: INDArray newIn = Nd4j.vstack(in, in); @@ -313,7 +313,7 @@ public class TestSameDiffConv extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) .labels(l).subset(true).maxPerParam(50)); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index c7e9cc046..a28230d48 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffDense; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,7 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestSameDiffDense extends BaseDL4JTest { @@ -316,7 +316,7 @@ public class TestSameDiffDense extends BaseDL4JTest { INDArray i1 = m1.get(s); INDArray i2 = m2.get(s); - assertEquals(s, i2, i1); + assertEquals(i2, i1, s); } assertEquals(gStd.gradient(), gSD.gradient()); @@ -398,9 +398,9 @@ public class TestSameDiffDense extends BaseDL4JTest { netSD.fit(ds); netStandard.fit(ds); String s = String.valueOf(i); - assertEquals(s, netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); - assertEquals(s, netStandard.params(), netSD.params()); - assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray()); + assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s); + assertEquals(netStandard.params(), netSD.params(), s); + assertEquals(netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s); } //Sanity check on different minibatch sizes: @@ -446,7 +446,7 @@ public class TestSameDiffDense extends BaseDL4JTest { boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l); - assertTrue(msg, gradOK); + assertTrue(gradOK, msg); TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 8c38177db..fd7ec1d32 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffDenseVertex; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class TestSameDiffDenseVertex extends BaseDL4JTest { @@ -134,7 +134,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { INDArray i1 = m1.get(s); INDArray i2 = m2.get(s); - assertEquals(s, i2, i1); + assertEquals(i2, i1, s); } assertEquals(gStd.gradient(), gSD.gradient()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index 6c5a79547..4afbc7e37 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -34,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaLayer; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaVertex; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,7 +44,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class TestSameDiffLambda extends BaseDL4JTest { @@ -119,8 +119,8 @@ public class TestSameDiffLambda extends BaseDL4JTest { std.fit(ds); String s = String.valueOf(i); - assertEquals(s, std.params(), lambda.params()); - assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients()); + assertEquals(std.params(), lambda.params(), s); + assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); } ComputationGraph loaded = TestUtils.testModelSerialization(lambda); @@ -204,8 +204,8 @@ public class TestSameDiffLambda extends BaseDL4JTest { std.fit(mds); String s = String.valueOf(i); - assertEquals(s, std.params(), lambda.params()); - assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients()); + assertEquals(std.params(), lambda.params(), s); + assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); } ComputationGraph loaded = TestUtils.testModelSerialization(lambda); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java index a5307121a..0207341b5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java @@ -31,7 +31,7 @@ import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffMSELossLayer; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffMSEOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class TestSameDiffOutput extends BaseDL4JTest { @@ -166,8 +166,8 @@ public class TestSameDiffOutput extends BaseDL4JTest { netSD.fit(ds); netStd.fit(ds); String s = String.valueOf(i); - assertEquals(s, netStd.params(), netSD.params()); - assertEquals(s, netStd.getFlattenedGradients(), netSD.getFlattenedGradients()); + assertEquals(netStd.params(), netSD.params(), s); + assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients(), s); } //Test fit before output: diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java index 9491f9ee7..934ba63a8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDist import org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -41,7 +41,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestReconstructionDistributions extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index e4fa96753..e61614a1b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,7 +49,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestVAE extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java index b4de50f49..c041154ba 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java @@ -28,14 +28,14 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class CloseNetworkTests extends BaseDL4JTest { @@ -92,14 +92,14 @@ public class CloseNetworkTests extends BaseDL4JTest { net.output(f); } catch (IllegalStateException e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("released")); + assertTrue(msg.contains("released"),msg); } try { net.fit(f, l); } catch (IllegalStateException e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("released")); + assertTrue(msg.contains("released"),msg); } } } @@ -140,14 +140,14 @@ public class CloseNetworkTests extends BaseDL4JTest { net.output(f); } catch (IllegalStateException e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("released")); + assertTrue( msg.contains("released"),msg); } try { net.fit(new INDArray[]{f}, new INDArray[]{l}); } catch (IllegalStateException e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("released")); + assertTrue(msg.contains("released"),msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java index 744a89e98..77f3a2342 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -42,7 +42,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.schedule.ExponentialSchedule; import org.nd4j.linalg.schedule.ScheduleType; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestLrChanges extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java index 8a2593dd2..a7fcee172 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryType; import org.deeplearning4j.nn.conf.memory.MemoryUseMode; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,8 +48,8 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestMemoryReports extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java index 02e857d99..cd1ca1a28 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java @@ -29,14 +29,14 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestNetConversion extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 14dc2faf2..ad57a4688 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -33,9 +33,9 @@ import org.deeplearning4j.nn.misc.iter.WSTestDataSetIterator; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -53,17 +53,17 @@ import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class WorkspaceTests extends BaseDL4JTest { - @Before + @BeforeEach public void before() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } - @After + @AfterEach public void after() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index 2c9ece0fd..5d952bf6d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -34,8 +34,8 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -195,7 +195,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { } } - @Test @Ignore //https://github.com/deeplearning4j/deeplearning4j/issues/7272 + @Test @Disabled //https://github.com/deeplearning4j/deeplearning4j/issues/7272 public void validateLRN() { //Only run test if using nd4j-native backend diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index bd1a1d540..d93938555 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -46,7 +46,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 6c3ad1855..60f5b91a6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -52,8 +52,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.*; -import org.junit.Test; +import org.junit.jupiter.api.*;import org.junit.jupiter.api.Test; import org.junit.jupiter.api.*; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index f90ad98b4..6f1b3f732 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -42,7 +42,7 @@ import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -57,7 +57,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class MultiLayerTestRNN extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java index dd0dd103d..420417296 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; @@ -54,7 +54,7 @@ import org.nd4j.linalg.lossfunctions.impl.*; import java.util.Collections; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestMasking extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java index c37600b86..30b9fdba7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +33,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestSetGetParameters extends BaseDL4JTest { @@ -63,7 +63,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Map initParams2After = net.paramTable(); for (String s : initParams2.keySet()) { - assertTrue("Params differ: " + s, initParams2.get(s).equals(initParams2After.get(s))); + assertTrue( initParams2.get(s).equals(initParams2After.get(s)),"Params differ: " + s); } assertEquals(initParams, initParamsAfter); @@ -100,7 +100,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Map initParams2After = net.paramTable(); for (String s : initParams2.keySet()) { - assertTrue("Params differ: " + s, initParams2.get(s).equals(initParams2After.get(s))); + assertTrue( initParams2.get(s).equals(initParams2After.get(s)),"Params differ: " + s); } assertEquals(initParams, initParamsAfter); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 02285588a..b113006ec 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.util.TimeSeriesUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -53,7 +53,7 @@ import java.util.List; import java.util.Map; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestVariableLengthTS extends BaseDL4JTest { @@ -124,7 +124,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { for (String s : g1map.keySet()) { INDArray g1s = g1map.get(s); INDArray g2s = g2map.get(s); - assertEquals(s, g1s, g2s); + assertEquals(g1s, g2s,s); } //Finally: check that the values at the masked outputs don't actually make any differente to: @@ -142,7 +142,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { for (String s : g2map.keySet()) { INDArray g2s = g2map.get(s); INDArray g2sa = g2a.getGradientFor(s); - assertEquals(s, g2s, g2sa); + assertEquals(g2s, g2sa,s); } } } @@ -231,7 +231,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { // System.out.println("Variable: " + s); // System.out.println(Arrays.toString(g1s.dup().data().asFloat())); // System.out.println(Arrays.toString(g2s.dup().data().asFloat())); - assertNotEquals(s, g1s, g2s); + assertNotEquals(g1s, g2s,s); } //Modify the values at the masked time step, and check that neither the gradients, score or activations change @@ -331,7 +331,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { mln.computeGradientAndScore(); double score = mln.score(); - assertEquals(msg, expScore, score, 0.1); + assertEquals(expScore, score, 0.1,msg); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java index a75d806f0..410abf970 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -41,8 +41,8 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class TestMultiModelGradientApplication extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java index 938792924..52b86b276 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java @@ -34,7 +34,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -44,7 +44,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.LinkedHashMap; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestFrozenLayers extends BaseDL4JTest { @@ -90,9 +90,9 @@ public class TestFrozenLayers extends BaseDL4JTest { String s = msg + " - " + entry.getKey(); if(entry.getKey().startsWith("5_")){ //Non-frozen layer - assertNotEquals(s, paramsBefore.get(entry.getKey()), entry.getValue()); + assertNotEquals(paramsBefore.get(entry.getKey()), entry.getValue(), s); } else { - assertEquals(s, paramsBefore.get(entry.getKey()), entry.getValue()); + assertEquals(paramsBefore.get(entry.getKey()), entry.getValue(), s); } } } @@ -142,9 +142,9 @@ public class TestFrozenLayers extends BaseDL4JTest { String s = msg + " - " + entry.getKey(); if(entry.getKey().startsWith("5_")){ //Non-frozen layer - assertNotEquals(s, paramsBefore.get(entry.getKey()), entry.getValue()); + assertNotEquals(paramsBefore.get(entry.getKey()), entry.getValue(), s); } else { - assertEquals(s, paramsBefore.get(entry.getKey()), entry.getValue()); + assertEquals(paramsBefore.get(entry.getKey()), entry.getValue(), s); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningJson.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningJson.java index 6c5e1ae93..a1598647a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningJson.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningJson.java @@ -21,11 +21,11 @@ package org.deeplearning4j.nn.transferlearning; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.AdaGrad; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestTransferLearningJson extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java index 795d2950c..ad92a7c47 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -41,7 +41,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestTransferLearningModelSerializer extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java index 3b6e319c3..6e9851ab6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.MultiDataSet; @@ -42,7 +42,7 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TransferLearningComplex extends BaseDL4JTest { @@ -92,9 +92,9 @@ public class TransferLearningComplex extends BaseDL4JTest { if ("C".equals(l.conf().getLayer().getLayerName())) { //Only C should be frozen in this config cFound = true; - assertTrue(name, l instanceof FrozenLayer); + assertTrue(l instanceof FrozenLayer, name); } else { - assertFalse(name, l instanceof FrozenLayer); + assertFalse(l instanceof FrozenLayer, name); } //Also check config: diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java index 083d32eff..02616d66d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.DefaultParamInitializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -38,7 +38,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.NoOp; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestGradientNormalization extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index 9b4ae74ce..28e27511b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -38,8 +38,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -53,7 +53,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.lang.reflect.Method; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -70,7 +70,7 @@ public class TestUpdaters extends BaseDL4JTest { protected String key; - @Before + @BeforeEach public void beforeDo() { gradients = Nd4j.ones(1, nIn * nOut + nOut); weightGradient = gradients.get(point(0), interval(0, nIn * nOut)); @@ -320,7 +320,7 @@ public class TestUpdaters extends BaseDL4JTest { count++; } - assertEquals("Count should be equal to 2, one for weight gradient and one for bias gradient", 2, count); + assertEquals(2, count,"Count should be equal to 2, one for weight gradient and one for bias gradient"); /* * Check that we are not erroneously mutating moving avg gradient while calculating @@ -340,7 +340,7 @@ public class TestUpdaters extends BaseDL4JTest { actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; } - assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true); + assertEquals(Arrays.equals(expectedM, actualM), true, "Wrong weight gradient after first iteration's update"); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java index e2d7b46e3..703d56eb2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java @@ -27,15 +27,15 @@ import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestCustomUpdater extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 1dea22f32..5c774c3f4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -42,7 +42,7 @@ import org.deeplearning4j.optimize.solvers.LBFGS; import org.deeplearning4j.optimize.solvers.LineGradientDescent; import org.deeplearning4j.optimize.solvers.StochasticGradientDescent; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -65,7 +65,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestOptimizers extends BaseDL4JTest { @@ -118,13 +118,13 @@ public class TestOptimizers extends BaseDL4JTest { double[] scores = new double[nCallsToOptimizer + 1]; scores[0] = score; for (int i = 0; i < nCallsToOptimizer; i++) { - for( int j=0; j getNetAndData(){ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -72,8 +72,8 @@ public class TestCheckpointListener extends BaseDL4JTest { } @Test - public void testCheckpointListenerEvery2Epochs() throws Exception { - File f = tempDir.newFolder(); + public void testCheckpointListenerEvery2Epochs(@TempDir Path tempDir) throws Exception { + File f = tempDir.toFile(); Pair p = getNetAndData(); MultiLayerNetwork net = p.getFirst(); DataSetIterator iter = p.getSecond(); @@ -121,8 +121,8 @@ public class TestCheckpointListener extends BaseDL4JTest { } @Test - public void testCheckpointListenerEvery5Iter() throws Exception { - File f = tempDir.newFolder(); + public void testCheckpointListenerEvery5Iter(@TempDir Path tempDir) throws Exception { + File f = tempDir.toFile(); Pair p = getNetAndData(); MultiLayerNetwork net = p.getFirst(); DataSetIterator iter = p.getSecond(); @@ -159,7 +159,7 @@ public class TestCheckpointListener extends BaseDL4JTest { count++; } - assertEquals(ns.toString(), 3, ns.size()); + assertEquals( 3, ns.size(),ns.toString()); assertTrue(ns.contains(25)); assertTrue(ns.contains(30)); assertTrue(ns.contains(35)); @@ -178,8 +178,8 @@ public class TestCheckpointListener extends BaseDL4JTest { } @Test - public void testCheckpointListenerEveryTimeUnit() throws Exception { - File f = tempDir.newFolder(); + public void testCheckpointListenerEveryTimeUnit(@TempDir Path tempDir) throws Exception { + File f = tempDir.toFile(); Pair p = getNetAndData(); MultiLayerNetwork net = p.getFirst(); DataSetIterator iter = p.getSecond(); @@ -216,14 +216,14 @@ public class TestCheckpointListener extends BaseDL4JTest { } assertEquals(2, l.availableCheckpoints().size()); - assertEquals(ns.toString(), 2, ns.size()); + assertEquals(2, ns.size(),ns.toString()); System.out.println(ns); assertTrue(ns.containsAll(Arrays.asList(2,4))); } @Test - public void testCheckpointListenerKeepLast3AndEvery3() throws Exception { - File f = tempDir.newFolder(); + public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path tempDir) throws Exception { + File f = tempDir.toFile(); Pair p = getNetAndData(); MultiLayerNetwork net = p.getFirst(); DataSetIterator iter = p.getSecond(); @@ -261,15 +261,15 @@ public class TestCheckpointListener extends BaseDL4JTest { count++; } - assertEquals(ns.toString(), 5, ns.size()); - assertTrue(ns.toString(), ns.containsAll(Arrays.asList(5, 11, 15, 17, 19))); + assertEquals(5, ns.size(),ns.toString()); + assertTrue(ns.containsAll(Arrays.asList(5, 11, 15, 17, 19)),ns.toString()); assertEquals(5, l.availableCheckpoints().size()); } @Test - public void testDeleteExisting() throws Exception { - File f = tempDir.newFolder(); + public void testDeleteExisting(@TempDir Path tempDir) throws Exception { + File f = tempDir.toFile(); Pair p = getNetAndData(); MultiLayerNetwork net = p.getFirst(); DataSetIterator iter = p.getSecond(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java index cf604d365..f8a40034a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java @@ -27,8 +27,8 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.FailureTestingListener; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Adam; @@ -36,17 +36,17 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.net.InetAddress; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; /** * WARNING: DO NOT ENABLE (UN-IGNORE) THESE TESTS. * They should be run manually, not as part of standard unit test run. */ -@Ignore +@Disabled public class TestFailureListener extends BaseDL4JTest { - @Ignore + @Disabled @Test public void testFailureIter5() throws Exception { @@ -68,7 +68,7 @@ public class TestFailureListener extends BaseDL4JTest { net.fit(iter); } - @Ignore + @Disabled @Test public void testFailureRandom_OR(){ @@ -96,7 +96,7 @@ public class TestFailureListener extends BaseDL4JTest { net.fit(iter); } - @Ignore + @Disabled @Test public void testFailureRandom_AND() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index 6897ea540..6798a2094 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -44,9 +44,10 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.TimeIterationListener; import org.deeplearning4j.optimize.listeners.CheckpointListener; import org.deeplearning4j.optimize.solvers.BaseOptimizer; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -57,20 +58,18 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class TestListeners extends BaseDL4JTest { - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); - @Override public long getTimeoutMilliseconds() { return 90000L; @@ -91,7 +90,7 @@ public class TestListeners extends BaseDL4JTest { for (Layer l : net.getLayers()) { Collection layerListeners = l.getListeners(); - assertEquals(l.getClass().toString(), 2, layerListeners.size()); + assertEquals(2, layerListeners.size(),l.getClass().toString()); TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]); assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[1] instanceof TestRoutingListener); @@ -168,7 +167,7 @@ public class TestListeners extends BaseDL4JTest { @Test - public void testListenerSerialization() throws Exception { + public void testListenerSerialization(@TempDir Path tempDir) throws Exception { //Note: not all listeners are (or should be) serializable. But some should be - for Spark etc List listeners = new ArrayList<>(); @@ -176,7 +175,7 @@ public class TestListeners extends BaseDL4JTest { listeners.add(new PerformanceListener(1, true, true)); listeners.add(new TimeIterationListener(10000)); listeners.add(new ComposableIterationListener(new ScoreIterationListener(), new PerformanceListener(1, true, true))); - listeners.add(new CheckpointListener.Builder(tempDir.newFolder()).keepAll().saveEveryNIterations(3).build()); //Doesn't usually need to be serialized, but no reason it can't be... + listeners.add(new CheckpointListener.Builder(tempDir.toFile()).keepAll().saveEveryNIterations(3).build()); //Doesn't usually need to be serialized, but no reason it can't be... DataSetIterator iter = new IrisDataSetIterator(10, 150); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/FancyBlockingQueueTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/FancyBlockingQueueTests.java index 91ef1d72f..82dfedaeb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/FancyBlockingQueueTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/FancyBlockingQueueTests.java @@ -24,12 +24,12 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.FancyBlockingQueue; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class FancyBlockingQueueTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java index aa8c5984f..bbdd861eb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java @@ -20,7 +20,7 @@ package org.deeplearning4j.parallelism; import lombok.extern.slf4j.Slf4j; -import org.junit.Rule; + import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.BaseDL4JTest; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java index e0eea0f27..97a1cb799 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -41,7 +41,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class RandomTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java index 97ef03af9..8d9bf2ed2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java @@ -24,7 +24,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.listener.HardwareMetric; import org.deeplearning4j.core.listener.SystemPolling; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java index de3d9a63c..b99c222bf 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java @@ -22,14 +22,14 @@ package org.deeplearning4j.perf.listener; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.listener.HardwareMetric; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import oshi.json.SystemInfo; import static junit.framework.TestCase.assertNotNull; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") +@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") public class TestHardWareMetric extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java index 270515bd6..4fe2be4e0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java @@ -28,30 +28,32 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") +@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") public class TestSystemInfoPrintListener extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testListener() throws Exception { + public void testListener(@TempDir Path testDir) throws Exception { SystemInfoPrintListener systemInfoPrintListener = SystemInfoPrintListener.builder() .printOnEpochStart(true).printOnEpochEnd(true) .build(); - File tmpFile = testDir.newFile("tmpfile-log.txt"); + File tmpFile = Files.createTempFile(testDir,"tmpfile-log","txt").toFile(); assertEquals(0, tmpFile.length() ); SystemInfoFilePrintListener systemInfoFilePrintListener = SystemInfoFilePrintListener.builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java index b4ea0e0cd..686501ff8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java @@ -30,16 +30,16 @@ import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.nio.charset.StandardCharsets; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MiscRegressionTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 0216ce3a2..5a6567fa7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -32,9 +32,9 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; + +import org.junit.jupiter.api.Test; + import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -48,7 +48,7 @@ import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class RegressionTest050 extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index b18933adb..da6976b6a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -37,7 +37,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -51,7 +51,7 @@ import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class RegressionTest060 extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index b804451c2..e2ef4b233 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -37,7 +37,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.buffer.DataType; @@ -52,7 +52,7 @@ import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class RegressionTest071 extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 532a8ebc0..b2af73f06 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -37,7 +37,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -51,7 +51,7 @@ import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class RegressionTest080 extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index 3a1642466..2da9d084b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -35,8 +35,8 @@ import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationSoftmax; @@ -53,10 +53,10 @@ import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled public class RegressionTest100a extends BaseDL4JTest { @Override @@ -82,7 +82,7 @@ public class RegressionTest100a extends BaseDL4JTest { fail("Expected exception"); } catch (Exception e){ String msg = e.getMessage(); - assertTrue(msg, msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again")); + assertTrue(msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again"), msg); } } @@ -169,7 +169,7 @@ public class RegressionTest100a extends BaseDL4JTest { @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testYoloHouseNumber() throws Exception { File f = Resources.asFile("regression_testing/100a/HouseNumberDetection_100a.bin"); @@ -215,12 +215,12 @@ public class RegressionTest100a extends BaseDL4JTest { log.info("Expected: {}", outExp); log.info("Actual: {}", outAct); } - assertTrue("Output not equal", eq); + assertTrue(eq, "Output not equal"); } @Test - @Ignore("Ignoring due to new set input types changes. Loading a network isn't a problem, but we need to set the input types yet.") + @Disabled("Ignoring due to new set input types changes. Loading a network isn't a problem, but we need to set the input types yet.") public void testUpsampling2d() throws Exception { File f = Resources.asFile("regression_testing/100a/upsampling/net.bin"); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 95455cca6..fce6cae1c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -34,8 +34,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,8 +50,8 @@ import java.io.File; import java.io.FileInputStream; import java.util.List; -import static org.junit.Assert.*; -@Ignore +import static org.junit.jupiter.api.Assertions.*; +@Disabled public class RegressionTest100b3 extends BaseDL4JTest { @Override @@ -115,7 +115,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { assertEquals(dt, net.getLayerWiseConfigurations().getDataType()); assertEquals(dt, net.params().dataType()); - assertEquals(dtype, outExp, outAct); + assertEquals(outExp, outAct, dtype); } } @@ -202,7 +202,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testYoloHouseNumber() throws Exception { File f = Resources.asFile("regression_testing/100b3/HouseNumberDetection_100b3.bin"); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index 41a06bdc8..fb11070c1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -20,9 +20,9 @@ package org.deeplearning4j.regressiontest; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.DataInputStream; import java.io.File; @@ -54,8 +54,8 @@ import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationReLU; @@ -71,7 +71,7 @@ import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.lossfunctions.impl.LossMAE; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.common.resources.Resources; -@Ignore +@Disabled public class RegressionTest100b4 extends BaseDL4JTest { @Override @@ -134,7 +134,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); - assertTrue("Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct, eq); + assertTrue(eq,"Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); } } @@ -221,7 +221,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { @Test - @Ignore("Failing due to new data format changes. Sept 10,2020") + @Disabled("Failing due to new data format changes. Sept 10,2020") public void testYoloHouseNumber() throws Exception { File f = Resources.asFile("regression_testing/100b4/HouseNumberDetection_100b4.bin"); @@ -257,7 +257,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { } @Test - @Ignore("failing due to new input data format changes.") + @Disabled("failing due to new input data format changes.") public void testSyntheticCNN() throws Exception { File f = Resources.asFile("regression_testing/100b4/SyntheticCNN_100b4.bin"); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 859872d5d..9e2dad7fe 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -35,8 +35,8 @@ import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -52,8 +52,8 @@ import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; -import static org.junit.Assert.*; -@Ignore +import static org.junit.jupiter.api.Assertions.*; +@Disabled public class RegressionTest100b6 extends BaseDL4JTest { @Override @@ -116,7 +116,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); - assertTrue("Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct, eq); + assertTrue(eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java index d4e2b015b..e4403e5de 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java @@ -23,11 +23,11 @@ package org.deeplearning4j.regressiontest; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.shade.jackson.databind.ObjectMapper; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestDistributionDeserializer extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index b0fbf1225..396fbe778 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; @@ -53,8 +53,8 @@ import org.nd4j.weightinit.impl.XavierInitScheme; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; @Slf4j public class CompareTrainingImplementations extends BaseDL4JTest { @@ -181,7 +181,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { INDArray outSd = map.get(a1.name()); INDArray outDl4j = net.output(f); - assertEquals(testName, outDl4j, outSd); + assertEquals(outDl4j, outSd, testName); net.setInput(f); net.setLabels(l); @@ -193,7 +193,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check score double scoreDl4j = net.score(); double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore(); - assertEquals(testName, scoreDl4j, scoreSd, 1e-6); + assertEquals( scoreDl4j, scoreSd, 1e-6,testName); double lossRegScoreSD = sd.calcRegularizationScore(); double lossRegScoreDL4J = net.calcRegularizationScore(true); @@ -207,10 +207,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only //We can check correctness though with training param checks later if(l1Val == 0 && l2Val == 0 && wdVal == 0) { - assertEquals(testName, grads.get("1_b"), gm.get(b1.name())); - assertEquals(testName, grads.get("1_W"), gm.get(w1.name())); - assertEquals(testName, grads.get("0_b"), gm.get(b0.name())); - assertEquals(testName, grads.get("0_W"), gm.get(w0.name())); + assertEquals(grads.get("1_b"), gm.get(b1.name()), testName); + assertEquals(grads.get("1_W"), gm.get(w1.name()), testName); + assertEquals(grads.get("0_b"), gm.get(b0.name()), testName); + assertEquals(grads.get("0_W"), gm.get(w0.name()), testName); } @@ -237,10 +237,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest { String s = testName + " - " + j; INDArray dl4j_0W = net.getParam("0_W"); INDArray sd_0W = w0.getArr(); - assertEquals(s, dl4j_0W, sd_0W); - assertEquals(s, net.getParam("0_b"), b0.getArr()); - assertEquals(s, net.getParam("1_W"), w1.getArr()); - assertEquals(s, net.getParam("1_b"), b1.getArr()); + assertEquals(dl4j_0W, sd_0W, s); + assertEquals(net.getParam("0_b"), b0.getArr(), s); + assertEquals(net.getParam("1_W"), w1.getArr(), s); + assertEquals(net.getParam("1_b"), b1.getArr(), s); } //Compare evaluations diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index ebf8510ce..06c6a8f8a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.jupiter.api.AfterEach; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index fdac1af4e..3125c7099 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -31,10 +31,10 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.rules.Timeout; + import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; @@ -59,8 +59,7 @@ class ModelGuesserTest extends BaseDL4JTest { @TempDir public Path testDir; - @Rule - public Timeout timeout = Timeout.seconds(300); + @Test @DisplayName("Test Model Guess File") diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index e2d128bb8..4f2c3b380 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java index 57e3b4407..4188e642e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java @@ -29,9 +29,10 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.common.validation.ValidationResult; @@ -52,16 +53,15 @@ import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class ModelValidatorTests extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testMultiLayerNetworkValidation() throws Exception { - File f = testDir.newFolder(); + public void testMultiLayerNetworkValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test non-existent file File f0 = new File(f, "doesntExist.bin"); @@ -178,8 +178,8 @@ public class ModelValidatorTests extends BaseDL4JTest { @Test - public void testComputationGraphNetworkValidation() throws Exception { - File f = testDir.newFolder(); + public void testComputationGraphNetworkValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test non-existent file File f0 = new File(f, "doesntExist.bin"); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java index cabbdf369..fe708f04c 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.util.SerializationUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java index 5be40b14a..9f0ecd29c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java @@ -22,11 +22,11 @@ package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.util.UIDProvider; -import org.junit.Test; +import org.junit.jupiter.api.Test; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class TestUIDProvider extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestDataTypes.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestDataTypes.java index b0a60a76c..f055ea0f6 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestDataTypes.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestDataTypes.java @@ -36,7 +36,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -53,7 +53,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestDataTypes extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestUtils.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestUtils.java index 6954f57c1..c33ecc6a9 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestUtils.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/TestUtils.java @@ -49,8 +49,8 @@ import java.lang.reflect.Field; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class TestUtils { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/ValidateCuDNN.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/ValidateCuDNN.java index 5840ae697..cb2d27c9f 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/ValidateCuDNN.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/ValidateCuDNN.java @@ -32,8 +32,8 @@ import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.cuda.util.CuDNNValidationUtil; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationIdentity; @@ -190,7 +190,7 @@ public class ValidateCuDNN extends BaseDL4JTest { validateLayers(net, classesToTest, false, fShape, lShape, CuDNNValidationUtil.MAX_REL_ERROR, CuDNNValidationUtil.MIN_ABS_ERROR); } - @Test @Ignore //AB 2019/05/20 - https://github.com/eclipse/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later + @Test @Disabled //AB 2019/05/20 - https://github.com/eclipse/deeplearning4j/issues/5088 - ignored to get to "all passing" state for CI, and revisit later public void validateConvLayersLRN() { //Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else) Nd4j.getRandom().setSeed(12345); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java index 412f1b7ca..dad81fbd0 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -51,7 +51,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class ConvDataFormatTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/TestConvolution.java index c5c12c5b8..e11f6d1f3 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/TestConvolution.java @@ -43,9 +43,9 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + import org.nd4j.common.resources.Resources; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -63,7 +63,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * Created by Alex on 15/11/2016. diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CuDNNGradientChecks.java index 3ed85ad14..8782fbb6c 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CuDNNGradientChecks.java @@ -47,7 +47,7 @@ import org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper; import org.deeplearning4j.nn.layers.recurrent.LSTMHelper; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.function.Consumer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -63,8 +63,8 @@ import java.util.HashSet; import java.util.Random; import java.util.Set; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Created by Alex on 09/09/2016. diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnDropout.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnDropout.java index 30aa80f0a..51bd7adc3 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnDropout.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnDropout.java @@ -22,15 +22,15 @@ package org.deeplearning4j.cuda.lstm; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.cuda.dropout.CudnnDropoutHelper; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ValidateCudnnDropout extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java index 07af247f1..7dae3adaf 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.cuda.recurrent.CudnnLSTMHelper; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; @@ -44,7 +44,7 @@ import java.lang.reflect.Field; import java.util.Map; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * Created by Alex on 18/07/2017. diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/util/CuDNNValidationUtil.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/util/CuDNNValidationUtil.java index 184b03cb2..5750b2d1b 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/util/CuDNNValidationUtil.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/util/CuDNNValidationUtil.java @@ -41,7 +41,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.lang.reflect.Field; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class CuDNNValidationUtil { diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java index 7c3c46dd7..e7f5f05c0 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java @@ -29,17 +29,19 @@ import org.deeplearning4j.graph.data.impl.DelimitedVertexLoader; import org.deeplearning4j.graph.graph.Graph; import org.deeplearning4j.graph.vertexfactory.StringVertexFactory; import org.deeplearning4j.graph.vertexfactory.VertexFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.io.ClassPathResource; import java.io.IOException; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestGraphLoading extends BaseDL4JTest { - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testEdgeListGraphLoading() throws IOException { ClassPathResource cpr = new ClassPathResource("deeplearning4j-graph/testgraph_7vertices.txt"); @@ -59,7 +61,8 @@ public class TestGraphLoading extends BaseDL4JTest { } } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testGraphLoading() throws IOException { ClassPathResource cpr = new ClassPathResource("deeplearning4j-graph/simplegraph.txt"); @@ -102,7 +105,8 @@ public class TestGraphLoading extends BaseDL4JTest { } } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testGraphLoadingWithVertices() throws IOException { ClassPathResource verticesCPR = new ClassPathResource("deeplearning4j-graph/test_graph_vertices.txt"); diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java index 0ac95ab28..3d295cb3d 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java @@ -28,18 +28,20 @@ import org.deeplearning4j.graph.data.impl.WeightedEdgeLineProcessor; import org.deeplearning4j.graph.graph.Graph; import org.deeplearning4j.graph.vertexfactory.StringVertexFactory; import org.deeplearning4j.graph.vertexfactory.VertexFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.io.ClassPathResource; import java.io.IOException; import java.util.List; -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestGraphLoadingWeighted extends BaseDL4JTest { - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testWeightedDirected() throws IOException { String path = new ClassPathResource("deeplearning4j-graph/WeightedGraph.txt").getTempFileFromArchive().getAbsolutePath(); @@ -79,7 +81,8 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest { } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testWeightedDirectedV2() throws Exception { String path = new ClassPathResource("deeplearning4j-graph/WeightedGraph.txt").getTempFileFromArchive().getAbsolutePath(); diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java index 5ab8ca342..cc486e3db 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java @@ -27,20 +27,21 @@ import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.iterator.RandomWalkIterator; import org.deeplearning4j.graph.iterator.WeightedRandomWalkIterator; import org.deeplearning4j.graph.vertexfactory.VertexFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.io.ClassPathResource; import java.util.HashSet; import java.util.List; import java.util.Set; -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestGraph extends BaseDL4JTest { - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testSimpleGraph() { Graph graph = new Graph<>(10, false, new VFactory()); @@ -94,7 +95,8 @@ public class TestGraph extends BaseDL4JTest { } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testRandomWalkIterator() { Graph graph = new Graph<>(10, false, new VFactory()); assertEquals(10, graph.numVertices()); @@ -124,8 +126,8 @@ public class TestGraph extends BaseDL4JTest { int left = (previous - 1 + 10) % 10; int right = (previous + 1) % 10; int current = sequence.next().vertexID(); - assertTrue("expected: " + left + " or " + right + ", got " + current, - current == left || current == right); + assertTrue(current == left || current == right, + "expected: " + left + " or " + right + ", got " + current); seqCount++; previous = current; } @@ -137,7 +139,8 @@ public class TestGraph extends BaseDL4JTest { assertEquals(10, startIdxSet.size()); } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testWeightedRandomWalkIterator() throws Exception { //Load a directed, weighted graph from file diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java index cf9f08549..9e19fd53c 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java @@ -26,8 +26,9 @@ import org.deeplearning4j.graph.graph.Graph; import org.deeplearning4j.graph.iterator.GraphWalkIterator; import org.deeplearning4j.graph.iterator.RandomWalkIterator; import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,7 +36,7 @@ import org.nd4j.common.io.ClassPathResource; import java.io.IOException; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class DeepWalkGradientCheck extends BaseDL4JTest { @@ -43,12 +44,13 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { public static final double MAX_REL_ERROR = 1e-3; public static final double MIN_ABS_ERROR = 1e-5; - @Before + @BeforeEach public void before() { Nd4j.setDataType(DataType.DOUBLE); } - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void checkGradients() throws IOException { ClassPathResource cpr = new ClassPathResource("deeplearning4j-graph/testgraph_7vertices.txt"); @@ -89,7 +91,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { assertTrue(probs[j] >= 0.0 && probs[j] <= 1.0); sumProb += probs[j]; } - assertTrue("Output probabilities do not sum to 1.0", Math.abs(sumProb - 1.0) < 1e-5); + assertTrue(Math.abs(sumProb - 1.0) < 1e-5, "Output probabilities do not sum to 1.0"); for (int j = 0; j < 7; j++) { //out //p(j|i) @@ -195,7 +197,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { - @Test(timeout = 60000L) + @Test() + @Timeout(60000) public void checkGradients2() throws IOException { double minAbsError = 1e-5; @@ -239,8 +242,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { assertTrue(probs[j] >= 0.0 && probs[j] <= 1.0); sumProb += probs[j]; } - assertTrue("Output probabilities do not sum to 1.0 (i=" + i + "), sum=" + sumProb, - Math.abs(sumProb - 1.0) < 1e-5); + assertTrue(Math.abs(sumProb - 1.0) < 1e-5, + "Output probabilities do not sum to 1.0 (i=" + i + "), sum=" + sumProb); for (int j = 0; j < nVertices; j++) { //out //p(j|i) diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java index 83f3db319..1415a4dde 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java @@ -32,30 +32,32 @@ import org.deeplearning4j.graph.iterator.parallel.WeightedRandomWalkGraphIterato import org.deeplearning4j.graph.models.GraphVectors; import org.deeplearning4j.graph.models.loader.GraphVectorSerializer; import org.deeplearning4j.graph.vertexfactory.StringVertexFactory; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestDeepWalk extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Override public long getTimeoutMilliseconds() { return 120_000L; //Increase timeout due to intermittently slow CI machines } - @Test(timeout = 60000L) + @Test() + @Timeout(60000) public void testBasic() throws IOException { //Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions @@ -93,7 +95,8 @@ public class TestDeepWalk extends BaseDL4JTest { } } - @Test(timeout = 180000L) + @Test() + @Timeout(180000) public void testParallel() { IGraph graph = generateRandomGraph(30, 4); @@ -127,7 +130,8 @@ public class TestDeepWalk extends BaseDL4JTest { } - @Test(timeout = 60000L) + @Test() + @Timeout(60000) public void testVerticesNearest() { int nVertices = 20; @@ -172,8 +176,9 @@ public class TestDeepWalk extends BaseDL4JTest { } } - @Test(timeout = 60000L) - public void testLoadingSaving() throws IOException { + @Test() + @Timeout(60000) + public void testLoadingSaving(@TempDir Path testDir) throws IOException { String out = "dl4jdwtestout.txt"; int nVertices = 20; @@ -187,7 +192,7 @@ public class TestDeepWalk extends BaseDL4JTest { deepWalk.fit(graph, 10); - File f = testDir.newFile(out); + File f = new File(testDir.toFile(),out); GraphVectorSerializer.writeGraphVectors(deepWalk, f.getAbsolutePath()); GraphVectors vectors = @@ -209,7 +214,8 @@ public class TestDeepWalk extends BaseDL4JTest { } } - @Test(timeout = 180000L) + @Test() + @Timeout(180000) public void testDeepWalk13Vertices() throws IOException { int nVertices = 13; @@ -245,7 +251,8 @@ public class TestDeepWalk extends BaseDL4JTest { deepWalk.getVertexVector(i); } - @Test(timeout = 60000L) + @Test() + @Timeout(60000) public void testDeepWalkWeightedParallel() throws IOException { //Load graph diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java index 5ac89c984..3a5d66b69 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java @@ -21,17 +21,19 @@ package org.deeplearning4j.graph.models.deepwalk; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestGraphHuffman extends BaseDL4JTest { - @Test(timeout = 10000L) + @Test() + @Timeout(10000) public void testGraphHuffman() { //Simple test case from Weiss - Data Structires and Algorithm Analysis in Java 3ed pg436 //Huffman code is non-unique, but length of code for each node is same for all Huffman codes diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java index 271523704..2c293b5d5 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java +++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java @@ -50,7 +50,7 @@ import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.activations.Activation; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java index 7b9d25445..12a00d4f7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java @@ -28,7 +28,7 @@ import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class KerasTestUtils { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java index c6d3a7d84..d61fb283e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java @@ -24,10 +24,12 @@ import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.utils.DL4JKerasModelValidator; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.resources.Resources; import org.nd4j.common.validation.ValidationResult; @@ -36,19 +38,19 @@ import java.io.File; import java.io.FileInputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled public class MiscTests extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testMultiThreadedLoading() throws Exception { final File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); @@ -87,11 +89,12 @@ public class MiscTests extends BaseDL4JTest { } boolean result = latch.await(30000, TimeUnit.MILLISECONDS); - assertTrue("Latch did not get to 0", result); - assertEquals("Number of errors", 0, errors.get()); + assertTrue(result,"Latch did not get to 0"); + assertEquals( 0, errors.get(),"Number of errors"); } - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testLoadFromStream() throws Exception { final File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); @@ -101,16 +104,17 @@ public class MiscTests extends BaseDL4JTest { } } - @Test(timeout = 60000L) - public void testModelValidatorSequential() throws Exception { - File f = testDir.newFolder(); + @Test() + @Timeout(60000L) + public void testModelValidatorSequential(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.h5"); ValidationResult vr0 = DL4JKerasModelValidator.validateKerasSequential(fNonExistent); assertFalse(vr0.isValid()); assertEquals("Keras Sequential Model HDF5", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue( vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); System.out.println(vr0.toString()); //Test empty file: @@ -120,7 +124,7 @@ public class MiscTests extends BaseDL4JTest { ValidationResult vr1 = DL4JKerasModelValidator.validateKerasSequential(fEmpty); assertEquals("Keras Sequential Model HDF5", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); System.out.println(vr1.toString()); //Test directory (not zip file) @@ -130,7 +134,7 @@ public class MiscTests extends BaseDL4JTest { ValidationResult vr2 = DL4JKerasModelValidator.validateKerasSequential(directory); assertEquals("Keras Sequential Model HDF5", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue( vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); System.out.println(vr2.toString()); //Test Keras HDF5 format: @@ -140,7 +144,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Sequential Model HDF5", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt")); + assertTrue(s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt"),s); System.out.println(vr3.toString()); //Test corrupted npy format: @@ -156,7 +160,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Sequential Model HDF5", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt")); + assertTrue(s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt"),s); System.out.println(vr4.toString()); @@ -169,9 +173,10 @@ public class MiscTests extends BaseDL4JTest { System.out.println(vr4.toString()); } - @Test(timeout = 60000L) - public void testModelValidatorFunctional() throws Exception { - File f = testDir.newFolder(); + @Test() + @Timeout(60000L) + public void testModelValidatorFunctional(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //String modelPath = "modelimport/keras/examples/functional_lstm/lstm_functional_tf_keras_2.h5"; //Test not existent file: @@ -179,7 +184,7 @@ public class MiscTests extends BaseDL4JTest { ValidationResult vr0 = DL4JKerasModelValidator.validateKerasFunctional(fNonExistent); assertFalse(vr0.isValid()); assertEquals("Keras Functional Model HDF5", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue( vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); System.out.println(vr0.toString()); //Test empty file: @@ -189,7 +194,7 @@ public class MiscTests extends BaseDL4JTest { ValidationResult vr1 = DL4JKerasModelValidator.validateKerasFunctional(fEmpty); assertEquals("Keras Functional Model HDF5", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue( vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); System.out.println(vr1.toString()); //Test directory (not zip file) @@ -199,7 +204,7 @@ public class MiscTests extends BaseDL4JTest { ValidationResult vr2 = DL4JKerasModelValidator.validateKerasFunctional(directory); assertEquals("Keras Functional Model HDF5", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue( vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); System.out.println(vr2.toString()); //Test Keras HDF5 format: @@ -209,13 +214,13 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Functional Model HDF5", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("Keras") && s.contains("Functional") && s.contains("corrupt")); + assertTrue(s.contains("Keras") && s.contains("Functional") && s.contains("corrupt"),s); System.out.println(vr3.toString()); //Test corrupted npy format: File fValid = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); - for( int i=0; i<30; i++ ){ + for( int i = 0; i < 30; i++) { numpyBytes[i] = 0; } File fCorrupt = new File(f, "corrupt.h5"); @@ -225,7 +230,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Functional Model HDF5", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("Keras") && s.contains("Functional") && s.contains("corrupt")); + assertTrue( s.contains("Keras") && s.contains("Functional") && s.contains("corrupt"),s); System.out.println(vr4.toString()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index b8ebd522c..062866adb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -20,7 +20,6 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; -import junit.framework.TestCase; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.NumberedFileInputSplit; @@ -33,11 +32,11 @@ import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.impl.ActivationHardSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.ndarray.INDArray; @@ -51,24 +50,21 @@ import org.nd4j.common.resources.Resources; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.nio.file.Path; import java.util.Arrays; import java.util.LinkedList; import java.util.List; -import static junit.framework.TestCase.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class FullModelComparisons extends BaseDL4JTest { ClassLoader classLoader = FullModelComparisons.class.getClassLoader(); - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Rule - public Timeout timeout = Timeout.seconds(300); @Test - public void lstmTest() throws IOException, UnsupportedKerasConfigurationException, + public void lstmTest(@TempDir Path testDir) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException, InterruptedException { String modelPath = "modelimport/keras/fullconfigs/lstm/lstm_th_keras_2_config.json"; @@ -106,26 +102,26 @@ public class FullModelComparisons extends BaseDL4JTest { // INDArray W = firstLstm.getParam("W"); assertTrue(Arrays.equals(W.shape(), new long[]{nIn, 4 * nOut})); - TestCase.assertEquals(W.getDouble(0, 288), -0.30737767, 1e-7); - TestCase.assertEquals(W.getDouble(0, 289), -0.5845409, 1e-7); - TestCase.assertEquals(W.getDouble(1, 288), -0.44083247, 1e-7); - TestCase.assertEquals(W.getDouble(11, 288), 0.017539706, 1e-7); - TestCase.assertEquals(W.getDouble(0, 96), 0.2707935, 1e-7); - TestCase.assertEquals(W.getDouble(0, 192), -0.19856165, 1e-7); - TestCase.assertEquals(W.getDouble(0, 0), 0.15368782, 1e-7); + assertEquals(W.getDouble(0, 288), -0.30737767, 1e-7); + assertEquals(W.getDouble(0, 289), -0.5845409, 1e-7); + assertEquals(W.getDouble(1, 288), -0.44083247, 1e-7); + assertEquals(W.getDouble(11, 288), 0.017539706, 1e-7); + assertEquals(W.getDouble(0, 96), 0.2707935, 1e-7); + assertEquals(W.getDouble(0, 192), -0.19856165, 1e-7); + assertEquals(W.getDouble(0, 0), 0.15368782, 1e-7); INDArray RW = firstLstm.getParam("RW"); assertTrue(Arrays.equals(RW.shape(), new long[]{nOut, 4 * nOut})); - TestCase.assertEquals(RW.getDouble(0, 288), 0.15112677, 1e-7); + assertEquals(RW.getDouble(0, 288), 0.15112677, 1e-7); INDArray b = firstLstm.getParam("b"); assertTrue(Arrays.equals(b.shape(), new long[]{1, 4 * nOut})); - TestCase.assertEquals(b.getDouble(0, 288), -0.36940336, 1e-7); // Keras I - TestCase.assertEquals(b.getDouble(0, 96), 0.6031118, 1e-7); // Keras F - TestCase.assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O - TestCase.assertEquals(b.getDouble(0, 0), -0.2587392, 1e-7); // Keras C + assertEquals(b.getDouble(0, 288), -0.36940336, 1e-7); // Keras I + assertEquals(b.getDouble(0, 96), 0.6031118, 1e-7); // Keras F + assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O + assertEquals(b.getDouble(0, 0), -0.2587392, 1e-7); // Keras C // 2. Layer LSTM secondLstm = (LSTM) ((LastTimeStepLayer) model.getLayer(1)).getUnderlying(); @@ -142,21 +138,21 @@ public class FullModelComparisons extends BaseDL4JTest { W = secondLstm.getParam("W"); assertTrue(Arrays.equals(W.shape(), new long[]{nIn, 4 * nOut})); - TestCase.assertEquals(W.getDouble(0, 288), -0.7559755, 1e-7); + assertEquals(W.getDouble(0, 288), -0.7559755, 1e-7); RW = secondLstm.getParam("RW"); assertTrue(Arrays.equals(RW.shape(), new long[]{nOut, 4 * nOut})); - TestCase.assertEquals(RW.getDouble(0, 288), -0.33184892, 1e-7); + assertEquals(RW.getDouble(0, 288), -0.33184892, 1e-7); b = secondLstm.getParam("b"); assertTrue(Arrays.equals(b.shape(), new long[]{1, 4 * nOut})); - TestCase.assertEquals(b.getDouble(0, 288), -0.2223678, 1e-7); - TestCase.assertEquals(b.getDouble(0, 96), 0.73556226, 1e-7); - TestCase.assertEquals(b.getDouble(0, 192), -0.63227624, 1e-7); - TestCase.assertEquals(b.getDouble(0, 0), 0.06636357, 1e-7); + assertEquals(b.getDouble(0, 288), -0.2223678, 1e-7); + assertEquals(b.getDouble(0, 96), 0.73556226, 1e-7); + assertEquals(b.getDouble(0, 192), -0.63227624, 1e-7); + assertEquals(b.getDouble(0, 0), 0.06636357, 1e-7); - File dataDir = testDir.newFolder(); + File dataDir = testDir.toFile(); SequenceRecordReader reader = new CSVSequenceRecordReader(0, ";"); new ClassPathResource("deeplearning4j-modelimport/data/", classLoader).copyDirectory(dataDir); @@ -179,19 +175,19 @@ public class FullModelComparisons extends BaseDL4JTest { INDArray kerasPredictions = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/lstm/predictions.npy")); for (int i = 0; i < 283; i++) { - TestCase.assertEquals(kerasPredictions.getDouble(i), dl4jPredictions.getDouble(i), 1e-7); + assertEquals(kerasPredictions.getDouble(i), dl4jPredictions.getDouble(i), 1e-7); } INDArray ones = Nd4j.ones(1, 4, 12); INDArray predOnes = model.output(ones); - TestCase.assertEquals(predOnes.getDouble(0, 0), 0.7216, 1e-4); + assertEquals(predOnes.getDouble(0, 0), 0.7216, 1e-4); } @Test() - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") public void cnnBatchNormTest() throws IOException, UnsupportedKerasConfigurationException, @@ -217,13 +213,13 @@ public class FullModelComparisons extends BaseDL4JTest { INDArray kerasOutput = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn/predictions.npy")); for (int i = 0; i < 5; i++) { - TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-4); + assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-4); } } @Test() - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") public void cnnBatchNormLargerTest() throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { @@ -249,7 +245,7 @@ public class FullModelComparisons extends BaseDL4JTest { for (int i = 0; i < 5; i++) { // TODO this should be a little closer - TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-2); + assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-2); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 23a873c13..b741d80d6 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -40,12 +40,12 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; -import static junit.framework.TestCase.assertTrue; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; + import org.junit.jupiter.api.DisplayName; -import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.*; + @Slf4j @DisplayName("Keras 2 Model Configuration Test") class Keras2ModelConfigurationTest extends BaseDL4JTest { @@ -301,7 +301,7 @@ class Keras2ModelConfigurationTest extends BaseDL4JTest { @Test @DisplayName("Reshape Embedding Concat Test") - // @Ignore("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441") + // @Disabled("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441") void ReshapeEmbeddingConcatTest() throws Exception { try (InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { ComputationGraphConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is).enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java index 22dbc2ba3..02cfc49fa 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN; import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasPoolHelper; import org.deeplearning4j.util.ModelSerializer; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.io.File; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java index 4e7fe2e21..9a61e90a0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java @@ -23,7 +23,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.samediff.SDVariable; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java index 2a03cebd9..7f2288254 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.samediff.SDVariable; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index c0d1c4051..1c859252b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -44,7 +44,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java index 29617de12..f07486439 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.junit.jupiter.api.Disabled; -import org.junit.Rule; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.resources.Resources; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java index be16c50bd..ad73a4c00 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java @@ -25,12 +25,12 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasActivationLayer extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java index 47bcaa328..380e93a52 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/optimizers/OptimizerImport.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.common.util.DL4JFileUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import java.io.File; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java index 8748de768..712e3e41c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java @@ -22,7 +22,8 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; import java.io.IOException; @@ -34,7 +35,8 @@ import java.io.IOException; */ public class TimeSeriesGeneratorImportTest extends BaseDL4JTest { - @Test(timeout=300000) + @Test() + @Timeout(300000) public void importTimeSeriesTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/timeseries_generator.json"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java index 20a9d18e3..91c86ba81 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorTest.java @@ -22,12 +22,12 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TimeSeriesGeneratorTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java index 26fc8c932..3f7d43723 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java @@ -22,12 +22,13 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; import java.io.IOException; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * Import Keras Tokenizer @@ -39,7 +40,8 @@ public class TokenizerImportTest extends BaseDL4JTest { ClassLoader classLoader = getClass().getClassLoader(); - @Test(timeout=300000) + @Test() + @Timeout(300000) public void importTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer.json"; @@ -55,7 +57,9 @@ public class TokenizerImportTest extends BaseDL4JTest { } - @Test(timeout=300000) + + @Test() + @Timeout(300000) public void importNumWordsNullTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer_num_words_null.json"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java index 5a173a9d2..6e421b404 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerTest.java @@ -21,14 +21,14 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessing.text; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * Tests for Keras Tokenizer diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 830b4cde0..32da2101f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -28,9 +28,10 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.resources.Resources; @@ -39,17 +40,16 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.StandardCopyOption; import java.util.Arrays; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class KerasWeightSettingTests extends BaseDL4JTest { - @Rule - public final TemporaryFolder testDir = new TemporaryFolder(); @Override public long getTimeoutMilliseconds() { @@ -57,84 +57,84 @@ public class KerasWeightSettingTests extends BaseDL4JTest { } @Test - public void testSimpleLayersWithWeights() throws Exception { + public void testSimpleLayersWithWeights(@TempDir Path tempDir) throws Exception { int[] kerasVersions = new int[]{1, 2}; String[] backends = new String[]{"tensorflow", "theano"}; for (int version : kerasVersions) { for (String backend : backends) { String densePath = "modelimport/keras/weights/dense_" + backend + "_" + version + ".h5"; - importDense(densePath); + importDense(tempDir,densePath); String conv2dPath = "modelimport/keras/weights/conv2d_" + backend + "_" + version + ".h5"; - importConv2D(conv2dPath); + importConv2D(tempDir,conv2dPath); if (version == 2 && backend.equals("tensorflow")) { // TODO should work for theano String conv2dReshapePath = "modelimport/keras/weights/conv2d_reshape_" + backend + "_" + version + ".h5"; System.out.println(backend + "_" + version); - importConv2DReshape(conv2dReshapePath); + importConv2DReshape(tempDir,conv2dReshapePath); } if (version == 2) { String conv1dFlattenPath = "modelimport/keras/weights/embedding_conv1d_flatten_" + backend + "_" + version + ".h5"; - importConv1DFlatten(conv1dFlattenPath); + importConv1DFlatten(tempDir,conv1dFlattenPath); } String lstmPath = "modelimport/keras/weights/lstm_" + backend + "_" + version + ".h5"; - importLstm(lstmPath); + importLstm(tempDir,lstmPath); String embeddingLstmPath = "modelimport/keras/weights/embedding_lstm_" + backend + "_" + version + ".h5"; - importEmbeddingLstm(embeddingLstmPath); + importEmbeddingLstm(tempDir,embeddingLstmPath); if (version == 2) { String embeddingConv1dExtendedPath = "modelimport/keras/weights/embedding_conv1d_extended_" + backend + "_" + version + ".h5"; - importEmbeddingConv1DExtended(embeddingConv1dExtendedPath); + importEmbeddingConv1DExtended(tempDir,embeddingConv1dExtendedPath); } if (version == 2) { String embeddingConv1dPath = "modelimport/keras/weights/embedding_conv1d_" + backend + "_" + version + ".h5"; - importEmbeddingConv1D(embeddingConv1dPath); + importEmbeddingConv1D(tempDir,embeddingConv1dPath); } String simpleRnnPath = "modelimport/keras/weights/simple_rnn_" + backend + "_" + version + ".h5"; - importSimpleRnn(simpleRnnPath); + importSimpleRnn(tempDir,simpleRnnPath); String bidirectionalLstmPath = "modelimport/keras/weights/bidirectional_lstm_" + backend + "_" + version + ".h5"; - importBidirectionalLstm(bidirectionalLstmPath); + importBidirectionalLstm(tempDir,bidirectionalLstmPath); String bidirectionalLstmNoSequencesPath = "modelimport/keras/weights/bidirectional_lstm_no_return_sequences_" + backend + "_" + version + ".h5"; - importBidirectionalLstm(bidirectionalLstmNoSequencesPath); + importBidirectionalLstm(tempDir,bidirectionalLstmNoSequencesPath); if (version == 2 && backend.equals("tensorflow")) { String batchToConv2dPath = "modelimport/keras/weights/batch_to_conv2d_" + backend + "_" + version + ".h5"; - importBatchNormToConv2D(batchToConv2dPath); + importBatchNormToConv2D(tempDir,batchToConv2dPath); } if (backend.equals("tensorflow") && version == 2) { // TODO should work for theano String simpleSpaceToBatchPath = "modelimport/keras/weights/space_to_depth_simple_" + backend + "_" + version + ".h5"; - importSimpleSpaceToDepth(simpleSpaceToBatchPath); + importSimpleSpaceToDepth(tempDir,simpleSpaceToBatchPath); } if (backend.equals("tensorflow") && version == 2) { String graphSpaceToBatchPath = "modelimport/keras/weights/space_to_depth_graph_" + backend + "_" + version + ".h5"; - importGraphSpaceToDepth(graphSpaceToBatchPath); + importGraphSpaceToDepth(tempDir,graphSpaceToBatchPath); } if (backend.equals("tensorflow") && version == 2) { String sepConvPath = "modelimport/keras/weights/sepconv2d_" + backend + "_" + version + ".h5"; - importSepConv2D(sepConvPath); + importSepConv2D(tempDir,sepConvPath); } } } @@ -144,8 +144,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { log.info("***** Successfully imported " + modelPath); } - private void importDense(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, true); + private void importDense(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, true); INDArray weights = model.getLayer(0).getParam("W"); val weightShape = weights.shape(); @@ -157,8 +157,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importSepConv2D(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importSepConv2D(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); INDArray depthWeights = model.getLayer(0).getParam("W"); val depthWeightShape = depthWeights.shape(); @@ -193,8 +193,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importConv2D(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importConv2D(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); INDArray weights = model.getLayer(0).getParam("W"); val weightShape = weights.shape(); @@ -209,8 +209,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { } - private void importConv2DReshape(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importConv2DReshape(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); int nOut = 12; @@ -223,8 +223,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importConv1DFlatten(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importConv1DFlatten(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); int nOut = 6; int inputLength = 10; @@ -242,15 +242,15 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importBatchNormToConv2D(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importBatchNormToConv2D(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); model.summary(); logSuccess(modelPath); } - private void importSimpleSpaceToDepth(String modelPath) throws Exception { + private void importSimpleSpaceToDepth(Path tempDir,String modelPath) throws Exception { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); INDArray input = Nd4j.zeros(10, 6, 6, 4); INDArray output = model.output(input); @@ -258,9 +258,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importGraphSpaceToDepth(String modelPath) throws Exception { + private void importGraphSpaceToDepth(Path tempDir,String modelPath) throws Exception { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); - ComputationGraph model = loadComputationalGraph(modelPath, false); + ComputationGraph model = loadComputationalGraph(tempDir,modelPath, false); // INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)}; @@ -270,15 +270,15 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importLstm(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importLstm(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); model.summary(); // TODO: check weights logSuccess(modelPath); } - private void importEmbeddingLstm(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importEmbeddingLstm(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); int nIn = 4; int nOut = 6; @@ -297,13 +297,13 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importEmbeddingConv1DExtended(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importEmbeddingConv1DExtended(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); logSuccess(modelPath); } - private void importEmbeddingConv1D(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importEmbeddingConv1D(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); int nIn = 4; int nOut = 6; @@ -327,22 +327,22 @@ public class KerasWeightSettingTests extends BaseDL4JTest { logSuccess(modelPath); } - private void importSimpleRnn(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importSimpleRnn(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); model.summary(); logSuccess(modelPath); // TODO: check weights } - private void importBidirectionalLstm(String modelPath) throws Exception { - MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); + private void importBidirectionalLstm(Path tempDir,String modelPath) throws Exception { + MultiLayerNetwork model = loadMultiLayerNetwork(tempDir,modelPath, false); model.summary(); logSuccess(modelPath); // TODO: check weights } - private MultiLayerNetwork loadMultiLayerNetwork(String modelPath, boolean training) throws Exception { - File modelFile = createTempFile("temp", ".h5"); + private MultiLayerNetwork loadMultiLayerNetwork(Path tempDir, String modelPath, boolean training) throws Exception { + File modelFile = createTempFile(tempDir,"temp", ".h5"); try(InputStream is = Resources.asStream(modelPath)) { Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); return new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) @@ -350,8 +350,8 @@ public class KerasWeightSettingTests extends BaseDL4JTest { } } - private ComputationGraph loadComputationalGraph(String modelPath, boolean training) throws Exception { - File modelFile = createTempFile("temp", ".h5"); + private ComputationGraph loadComputationalGraph(Path tempDir,String modelPath, boolean training) throws Exception { + File modelFile = createTempFile(tempDir,"temp", ".h5"); try(InputStream is = Resources.asStream(modelPath)) { Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); return new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) @@ -359,8 +359,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest { } } - private File createTempFile(String prefix, String suffix) throws IOException { - return testDir.newFile(prefix + "-" + System.nanoTime() + suffix); + private File createTempFile(Path tempDir,String prefix, String suffix) throws IOException { + File createTempFile = Files.createTempFile(tempDir,prefix + "-" + System.nanoTime(),suffix).toFile(); + return createTempFile; } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index 0cd3e8071..a4ea94d8b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -70,6 +70,12 @@ ${junit.version} test + + org.hamcrest + hamcrest-core + 1.3 + test + org.mockito mockito-core diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java deleted file mode 100644 index 764f735bf..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j; - -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; -import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.common.primitives.Pair; - -import java.io.File; -import java.util.ArrayList; -import java.util.List; - -@Slf4j -public class TsneTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 180000L; - } - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Override - public DataType getDataType() { - return DataType.FLOAT; - } - - @Override - public DataType getDefaultFPDataType() { - return DataType.FLOAT; - } - -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java index ae89c6d1c..c7e94b2fa 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java @@ -24,8 +24,10 @@ package org.deeplearning4j.bagofwords.vectorizer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.deeplearning4j.models.word2vec.VocabWord; @@ -34,7 +36,7 @@ import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenc import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -42,13 +44,13 @@ import org.nd4j.common.util.SerializationUtils; import java.io.File; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assume.assumeNotNull; +import static org.junit.jupiter.api.Assertions.*; /** *@author Adam Gibson @@ -56,14 +58,10 @@ import static org.junit.Assume.assumeNotNull; @Slf4j public class BagOfWordsVectorizerTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - - - @Test(timeout = 60000L) - public void testBagOfWordsVectorizer() throws Exception { - val rootDir = testDir.newFolder(); + @Test() + @Timeout(60000L) + public void testBagOfWordsVectorizer(@TempDir Path testDir) throws Exception { + val rootDir = testDir.toFile(); ClassPathResource resource = new ClassPathResource("rootdir/"); resource.copyDirectory(rootDir); @@ -72,15 +70,15 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest { TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); BagOfWordsVectorizer vectorizer = new BagOfWordsVectorizer.Builder().setMinWordFrequency(1) - .setStopWords(new ArrayList()).setTokenizerFactory(tokenizerFactory).setIterator(iter) - .allowParallelTokenization(false) - // .labels(labels) - // .cleanup(true) - .build(); + .setStopWords(new ArrayList()).setTokenizerFactory(tokenizerFactory).setIterator(iter) + .allowParallelTokenization(false) + // .labels(labels) + // .cleanup(true) + .build(); vectorizer.fit(); VocabWord word = vectorizer.getVocabCache().wordFor("file."); - assumeNotNull(word); + assertNotNull(word); assertEquals(word, vectorizer.getVocabCache().tokenFor("file.")); assertEquals(2, vectorizer.getVocabCache().totalNumberOfDocs()); @@ -138,7 +136,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest { assertNotEquals(idx2, idx1); // Serialization check - File tempFile = createTempFile("fdsf", "fdfsdf"); + File tempFile = createTempFile(testDir,"fdsf", "fdfsdf"); tempFile.deleteOnExit(); SerializationUtils.saveObject(vectorizer, tempFile); @@ -150,8 +148,9 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest { assertEquals(array, dataSet.getFeatures()); } - private File createTempFile(String prefix, String suffix) throws IOException { - return testDir.newFile(prefix + "-" + System.nanoTime() + suffix); + private File createTempFile(Path tempDir,String prefix, String suffix) throws IOException { + File newFile = Files.createTempFile(tempDir,prefix + "-" + System.nanoTime(),suffix).toFile(); + return newFile; } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java index 9c1c31c85..2d5ce3bb7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java @@ -23,8 +23,10 @@ package org.deeplearning4j.bagofwords.vectorizer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; @@ -39,20 +41,21 @@ import org.deeplearning4j.text.tokenization.tokenizer.DefaultTokenizer; import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.common.util.SerializationUtils; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.Assert.*; -import static org.junit.Assume.assumeNotNull; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson @@ -60,31 +63,31 @@ import static org.junit.Assume.assumeNotNull; @Slf4j public class TfidfVectorizerTest extends BaseDL4JTest { - @Rule - public final TemporaryFolder testDir = new TemporaryFolder(); - @Test(timeout = 60000L) - public void testTfIdfVectorizer() throws Exception { - val rootDir = testDir.newFolder(); + + @Test() + @Timeout(60000L) + public void testTfIdfVectorizer(@TempDir Path testDir) throws Exception { + val rootDir = testDir.toFile(); ClassPathResource resource = new ClassPathResource("tripledir/"); resource.copyDirectory(rootDir); - + assertTrue(rootDir.isDirectory()); LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir); TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); TfidfVectorizer vectorizer = new TfidfVectorizer.Builder().setMinWordFrequency(1) - .setStopWords(new ArrayList()).setTokenizerFactory(tokenizerFactory).setIterator(iter) - .allowParallelTokenization(false) - // .labels(labels) - // .cleanup(true) - .build(); + .setStopWords(new ArrayList()).setTokenizerFactory(tokenizerFactory).setIterator(iter) + .allowParallelTokenization(false) + // .labels(labels) + // .cleanup(true) + .build(); vectorizer.fit(); VocabWord word = vectorizer.getVocabCache().wordFor("file."); - assumeNotNull(word); + assertNotNull(word); assertEquals(word, vectorizer.getVocabCache().tokenFor("file.")); assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs()); @@ -128,7 +131,8 @@ public class TfidfVectorizerTest extends BaseDL4JTest { assertEquals(1, cnt); - File tempFile = testDir.newFile("somefile.bin"); + + File tempFile = Files.createTempFile(testDir,"somefile","bin").toFile(); tempFile.delete(); SerializationUtils.saveObject(vectorizer, tempFile); @@ -152,24 +156,24 @@ public class TfidfVectorizerTest extends BaseDL4JTest { List docs = new ArrayList<>(2); docs.add(doc1); docs.add(doc2); - + LabelAwareIterator iterator = new SimpleLabelAwareIterator(docs); TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); TfidfVectorizer vectorizer = new TfidfVectorizer - .Builder() - .setMinWordFrequency(1) - .setStopWords(new ArrayList()) - .setTokenizerFactory(tokenizerFactory) - .setIterator(iterator) - .allowParallelTokenization(false) - .build(); + .Builder() + .setMinWordFrequency(1) + .setStopWords(new ArrayList()) + .setTokenizerFactory(tokenizerFactory) + .setIterator(iterator) + .allowParallelTokenization(false) + .build(); vectorizer.fit(); DataSet dataset = vectorizer.vectorize("it meows like a cat", "cat"); assertNotNull(dataset); - + LabelsSource source = vectorizer.getLabelsSource(); assertEquals(2, source.getNumberOfLabelsUsed()); List labels = source.getLabels(); @@ -177,7 +181,8 @@ public class TfidfVectorizerTest extends BaseDL4JTest { assertEquals("cat", labels.get(1)); } - @Test(timeout = 10000L) + @Test() + @Timeout(10000L) public void testParallelFlag1() throws Exception { val vectorizer = new TfidfVectorizer.Builder() .allowParallelTokenization(false) @@ -187,53 +192,61 @@ public class TfidfVectorizerTest extends BaseDL4JTest { } - @Test(expected = ND4JIllegalStateException.class, timeout = 20000L) + @Test() + @Timeout(20000L) public void testParallelFlag2() throws Exception { - val collection = new ArrayList(); - collection.add("First string"); - collection.add("Second string"); - collection.add("Third string"); - collection.add(""); - collection.add("Fifth string"); + assertThrows(ND4JIllegalStateException.class,() -> { + val collection = new ArrayList(); + collection.add("First string"); + collection.add("Second string"); + collection.add("Third string"); + collection.add(""); + collection.add("Fifth string"); // collection.add("caboom"); - val vectorizer = new TfidfVectorizer.Builder() - .allowParallelTokenization(false) - .setIterator(new CollectionSentenceIterator(collection)) - .setTokenizerFactory(new ExplodingTokenizerFactory(8, -1)) - .build(); + val vectorizer = new TfidfVectorizer.Builder() + .allowParallelTokenization(false) + .setIterator(new CollectionSentenceIterator(collection)) + .setTokenizerFactory(new ExplodingTokenizerFactory(8, -1)) + .build(); - vectorizer.buildVocab(); + vectorizer.buildVocab(); - log.info("Fitting vectorizer..."); + log.info("Fitting vectorizer..."); + + vectorizer.fit(); + }); - vectorizer.fit(); } - @Test(expected = ND4JIllegalStateException.class, timeout = 20000L) + @Test() + @Timeout(20000L) public void testParallelFlag3() throws Exception { - val collection = new ArrayList(); - collection.add("First string"); - collection.add("Second string"); - collection.add("Third string"); - collection.add(""); - collection.add("Fifth string"); - collection.add("Long long long string"); - collection.add("Sixth string"); + assertThrows(ND4JIllegalStateException.class,() -> { + val collection = new ArrayList(); + collection.add("First string"); + collection.add("Second string"); + collection.add("Third string"); + collection.add(""); + collection.add("Fifth string"); + collection.add("Long long long string"); + collection.add("Sixth string"); - val vectorizer = new TfidfVectorizer.Builder() - .allowParallelTokenization(false) - .setIterator(new CollectionSentenceIterator(collection)) - .setTokenizerFactory(new ExplodingTokenizerFactory(-1, 4)) - .build(); + val vectorizer = new TfidfVectorizer.Builder() + .allowParallelTokenization(false) + .setIterator(new CollectionSentenceIterator(collection)) + .setTokenizerFactory(new ExplodingTokenizerFactory(-1, 4)) + .build(); - vectorizer.buildVocab(); + vectorizer.buildVocab(); - log.info("Fitting vectorizer..."); + log.info("Fitting vectorizer..."); + + vectorizer.fit(); + }); - vectorizer.fit(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index 405ebede5..76b4bc64d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -27,7 +27,8 @@ import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider; import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -43,7 +44,7 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestBertIterator extends BaseDL4JTest { @@ -132,7 +133,8 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000) public void testBertUnsupervised() throws Exception { int minibatchSize = 2; TestSentenceHelper testHelper = new TestSentenceHelper(); @@ -163,7 +165,8 @@ public class TestBertIterator extends BaseDL4JTest { assertTrue(b.hasNext()); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000) public void testLengthHandling() throws Exception { int minibatchSize = 2; TestSentenceHelper testHelper = new TestSentenceHelper(); @@ -232,7 +235,8 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000) public void testMinibatchPadding() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); int minibatchSize = 3; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestCnnSentenceDataSetIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestCnnSentenceDataSetIterator.java index da3e5f8a4..1a274766a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestCnnSentenceDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestCnnSentenceDataSetIterator.java @@ -21,13 +21,13 @@ package org.deeplearning4j.iterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; +import org.junit.jupiter.api.BeforeEach; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -37,11 +37,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestCnnSentenceDataSetIterator extends BaseDL4JTest { - @Before + @BeforeEach public void before(){ Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } @@ -278,7 +278,7 @@ public class TestCnnSentenceDataSetIterator extends BaseDL4JTest { fail("Expected exception"); } catch (Throwable t){ String m = t.getMessage(); - assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary")); + assertTrue(m.contains("RemoveWord") && m.contains("vocabulary"), m); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java index c94ea4597..8b058ee6f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java @@ -22,8 +22,10 @@ package org.deeplearning4j.models.embeddings.inmemory; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; @@ -35,25 +37,25 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import java.io.File; +import java.nio.file.Path; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InMemoryLookupTableTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + @BeforeEach public void setUp() throws Exception { } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testConsumeOnEqualVocabs() throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -80,14 +82,14 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { assertEquals(244, cacheSource.numWords()); InMemoryLookupTable mem1 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).seed(17).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).seed(17).build(); mem1.resetWeights(true); InMemoryLookupTable mem2 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).seed(15).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).seed(15).build(); mem2.resetWeights(true); @@ -100,8 +102,9 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { } - @Test(timeout = 300000) - public void testConsumeOnNonEqualVocabs() throws Exception { + @Test() + @Timeout(300000) + public void testConsumeOnNonEqualVocabs(@TempDir Path testDir) throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -127,8 +130,8 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { assertEquals(244, cacheSource.numWords()); InMemoryLookupTable mem1 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).build(); mem1.resetWeights(true); @@ -137,7 +140,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { AbstractCache cacheTarget = new AbstractCache.Builder().build(); - val dir = testDir.newFolder(); + val dir = testDir.toFile(); new ClassPathResource("/paravec/labeled/").copyDirectory(dir); FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder() diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java index ecf5ac93b..422d6525b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializerTest.java @@ -34,10 +34,11 @@ import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.BeforeEach; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -46,22 +47,22 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.Collections; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; @Slf4j public class WordVectorSerializerTest extends BaseDL4JTest { private AbstractCache cache; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + + @BeforeEach public void setUp() throws Exception { cache = new AbstractCache.Builder().build(); @@ -253,7 +254,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { } @Test - public void weightLookupTable_Correct_WhenDeserialized() throws Exception { + public void weightLookupTable_Correct_WhenDeserialized(@TempDir Path testDir) throws Exception { INDArray syn0 = Nd4j.rand(DataType.FLOAT, 10, 2), syn1 = Nd4j.rand(DataType.FLOAT, 10, 2), @@ -269,7 +270,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { lookupTable.setSyn1(syn1); lookupTable.setSyn1Neg(syn1Neg); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File file = new File(dir, "lookupTable.txt"); WeightLookupTable deser = null; @@ -302,12 +303,12 @@ public class WordVectorSerializerTest extends BaseDL4JTest { } @Test - public void FastText_Correct_WhenDeserialized() throws IOException { + public void FastText_Correct_WhenDeserialized(@TempDir Path testDir) throws IOException { FastText fastText = FastText.builder().cbow(true).build(); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); WordVectorSerializer.writeWordVectors(fastText, new File(dir, "some.data")); FastText deser = null; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java index c0b73eddd..8ffd88219 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java @@ -25,9 +25,9 @@ import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.Word2Vec; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.ops.transforms.Transforms; import org.slf4j.Logger; @@ -35,14 +35,14 @@ import org.slf4j.LoggerFactory; import java.util.Collection; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore +@Disabled public class FlatModelUtilsTest extends BaseDL4JTest { private Word2Vec vec; private static final Logger log = LoggerFactory.getLogger(FlatModelUtilsTest.class); - @Before + @BeforeEach public void setUp() throws Exception { if (vec == null) { //vec = WordVectorSerializer.loadFullModel("/Users/raver119/develop/model.dat"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java index 00ff56be8..4e604c07a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java @@ -25,13 +25,13 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; public class WordVectorsImplTest extends BaseDL4JTest { @@ -39,7 +39,7 @@ public class WordVectorsImplTest extends BaseDL4JTest { private WeightLookupTable weightLookupTable; private WordVectorsImpl wordVectors; - @Before + @BeforeEach public void init() throws Exception { vocabCache = Mockito.mock(VocabCache.class); weightLookupTable = Mockito.mock(WeightLookupTable.class); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index 2d093df41..17d26a674 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -26,55 +26,54 @@ import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.primitives.Pair; import org.nd4j.common.resources.Resources; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled public class FastTextTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); + private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec"); - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testTrainSupervised() throws IOException { + public void testTrainSupervised(@TempDir Path testDir) throws IOException { - File output = testDir.newFile(); + File output = testDir.toFile(); FastText fastText = - FastText.builder().supervised(true). - inputFile(inputFile.getAbsolutePath()). - outputFile(output.getAbsolutePath()).build(); + FastText.builder().supervised(true). + inputFile(inputFile.getAbsolutePath()). + outputFile(output.getAbsolutePath()).build(); log.info("\nTraining supervised model ...\n"); fastText.fit(); } @Test - public void testTrainSkipgram() throws IOException { + public void testTrainSkipgram(@TempDir Path testDir) throws IOException { - File output = testDir.newFile(); + File output = testDir.toFile(); FastText fastText = FastText.builder().skipgram(true). @@ -85,9 +84,9 @@ public class FastTextTest extends BaseDL4JTest { } @Test - public void testTrainSkipgramWithBuckets() throws IOException { + public void testTrainSkipgramWithBuckets(@TempDir Path testDir) throws IOException { - File output = testDir.newFile(); + File output = Files.createTempFile(testDir,"newFile","bin").toFile(); FastText fastText = FastText.builder().skipgram(true). @@ -99,9 +98,9 @@ public class FastTextTest extends BaseDL4JTest { } @Test - public void testTrainCBOW() throws IOException { + public void testTrainCBOW(@TempDir Path testDir) throws IOException { - File output = testDir.newFile(); + File output = Files.createTempFile(testDir,"newFile","bin").toFile(); FastText fastText = FastText.builder().cbow(true). @@ -126,21 +125,6 @@ public class FastTextTest extends BaseDL4JTest { @Test public void testPredict() { - String text = "I like soccer"; - - FastText fastText = new FastText(supModelFile); - assertEquals(48, fastText.vocab().numWords()); - assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1)); - - double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582}; - assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3); - - String label = fastText.predict(text); - assertEquals("__label__soccer", label); - } - - @Test(expected = IllegalStateException.class) - public void testIllegalState() { String text = "I like soccer"; FastText fastText = new FastText(supModelFile); @@ -151,7 +135,25 @@ public class FastTextTest extends BaseDL4JTest { assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3); String label = fastText.predict(text); - fastText.wordsNearest("test",1); + assertEquals("__label__soccer", label); + } + + @Test() + public void testIllegalState() { + assertThrows(IllegalStateException.class,() -> { + String text = "I like soccer"; + + FastText fastText = new FastText(supModelFile); + assertEquals(48, fastText.vocab().numWords()); + assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1)); + + double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582}; + assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3); + + String label = fastText.predict(text); + fastText.wordsNearest("test",1); + }); + } @Test @@ -183,19 +185,19 @@ public class FastTextTest extends BaseDL4JTest { assertEquals(48, fastText.vocabSize()); String[] expected = {"", ".", "is", "game", "the", "soccer", "?", "football", "3", "12", "takes", "usually", "A", "US", - "in", "popular", "most", "hours", "and", "clubs", "minutes", "Do", "you", "like", "Is", "your", "favorite", "games", - "Premier", "Soccer", "a", "played", "by", "two", "teams", "of", "eleven", "players", "The", "Football", "League", "an", - "English", "professional", "league", "for", "men's", "association"}; + "in", "popular", "most", "hours", "and", "clubs", "minutes", "Do", "you", "like", "Is", "your", "favorite", "games", + "Premier", "Soccer", "a", "played", "by", "two", "teams", "of", "eleven", "players", "The", "Football", "League", "an", + "English", "professional", "league", "for", "men's", "association"}; for (int i = 0; i < fastText.vocabSize(); ++i) { - assertEquals(expected[i], fastText.vocab().wordAtIndex(i)); + assertEquals(expected[i], fastText.vocab().wordAtIndex(i)); } } @Test public void testLoadIterator() throws FileNotFoundException { SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); - FastText + FastText .builder() .supervised(true) .iterator(iter) @@ -203,16 +205,19 @@ public class FastTextTest extends BaseDL4JTest { .loadIterator(); } - @Test(expected=IllegalStateException.class) + @Test() public void testState() { - FastText fastText = new FastText(); - fastText.predict("something"); + assertThrows(IllegalStateException.class,() -> { + FastText fastText = new FastText(); + fastText.predict("something"); + }); + } @Test - public void testPretrainedVectors() throws IOException { - File output = testDir.newFile(); - + public void testPretrainedVectors(@TempDir Path testDir) throws IOException { + File output = new File(testDir.toFile(),"newfile.bin"); + output.deleteOnExit(); FastText fastText = FastText .builder() .supervised(true) @@ -226,8 +231,8 @@ public class FastTextTest extends BaseDL4JTest { } @Test - public void testWordsStatistics() throws IOException { - File output = testDir.newFile(); + public void testWordsStatistics(@TempDir Path testDir) throws IOException { + File output = Files.createTempFile(testDir,"output","bin").toFile(); FastText fastText = FastText .builder() @@ -243,9 +248,9 @@ public class FastTextTest extends BaseDL4JTest { Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file); assertEquals(48, word2Vec.getVocab().numWords()); - assertEquals("", 0.1667751520872116, word2Vec.similarity("Football", "teams"), 2e-3); - assertEquals("", 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 2e-3); - assertEquals("", Double.NaN, word2Vec.similarity("java","cpp"), 0.0); + assertEquals( 0.1667751520872116, word2Vec.similarity("Football", "teams"), 2e-3); + assertEquals( 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 2e-3); + assertEquals( Double.NaN, word2Vec.similarity("java","cpp"), 0.0); assertThat(word2Vec.wordsNearest("association", 3), hasItems("Football", "Soccer", "men's")); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 335225da7..9a2782fed 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -31,8 +31,10 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator; import org.deeplearning4j.text.sentenceiterator.*; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -55,8 +57,8 @@ import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIterato import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.CollectionUtils; @@ -66,9 +68,10 @@ import org.nd4j.common.resources.Resources; import java.io.*; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class ParagraphVectorsTest extends BaseDL4JTest { @@ -78,8 +81,6 @@ public class ParagraphVectorsTest extends BaseDL4JTest { return isIntegrationTests() ? 600_000 : 240_000; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Override public DataType getDataType() { @@ -124,7 +125,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { * * @throws Exception */ - @Test(timeout = 2400000) + @Test() + @Timeout(2400000) public void testParagraphVectorsVocabBuilding1() throws Exception { File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); //UimaSentenceIterator.createWithPath(file.getAbsolutePath()); @@ -170,8 +172,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest { * * @throws Exception */ - @Test(timeout = 3000000) - @Ignore("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - Issue #7657") + @Test() + @Timeout(3000000) + @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - Issue #7657") public void testParagraphVectorsModelling1() throws Exception { File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); @@ -432,7 +435,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } - @Test(timeout = 300000) + @Timeout(300000) public void testParagraphVectorsDBOW() throws Exception { skipUnlessIntegrationTests(); @@ -509,7 +512,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testParagraphVectorsWithWordVectorsModelling1() throws Exception { String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { @@ -601,9 +605,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest { * @throws Exception */ @Test - @Ignore - public void testParagraphVectorsReducedLabels1() throws Exception { - val tempDir = testDir.newFolder(); + @Disabled + public void testParagraphVectorsReducedLabels1(@TempDir Path testDir) throws Exception { + val tempDir = testDir.toFile(); ClassPathResource resource = new ClassPathResource("/labeled"); resource.copyDirectory(tempDir); @@ -650,7 +654,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest { log.info("Similarity positive: " + simV); } - @Test(timeout = 300000) + + @Test() + @Timeout(300000) public void testParallelIterator() throws IOException { TokenizerFactory factory = new DefaultTokenizerFactory(); SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); @@ -674,9 +680,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } @Test - public void testIterator() throws IOException { - val folder_labeled = testDir.newFolder(); - val folder_unlabeled = testDir.newFolder(); + public void testIterator(@TempDir Path testDir) throws IOException { + val folder_labeled = new File(testDir.toFile(),"labeled"); + val folder_unlabeled = new File(testDir.toFile(),"unlabeled"); new ClassPathResource("/paravec/labeled/").copyDirectory(folder_labeled); new ClassPathResource("/paravec/unlabeled/").copyDirectory(folder_unlabeled); @@ -721,7 +727,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { there's no need in this test within travis, use it manually only for problems detection */ @Test - public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception { + public void testParagraphVectorsOverExistingWordVectorsModel(@TempDir Path testDir) throws Exception { String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed @@ -730,7 +736,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { // we build w2v from multiple sources, to cover everything File resource_sentences = Resources.asFile("/big/raw_sentences.txt"); - val folder_mixed = testDir.newFolder(); + val folder_mixed = testDir.toFile(); ClassPathResource resource_mixed = new ClassPathResource("paravec/"); resource_mixed.copyDirectory(folder_mixed); @@ -756,8 +762,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { // At this moment we have ready w2v model. It's time to use it for ParagraphVectors - val folder_labeled = testDir.newFolder(); - val folder_unlabeled = testDir.newFolder(); + val folder_labeled = new File(testDir.toFile(),"labeled"); + val folder_unlabeled = new File(testDir.toFile(),"unlabeled"); new ClassPathResource("/paravec/labeled/").copyDirectory(folder_labeled); new ClassPathResource("/paravec/unlabeled/").copyDirectory(folder_unlabeled); @@ -866,7 +872,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { /** * Special test to check d2v inference against pre-trained gensim model and */ - @Ignore + @Disabled @Test public void testGensimEquality() throws Exception { @@ -1017,14 +1023,14 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } @Test - @Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677 - public void testDirectInference() throws Exception { + @Disabled //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677 + public void testDirectInference(@TempDir Path testDir) throws Exception { boolean isIntegration = isIntegrationTests(); File resource = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator sentencesIter = getIterator(isIntegration, resource); ClassPathResource resource_mixed = new ClassPathResource("paravec/"); - File local_resource_mixed = testDir.newFolder(); + File local_resource_mixed = testDir.toFile(); resource_mixed.copyDirectory(local_resource_mixed); SentenceIterator iter = new AggregatingSentenceIterator.Builder() .addSentenceIterator(sentencesIter) @@ -1050,7 +1056,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2)); } - @Ignore + @Disabled @Test public void testGoogleModelForInference() throws Exception { WordVectors googleVectors = WordVectorSerializer.readWord2VecModel(new File("/ext/GoogleNews-vectors-negative300.bin.gz")); @@ -1069,7 +1075,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2)); } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testHash() { VocabWord w1 = new VocabWord(1.0, "D1"); VocabWord w2 = new VocabWord(1.0, "Bo"); @@ -1088,7 +1095,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { * * @throws Exception */ - @Ignore + @Disabled @Test public void testsParallelFit1() throws Exception { final File file = Resources.asFile("big/raw_sentences.txt"); @@ -1134,7 +1141,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testJSONSerialization() { ParagraphVectors paragraphVectors = new ParagraphVectors.Builder().build(); AbstractCache cache = new AbstractCache.Builder().build(); @@ -1173,7 +1181,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testDoubleFit() throws Exception { boolean isIntegration = isIntegrationTests(); File resource = Resources.asFile("/big/raw_sentences.txt"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java index 8add7ac24..260ad9c8b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java @@ -54,9 +54,9 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.heartbeat.Heartbeat; @@ -69,14 +69,14 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled public class SequenceVectorsTest extends BaseDL4JTest { protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class); - @Before + @BeforeEach public void setUp() throws Exception { } @@ -270,7 +270,7 @@ public class SequenceVectorsTest extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void testDeepWalk() throws Exception { Heartbeat.getInstance().disableHeartbeat(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalkerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalkerTest.java index 06db2db26..50001500c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalkerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalkerTest.java @@ -30,19 +30,19 @@ import org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFact import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.util.concurrent.atomic.AtomicBoolean; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class PopularityWalkerTest extends BaseDL4JTest { private static Graph graph; - @Before + @BeforeEach public void setUp() { if (graph == null) { graph = new Graph<>(10, false, new AbstractVertexFactory()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java index 0687ab18b..7c150a610 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java @@ -31,12 +31,12 @@ import org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFact import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class RandomWalkerTest extends BaseDL4JTest { @@ -46,7 +46,7 @@ public class RandomWalkerTest extends BaseDL4JTest { protected static final Logger logger = LoggerFactory.getLogger(RandomWalkerTest.class); - @Before + @BeforeEach public void setUp() throws Exception { if (graph == null) { graph = new Graph<>(10, false, new AbstractVertexFactory()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/WeightedWalkerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/WeightedWalkerTest.java index bd69dea7b..7cf36eb9e 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/WeightedWalkerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/WeightedWalkerTest.java @@ -28,16 +28,16 @@ import org.deeplearning4j.models.sequencevectors.graph.vertex.AbstractVertexFact import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class WeightedWalkerTest extends BaseDL4JTest { private static Graph basicGraph; - @Before + @BeforeEach public void setUp() throws Exception { if (basicGraph == null) { // we don't really care about this graph, since it's just basic graph for iteration checks diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/AbstractElementFactoryTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/AbstractElementFactoryTest.java index e99c45c35..ee7d33022 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/AbstractElementFactoryTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/AbstractElementFactoryTest.java @@ -22,14 +22,14 @@ package org.deeplearning4j.models.sequencevectors.serialization; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class AbstractElementFactoryTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/VocabWordFactoryTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/VocabWordFactoryTest.java index 16a770891..6ad958e79 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/VocabWordFactoryTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/VocabWordFactoryTest.java @@ -22,14 +22,14 @@ package org.deeplearning4j.models.sequencevectors.serialization; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class VocabWordFactoryTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformerTest.java index 9dc5e60d5..e0cb1c5cc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformerTest.java @@ -29,18 +29,18 @@ import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker; import org.deeplearning4j.models.sequencevectors.graph.walkers.impl.RandomWalker; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.util.Iterator; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class GraphTransformerTest extends BaseDL4JTest { private static IGraph graph; - @Before + @BeforeEach public void setUp() throws Exception { if (graph == null) { graph = new Graph<>(10, false, new AbstractVertexFactory()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java index 34349e58b..eaf7022de 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java @@ -33,27 +33,29 @@ import org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; import java.io.InputStream; import java.util.Iterator; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j public class ParallelTransformerIteratorTest extends BaseDL4JTest { private TokenizerFactory factory = new DefaultTokenizerFactory(); - @Before + @BeforeEach public void setUp() throws Exception { } - @Test(timeout = 300000) + @Test() + @Timeout(30000) public void hasNext() throws Exception { SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); @@ -65,8 +67,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { Sequence sequence = null; while (iter.hasNext()) { sequence = iter.next(); - assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence); - assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size()); + assertNotEquals( null, sequence,"Failed on [" + cnt + "] iteration"); + assertNotEquals(0, sequence.size(),"Failed on [" + cnt + "] iteration"); cnt++; } @@ -75,7 +77,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { assertEquals(97162, cnt); } - @Test(timeout = 300000) + @Test() + @Timeout(30000) public void testSpeedComparison1() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 25); @@ -88,8 +91,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { long time1 = System.currentTimeMillis(); while (iter.hasNext()) { Sequence sequence = iter.next(); - assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence); - assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size()); + assertNotEquals(null, sequence,"Failed on [" + cnt + "] iteration"); + assertNotEquals( 0, sequence.size(),"Failed on [" + cnt + "] iteration"); cnt++; } long time2 = System.currentTimeMillis(); @@ -105,8 +108,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { time1 = System.currentTimeMillis(); while (iter.hasNext()) { Sequence sequence = iter.next(); - assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence); - assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size()); + assertNotEquals(null, sequence,"Failed on [" + cnt + "] iteration"); + assertNotEquals(0, sequence.size(),"Failed on [" + cnt + "] iteration"); cnt++; } time2 = System.currentTimeMillis(); @@ -129,8 +132,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { time1 = System.currentTimeMillis(); while (iter.hasNext()) { Sequence sequence = iter.next(); - assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence); - assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size()); + assertNotEquals(null, sequence, "Failed on [" + cnt + "] iteration"); + assertNotEquals(0, sequence.size(),"Failed on [" + cnt + "] iteration"); cnt++; } time2 = System.currentTimeMillis(); @@ -147,8 +150,8 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { time1 = System.currentTimeMillis(); while (iter.hasNext()) { Sequence sequence = iter.next(); - assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence); - assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size()); + assertNotEquals(null, sequence, "Failed on [" + cnt + "] iteration"); + assertNotEquals(0, sequence.size(),"Failed on [" + cnt + "] iteration"); cnt++; } time2 = System.currentTimeMillis(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 8bb1f4d0d..bd59aa1a9 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -35,6 +35,7 @@ import org.deeplearning4j.text.documentiterator.LabelAwareIterator; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.util.ModelSerializer; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,8 +43,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.resources.Resources; @@ -53,8 +54,8 @@ import java.io.File; import java.util.Collection; import java.util.concurrent.Callable; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @@ -66,7 +67,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { return isIntegrationTests() ? 240000 : 60000; } - @Before + @BeforeEach public void setUp() throws Exception { word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile()); } @@ -92,7 +93,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest { assertEquals(neighbours, nearestWords.size()); } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testUnkSerialization_1() throws Exception { val inputFile = Resources.asFile("big/raw_sentences.txt"); // val iter = new BasicLineIterator(inputFile); @@ -152,7 +154,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest { } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testW2VEmbeddingLayerInit() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecVisualizationTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecVisualizationTests.java index 35c4af5ad..f14ddbed5 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecVisualizationTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecVisualizationTests.java @@ -23,16 +23,16 @@ package org.deeplearning4j.models.word2vec; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; -@Ignore +@Disabled public class Word2VecVisualizationTests extends BaseDL4JTest { private static WordVectors vectors; - @Before + @BeforeEach public synchronized void setUp() throws Exception { if (vectors == null) { vectors = WordVectorSerializer.loadFullModel("/ext/Temp/Models/raw_sentences.dat"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index c282a4215..a18e13c5b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -32,8 +32,8 @@ import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIte import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.common.resources.Resources; @@ -44,7 +44,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; public class Word2VecDataSetIteratorTest extends BaseDL4JTest { @@ -57,7 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { * Basically all we want from this test - being able to finish without exceptions. */ @Test - @Ignore + @Disabled public void testIterator1() throws Exception { File inputFile = Resources.asFile("big/raw_sentences.txt"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java index d238b6474..c20528973 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java @@ -22,9 +22,10 @@ package org.deeplearning4j.models.word2vec.wordstore; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; -import org.junit.rules.Timeout; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; @@ -39,32 +40,31 @@ import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class VocabConstructorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(VocabConstructorTest.class); TokenizerFactory t = new DefaultTokenizerFactory(); - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + - @Before + @BeforeEach public void setUp() throws Exception { t.setTokenPreProcessor(new CommonPreprocessor()); } @@ -290,7 +290,7 @@ public class VocabConstructorTest extends BaseDL4JTest { } @Test - public void testMergedVocabWithLabels1() throws Exception { + public void testMergedVocabWithLabels1(@TempDir Path testDir) throws Exception { AbstractCache cacheSource = new AbstractCache.Builder().build(); AbstractCache cacheTarget = new AbstractCache.Builder().build(); @@ -314,7 +314,7 @@ public class VocabConstructorTest extends BaseDL4JTest { int sourceSize = cacheSource.numWords(); log.info("Source Vocab size: " + sourceSize); - val dir = testDir.newFolder(); + val dir = testDir.toFile(); new ClassPathResource("/paravec/labeled/").copyDirectory(dir); @@ -435,7 +435,8 @@ public class VocabConstructorTest extends BaseDL4JTest { } - @Test(timeout=5000) // 5s timeout + @Test() // 5s timeout + @Timeout(5000) public void testParallelTokenizationDisabled_Completes() throws Exception { File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(inputFile); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolderTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolderTest.java index 2b8032ca8..41c1fc0ba 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolderTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolderTest.java @@ -22,9 +22,9 @@ package org.deeplearning4j.models.word2vec.wordstore; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class VocabularyHolderTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java index 79c3bb568..9cdf38363 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java @@ -27,17 +27,17 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.sequencevectors.serialization.ExtVocabWord; import org.deeplearning4j.models.word2vec.Huffman; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.util.Collection; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class AbstractCacheTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java index 8b060db16..c40e4bcdc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java @@ -23,13 +23,15 @@ package org.deeplearning4j.text.documentiterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class AsyncLabelAwareIteratorTest extends BaseDL4JTest { - @Test(timeout = 300000) + @Test() + @Timeout(30000) public void nextDocument() throws Exception { SentenceIterator sentence = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); BasicLabelAwareIterator backed = new BasicLabelAwareIterator.Builder(sentence).build(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java index c292def03..e2d635108 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java @@ -21,24 +21,23 @@ package org.deeplearning4j.text.documentiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.Timeout; + + import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class BasicLabelAwareIteratorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); - @Before + + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/DefaultDocumentIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/DefaultDocumentIteratorTest.java index 64d99b37b..06cdb5fcb 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/DefaultDocumentIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/DefaultDocumentIteratorTest.java @@ -24,13 +24,13 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.io.InputStream; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class DefaultDocumentIteratorTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileDocumentIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileDocumentIteratorTest.java index 3dff102ba..9e4edffa2 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileDocumentIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileDocumentIteratorTest.java @@ -25,31 +25,34 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import java.io.File; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.HashSet; import java.util.Set; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@Ignore +@Disabled public class FileDocumentIteratorTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + + @BeforeEach public void setUp() throws Exception { } @@ -107,9 +110,10 @@ public class FileDocumentIteratorTest extends BaseDL4JTest { assertEquals(48, cnt); } - @Test(timeout = 5000L) - public void testEmptyDocument() throws Exception { - File f = testDir.newFile(); + @Test() + @Timeout(5000) + public void testEmptyDocument(@TempDir Path testDir) throws Exception { + File f = Files.createTempFile(testDir,"newfile","bin").toFile(); assertTrue(f.exists()); assertEquals(0, f.length()); @@ -121,9 +125,10 @@ public class FileDocumentIteratorTest extends BaseDL4JTest { } } - @Test(timeout = 5000L) - public void testEmptyDocument2() throws Exception { - File dir = testDir.newFolder(); + @Test() + @Timeout(5000) + public void testEmptyDocument2(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); File f1 = new File(dir, "1.txt"); FileUtils.writeStringToFile(f1, "line 1\nline2", StandardCharsets.UTF_8); File f2 = new File(dir, "2.txt"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileLabelAwareIteratorTest.java index 14f0bfe0a..c94eaf747 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FileLabelAwareIteratorTest.java @@ -22,27 +22,29 @@ package org.deeplearning4j.text.documentiterator; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; -import org.nd4j.common.io.ClassPathResource; -import org.junit.Before; -import org.junit.Test; -import static org.junit.Assert.*; + +import org.junit.jupiter.api.io.TempDir; +import org.nd4j.common.io.ClassPathResource; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.*; public class FileLabelAwareIteratorTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + @BeforeEach public void setUp() throws Exception { } @Test - public void testExtractLabelFromPath1() throws Exception { - val dir = testDir.newFolder(); + public void testExtractLabelFromPath1(@TempDir Path testDir) throws Exception { + val dir = testDir.toFile(); val resource = new ClassPathResource("/labeled/"); resource.copyDirectory(dir); @@ -69,9 +71,9 @@ public class FileLabelAwareIteratorTest extends BaseDL4JTest { @Test - public void testExtractLabelFromPath2() throws Exception { - val dir0 = testDir.newFolder(); - val dir1 = testDir.newFolder(); + public void testExtractLabelFromPath2(@TempDir Path testDir) throws Exception { + val dir0 = new File(testDir.toFile(),"dir-0"); + val dir1 = new File(testDir.toFile(),"dir-1"); val resource = new ClassPathResource("/labeled/"); val resource2 = new ClassPathResource("/rootdir/"); resource.copyDirectory(dir0); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIteratorTest.java index 13fade701..68c4677c0 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIteratorTest.java @@ -22,31 +22,31 @@ package org.deeplearning4j.text.documentiterator; import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; -import org.junit.Before; -import org.junit.Test; + + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.resources.Resources; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class FilenamesLabelAwareIteratorTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + @BeforeEach public void setUp() throws Exception { } @Test - public void testNextDocument() throws Exception { - val tempDir = testDir.newFolder(); + public void testNextDocument(@TempDir Path testDir) throws Exception { + val tempDir = testDir.toFile(); Resources.copyDirectory("/big/", tempDir); FilenamesLabelAwareIterator iterator = new FilenamesLabelAwareIterator.Builder() diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/LabelsSourceTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/LabelsSourceTest.java index 8f8a78f10..673b38485 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/LabelsSourceTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/LabelsSourceTest.java @@ -21,17 +21,17 @@ package org.deeplearning4j.text.documentiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class LabelsSourceTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java index 61935aec9..6f8acb8a7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java @@ -21,16 +21,18 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; import java.io.File; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class AggregatingSentenceIteratorTest extends BaseDL4JTest { - @Test(timeout = 300000) + @Test() + @Timeout(30000) public void testHasNext() throws Exception { File file = Resources.asFile("/big/raw_sentences.txt"); BasicLineIterator iterator = new BasicLineIterator(file); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java index 9ca2642c4..1a1a0a685 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java @@ -21,23 +21,22 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.Timeout; -import org.junit.Before; -import org.junit.Test; + + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import java.io.File; import java.io.FileInputStream; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class BasicLineIteratorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); - @Before + + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIteratorTest.java index 294a97dac..a73e4da89 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIteratorTest.java @@ -21,17 +21,17 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import java.sql.ResultSet; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class BasicResultSetIteratorTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java index 88116bb57..67774e97f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java @@ -21,13 +21,15 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class MutipleEpochsSentenceIteratorTest extends BaseDL4JTest { - @Test(timeout = 300000) + @Test() + @Timeout(30000) public void hasNext() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 100); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java index af2e8e7a5..cd8ca169f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java @@ -21,22 +21,21 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Rule; -import org.junit.rules.Timeout; -import org.junit.Test; + + +import org.junit.jupiter.api.Test; import org.nd4j.common.resources.Resources; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class PrefetchingSentenceIteratorTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(PrefetchingSentenceIteratorTest.class); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/StreamLineIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/StreamLineIteratorTest.java index 5fe9dfe56..0f447ab64 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/StreamLineIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/StreamLineIteratorTest.java @@ -22,15 +22,15 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; import org.nd4j.common.io.ClassPathResource; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.FileInputStream; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class StreamLineIteratorTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index 976fe57fd..e7933d1be 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -26,8 +26,9 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.BertWordPiecePreProcessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.resources.Resources; @@ -39,10 +40,10 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled public class BertWordPieceTokenizerTests extends BaseDL4JTest { private File pathToVocab = Resources.asFile("other/vocab.txt"); @@ -112,7 +113,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { } @Test - @Ignore("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657") + @Disabled("AB 2019/05/24 - Disabled until dev branch merged - see issue #7657") public void testBertWordPieceTokenizer5() throws Exception { // Longest Token in Vocab is 22 chars long, so make sure splits on the edge are properly handled String toTokenize = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; @@ -194,7 +195,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { String m = e.getMessage(); assertNotNull(m); m = m.toLowerCase(); - assertTrue(m, m.contains("invalid") && m.contains("token") && m.contains("preprocessor")); + assertTrue(m.contains("invalid") && m.contains("token") && m.contains("preprocessor"), m); } try { @@ -204,13 +205,14 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { String m = e.getMessage(); assertNotNull(m); m = m.toLowerCase(); - assertTrue(m, m.contains("invalid") && m.contains("token") && m.contains("preprocessor")); + assertTrue(m.contains("invalid") && m.contains("token") && m.contains("preprocessor"), m); } } } - @Test(timeout = 300000) + @Test() + @Timeout(300000) public void testBertWordPieceTokenizer10() throws Exception { File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt"); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/DefaulTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/DefaulTokenizerTests.java index 302f6d1ef..d3ab0bfc8 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/DefaulTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/DefaulTokenizerTests.java @@ -25,14 +25,14 @@ import org.deeplearning4j.BaseDL4JTest; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.ByteArrayInputStream; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class DefaulTokenizerTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/NGramTokenizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/NGramTokenizerTest.java index 738d96c23..6d36889cf 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/NGramTokenizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/NGramTokenizerTest.java @@ -24,12 +24,12 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/tokenprepreprocessor/EndingPreProcessorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/tokenprepreprocessor/EndingPreProcessorTest.java index 0df619752..03db99995 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/tokenprepreprocessor/EndingPreProcessorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/tokenprepreprocessor/EndingPreProcessorTest.java @@ -23,9 +23,9 @@ package org.deeplearning4j.text.tokenization.tokenizer.tokenprepreprocessor; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.EndingPreProcessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class EndingPreProcessorTest extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactoryTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactoryTest.java index 8e604c11b..32ccee306 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactoryTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactoryTest.java @@ -24,9 +24,9 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class NGramTokenizerFactoryTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/wordstore/InMemoryVocabStoreTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/wordstore/InMemoryVocabStoreTests.java index 3b8e055b2..fab3d2e89 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/wordstore/InMemoryVocabStoreTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/wordstore/InMemoryVocabStoreTests.java @@ -24,11 +24,11 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InMemoryVocabStoreTests extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java index 322a8e28f..d92cdf753 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.parallelism.ParallelWrapper; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java index e29ed17de..a8db019b4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java @@ -27,12 +27,12 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.parallelism.inference.LoadBalanceMode; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InplaceParallelInferenceTest extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index af74a7ed2..0ee5d5f26 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -32,8 +32,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.*; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.*; import org.nd4j.linalg.activations.Activation; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; @@ -58,17 +57,16 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class ParallelInferenceTest extends BaseDL4JTest { private static MultiLayerNetwork model; private static DataSetIterator iterator; - @Rule - public Timeout timeout = Timeout.seconds(300); - @Before + + @BeforeEach public void setUp() throws Exception { if (model == null) { File file = Resources.asFile("models/LenetMnistMLN.zip"); @@ -78,12 +76,13 @@ public class ParallelInferenceTest extends BaseDL4JTest { } } - @After + @AfterEach public void tearDown() throws Exception { iterator.reset(); } - @Test(timeout = 30000L) + @Test() + @Timeout(30000) public void testInferenceSequential1() throws Exception { long count0 = 0; @@ -128,7 +127,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertTrue(count1 > 0L); } - @Test(timeout = 30000L) + @Test() + @Timeout(30000) public void testInferenceSequential2() throws Exception { long count0 = 0; @@ -173,7 +173,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } - @Test(timeout = 30000L) + @Test() + @Timeout(30000) public void testInferenceBatched1() throws Exception { long count0 = 0; long count1 = 0; @@ -405,7 +406,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } - @Test(timeout = 120000L) + @Test() + @Timeout(120000) public void testParallelInferenceVariableLengthTS() throws Exception { Nd4j.getRandom().setSeed(12345); @@ -451,7 +453,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } } - @Test(timeout = 120000L) + @Test() + @Timeout(120000) public void testParallelInferenceVariableLengthTS2() throws Exception { Nd4j.getRandom().setSeed(12345); @@ -506,8 +509,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } - - @Test(timeout = 30000L) + @Test() + @Timeout(30000) public void testParallelInferenceVariableSizeCNN() throws Exception { //Variable size input for CNN model - for example, YOLO models //In these cases, we can't batch and have to execute the different size inputs separately @@ -562,8 +565,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } } - - @Test(timeout = 30000L) + @Test() + @Timeout(30000) public void testParallelInferenceVariableSizeCNN2() throws Exception { //Variable size input for CNN model - for example, YOLO models //In these cases, we can't batch and have to execute the different size inputs separately @@ -617,7 +620,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } } - @Test(timeout = 20000L) + @Test() + @Timeout(20000) public void testParallelInferenceErrorPropagation(){ int nIn = 10; @@ -751,7 +755,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { } } - @Test(timeout = 20000L) + @Test() + @Timeout(20000) public void testModelUpdate_1() throws Exception { int nIn = 5; @@ -789,7 +794,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { // model can be null for some of the workers yet, due to race condition if (m != null) { Thread.sleep(500); - assertEquals("Failed at model [" + cnt0 + "]", net.params(), m.params()); + assertEquals(net.params(), m.params(), "Failed at model [" + cnt0 + "]"); passed = true; } cnt0++; @@ -816,14 +821,15 @@ public class ParallelInferenceTest extends BaseDL4JTest { cnt0 = 0; for (val m:modelsAfter) { - assertNotNull("Failed at model [" + cnt0 + "]", m); - assertEquals("Failed at model [" + cnt0++ + "]", net2.params(), m.params()); + assertNotNull(m,"Failed at model [" + cnt0 + "]"); + assertEquals(net2.params(), m.params(), "Failed at model [" + cnt0++ + "]"); } inf.shutdown(); } - @Test(timeout = 120000L) + @Test() + @Timeout(120000) public void testMultiOutputNet() throws Exception { int nIn = 5; @@ -936,7 +942,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { // System.out.println(Arrays.toString(e.shape()) + " vs " + Arrays.toString(a.shape())); // assertArrayEquals(e.shape(), a.shape()); - assertEquals("Failed at iteration [" + i + "]", e, a); + assertEquals(e, a, "Failed at iteration [" + i + "]"); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java index a50dfe8fa..7f751df54 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -45,7 +45,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ParallelWrapperTest extends BaseDL4JTest { @@ -137,7 +137,7 @@ public class ParallelWrapperTest extends BaseDL4JTest { mnistTest.reset(); double acc = eval.accuracy(); - assertTrue(String.valueOf(acc), acc > 0.5); + assertTrue(acc > 0.5, String.valueOf(acc)); wrapper.shutdown(); } diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java index 352352365..eb3ccfef8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java @@ -36,7 +36,7 @@ import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -48,7 +48,7 @@ import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestListeners extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java index 5746c2214..2eaf2e850 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; @@ -47,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestParallelEarlyStopping extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java index b00787d5a..07dea5739 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java @@ -40,19 +40,19 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestParallelEarlyStoppingUI extends BaseDL4JTest { @Test - @Ignore //To be run manually + @Disabled //To be run manually public void testParallelStatsListenerCompatibility() throws Exception { UIServer uiServer = UIServer.getInstance(); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java index 69064a160..3a85b4b34 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java @@ -34,12 +34,12 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class DefaultTrainerContextTest extends BaseDL4JTest { int nChannels = 1; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java index 261718369..ec82896df 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java @@ -34,12 +34,12 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class SymmetricTrainerContextTest extends BaseDL4JTest { int nChannels = 1; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java index 6c0b6b297..1a49fa3b1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservableTest.java @@ -22,9 +22,9 @@ package org.deeplearning4j.parallelism.inference.observers; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,15 +34,15 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class BatchedInferenceObservableTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception {} - @After + @AfterEach public void tearDown() throws Exception {} @Test diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java index dabbc9469..472bf86b6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java @@ -33,24 +33,25 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; @Slf4j public class ParallelWrapperMainTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void runParallelWrapperMain() throws Exception { + public void runParallelWrapperMain(@TempDir Path testDir) throws Exception { int nChannels = 1; int outputNum = 10; @@ -87,10 +88,10 @@ public class ParallelWrapperMainTest extends BaseDL4JTest { MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - File tempModel = testDir.newFile("tmpmodel.zip"); + File tempModel = Files.createTempFile(testDir,"tmpmodel","zip").toFile(); tempModel.deleteOnExit(); ModelSerializer.writeModel(model, tempModel, false); - File tmp = testDir.newFile("tmpmodel.bin"); + File tmp = Files.createTempFile(testDir,"tmpmodel","bin").toFile(); tmp.deleteOnExit(); ParallelWrapperMain parallelWrapperMain = new ParallelWrapperMain(); try { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java index 4f0da8ca0..2892b1653 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java @@ -33,16 +33,16 @@ import org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer; import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter; import org.deeplearning4j.spark.models.word2vec.SparkWord2VecTest; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Counter; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class SparkSequenceVectorsTest extends BaseDL4JTest { @@ -54,7 +54,7 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest { protected static List> sequencesCyclic; private JavaSparkContext sc; - @Before + @BeforeEach public void setUp() throws Exception { if (sequencesCyclic == null) { sequencesCyclic = new ArrayList<>(); @@ -81,7 +81,7 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest { sc = new JavaSparkContext(sparkConf); } - @After + @AfterEach public void tearDown() throws Exception { sc.stop(); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java index 3b7e7865f..3ca30f7d1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java @@ -22,14 +22,14 @@ package org.deeplearning4j.spark.models.sequencevectors.export; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.word2vec.VocabWord; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class ExportContainerTest extends BaseDL4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java index c981c7a31..a6ce8d2eb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java @@ -34,17 +34,17 @@ import org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer; import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter; import org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class SparkWord2VecTest extends BaseDL4JTest { @@ -56,7 +56,7 @@ public class SparkWord2VecTest extends BaseDL4JTest { private static List sentences; private JavaSparkContext sc; - @Before + @BeforeEach public void setUp() throws Exception { if (sentences == null) { sentences = new ArrayList<>(); @@ -72,13 +72,13 @@ public class SparkWord2VecTest extends BaseDL4JTest { sc = new JavaSparkContext(sparkConf); } - @After + @AfterEach public void tearDown() throws Exception { sc.stop(); } @Test - @Ignore("AB 2019/05/21 - Failing - Issue #7657") + @Disabled("AB 2019/05/21 - Failing - Issue #7657") public void testStringsTokenization1() throws Exception { JavaRDD rddSentences = sc.parallelize(sentences); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java index dc77915ea..e2bd741fb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java @@ -24,8 +24,9 @@ import com.sun.jna.Platform; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; + + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; @@ -37,24 +38,25 @@ import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreproc import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Arrays; import java.util.Collection; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled public class Word2VecTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testConcepts() throws Exception { + public void testConcepts(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; @@ -132,7 +134,8 @@ public class Word2VecTest { // test serialization - File tempFile = testDir.newFile("temp" + System.currentTimeMillis() + ".tmp"); + + File tempFile = Files.createTempFile(testDir,"temp" + System.currentTimeMillis(),"tmp").toFile(); int idx1 = word2Vec.vocab().wordFor("day").getIndex(); @@ -158,7 +161,7 @@ public class Word2VecTest { assertEquals(array1, array2); } - @Ignore + @Disabled @Test public void testSparkW2VonBiggerCorpus() throws Exception { SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest") @@ -197,7 +200,7 @@ public class Word2VecTest { } @Test - @Ignore + @Disabled public void testPortugeseW2V() throws Exception { WordVectors word2Vec = WordVectorSerializer.loadTxtVectors(new File("/ext/Temp/para.txt")); word2Vec.setModelUtils(new FlatModelUtils()); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 9282a8a82..d998ddde4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -24,8 +24,8 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables; -import org.junit.After; -import org.junit.Before; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import java.io.Serializable; import java.lang.reflect.Field; @@ -40,12 +40,12 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable return 120000L; } - @Before + @BeforeEach public void before() throws Exception { sc = getContext(); } - @After + @AfterEach public void after() { if(sc != null) { sc.close(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index b3bd10b2c..aa559f90a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -35,9 +35,9 @@ import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; import org.deeplearning4j.spark.text.functions.CountCumSum; import org.deeplearning4j.spark.text.functions.TextPipeline; import org.deeplearning4j.text.stopwords.StopWords; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Counter; import org.nd4j.common.primitives.Pair; @@ -48,8 +48,8 @@ import scala.Tuple2; import java.util.*; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Jeffrey Tang @@ -67,7 +67,7 @@ public class TextPipelineTest extends BaseSparkTest { return sc.parallelize(sentenceList, 2); } - @Before + @BeforeEach public void before() throws Exception { conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost"); @@ -335,7 +335,7 @@ public class TextPipelineTest extends BaseSparkTest { sc.stop(); } - @Test @Ignore //AB 2020/04/20 https://github.com/eclipse/deeplearning4j/issues/8849 + @Test @Disabled //AB 2020/04/20 https://github.com/eclipse/deeplearning4j/issues/8849 public void testCountCumSum() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); @@ -360,7 +360,7 @@ public class TextPipelineTest extends BaseSparkTest { * * @throws Exception */ - @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 + @Test @Disabled //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction1() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); @@ -398,7 +398,7 @@ public class TextPipelineTest extends BaseSparkTest { sc.stop(); } - @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 + @Test @Disabled //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction2() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index 4727378cc..d110e41bd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -28,8 +28,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.After; -import org.junit.Before; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -60,7 +60,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable return 120000L; } - @Before + @BeforeEach public void before() { sc = getContext(); @@ -78,7 +78,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable sparkData = getBasicSparkDataSet(nRows, input, labels); } - @After + @AfterEach public void after() { sc.close(); sc = null; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java index 2f3f0f952..c5731af74 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java @@ -21,15 +21,15 @@ package org.deeplearning4j.spark.parameterserver.accumulation; import com.sun.jna.Platform; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class SharedTrainingAccumulationFunctionTest { - @Before + @BeforeEach public void setUp() throws Exception {} @Test diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java index 8d65bd693..25ef434bd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java @@ -22,15 +22,15 @@ package org.deeplearning4j.spark.parameterserver.accumulation; import com.sun.jna.Platform; import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class SharedTrainingAggregateFunctionTest { - @Before + @BeforeEach public void setUp() throws Exception { // } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java index 7be5f6105..b837efe5e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java @@ -21,8 +21,8 @@ package org.deeplearning4j.spark.parameterserver.iterators; import com.sun.jna.Platform; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -31,10 +31,10 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class VirtualDataSetIteratorTest { - @Before + @BeforeEach public void setUp() throws Exception {} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java index 43849d939..4e56b575a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java @@ -21,16 +21,16 @@ package org.deeplearning4j.spark.parameterserver.iterators; import com.sun.jna.Platform; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class VirtualIteratorTest { - @Before + @BeforeEach public void setUp() throws Exception { // } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java index 3e9c7d3e0..16429f41e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java @@ -27,7 +27,7 @@ import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.parameterserver.BaseSparkTest; import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import java.io.File; @@ -35,7 +35,7 @@ import java.nio.file.Files; import java.nio.file.StandardCopyOption; import static java.io.File.createTempFile; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestElephasImport extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index d98a9561a..5ea8ac321 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -46,10 +46,11 @@ import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.parameterserver.BaseSparkTest; import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -65,17 +66,18 @@ import org.nd4j.parameterserver.distributed.v2.enums.MeshBuildMode; import java.io.File; import java.io.Serializable; import java.net.Inet4Address; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.ConcurrentHashMap; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -//@Ignore("AB 2019/05/21 - Failing - Issue #7657") +//@Disabled("AB 2019/05/21 - Failing - Issue #7657") public class GradientSharingTrainingTest extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Override public long getTimeoutMilliseconds() { @@ -83,7 +85,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } @Test - public void trainSanityCheck() throws Exception { + public void trainSanityCheck(@TempDir Path testDir) throws Exception { for(boolean mds : new boolean[]{false, true}) { INDArray last = null; @@ -108,7 +110,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { throw new RuntimeException(); } - File temp = testDir.newFolder(); + File temp = testDir.toFile(); //TODO this probably won't work everywhere... @@ -146,7 +148,8 @@ public class GradientSharingTrainingTest extends BaseSparkTest { sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); // System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - File f = testDir.newFolder(); + File f = new File(testDir.toFile(),"test-dir-1"); + f.mkdirs(); DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); int count = 0; List paths = new ArrayList<>(); @@ -224,7 +227,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { double accAfter = eAfter.accuracy(); double accBefore = eBefore.accuracy(); - assertTrue("after: " + accAfter + ", before=" + accBefore, accAfter >= accBefore + 0.005); + assertTrue(accAfter >= accBefore + 0.005, "after: " + accAfter + ", before=" + accBefore); if (i == 0) { acc[0] = eBefore.accuracy(); @@ -239,11 +242,11 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985 - public void differentNetsTrainingTest() throws Exception { + @Test @Disabled //AB https://github.com/eclipse/deeplearning4j/issues/8985 + public void differentNetsTrainingTest(@TempDir Path testDir) throws Exception { int batch = 3; - File temp = testDir.newFolder(); + File temp = testDir.toFile(); DataSet ds = new IrisDataSetIterator(150, 150).next(); List list = ds.asList(); Collections.shuffle(list, new Random(12345)); @@ -327,11 +330,11 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test @Ignore - public void testEpochUpdating() throws Exception { + @Test @Disabled + public void testEpochUpdating(@TempDir Path testDir) throws Exception { //Ensure that epoch counter is incremented properly on the workers - File temp = testDir.newFolder(); + File temp = testDir.toFile(); //TODO this probably won't work everywhere... String controller = Inet4Address.getLocalHost().getHostAddress(); @@ -370,7 +373,8 @@ public class GradientSharingTrainingTest extends BaseSparkTest { int count = 0; List paths = new ArrayList<>(); List ds = new ArrayList<>(); - File f = testDir.newFolder(); + File f = new File(testDir.toFile(),"test-dir-1"); + f.mkdirs(); while (iter.hasNext() && count++ < 8) { DataSet d = iter.next(); File out = new File(f, count + ".bin"); @@ -386,7 +390,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { sparkNet.fitPaths(pathRdd); //Check also that threshold algorithm was updated/averaged ThresholdAlgorithm taAfter = tm.getThresholdAlgorithm(); - assertTrue("Threshold algorithm should have been updated with different instance after averaging", ta != taAfter); + assertTrue(ta != taAfter, "Threshold algorithm should have been updated with different instance after averaging"); AdaptiveThresholdAlgorithm ataAfter = (AdaptiveThresholdAlgorithm) taAfter; assertFalse(Double.isNaN(ataAfter.getLastSparsity())); assertFalse(Double.isNaN(ataAfter.getLastThreshold())); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index a48833e22..e00f8d6d3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -30,8 +30,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.After; -import org.junit.Before; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -61,7 +61,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable return 120000L; } - @Before + @BeforeEach public void before() { sc = getContext(); @@ -79,7 +79,7 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable sparkData = getBasicSparkDataSet(nRows, input, labels); } - @After + @AfterEach public void after() { if(sc != null) { sc.close(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java index 7a038fabd..ed8de3623 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java @@ -44,7 +44,7 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator; import org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -58,7 +58,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestEarlyStoppingSpark extends BaseSparkTest { @@ -191,8 +191,8 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { long endTime = System.currentTimeMillis(); int durationSeconds = (int) (endTime - startTime) / 1000; - assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3); - assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 20); + assertTrue(durationSeconds >= 3, "durationSeconds = " + durationSeconds); + assertTrue(durationSeconds <= 20, "durationSeconds = " + durationSeconds); assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason()); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java index ac25bbc92..3de17a742 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java @@ -46,7 +46,7 @@ import org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer; import org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph; import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -60,7 +60,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java index a041568fb..33023d605 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.*; import org.nd4j.evaluation.regression.RegressionEvaluation; @@ -47,7 +47,7 @@ import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestKryo extends BaseSparkKryoTest { @@ -56,7 +56,7 @@ public class TestKryo extends BaseSparkKryoTest { T deserialized = (T)si.deserialize(bb, null); boolean equals = in.equals(deserialized); - assertTrue(in.getClass() + "\t" + in.toString(), equals); + assertTrue(equals, in.getClass() + "\t" + in.toString()); } @Test diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java index 384c89926..f366de5b4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java @@ -23,14 +23,14 @@ package org.deeplearning4j.spark.common; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.impl.common.Add; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class AddTest extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java index 36744425f..f879cfd29 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java @@ -25,7 +25,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.util.SparkUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -35,8 +35,8 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestShuffleExamples extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java index 927d14508..4e9f12dd9 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java @@ -21,7 +21,7 @@ package org.deeplearning4j.spark.data; import org.deeplearning4j.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class TestSparkDataUtils extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java index fde796b86..43c50fdeb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java @@ -26,7 +26,7 @@ import org.datavec.api.conf.Configuration; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader; import org.deeplearning4j.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.common.io.ClassPathResource; import org.slf4j.Logger; @@ -34,8 +34,8 @@ import org.slf4j.LoggerFactory; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MiniBatchTests extends BaseSparkTest { private static final Logger log = LoggerFactory.getLogger(MiniBatchTests.class); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java index bebeaca56..e8153debc 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java @@ -45,9 +45,10 @@ import org.datavec.spark.util.DataVecSparkUtil; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.spark.BaseSparkTest; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -60,22 +61,21 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestDataVecDataSetFunctions extends BaseSparkTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testDataVecDataSetFunction() throws Exception { + public void testDataVecDataSetFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } JavaSparkContext sc = getContext(); - File f = testDir.newFolder(); + File f = testDir.toFile(); ClassPathResource cpr = new ClassPathResource("dl4j-spark/imagetest/"); cpr.copyDirectory(f); @@ -182,14 +182,14 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { } @Test - public void testDataVecSequenceDataSetFunction() throws Exception { + public void testDataVecSequenceDataSetFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } JavaSparkContext sc = getContext(); //Test Spark record reader functionality vs. local - File dir = testDir.newFolder(); + File dir = testDir.toFile(); ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); cpr.copyDirectory(dir); @@ -244,14 +244,14 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { } @Test - public void testDataVecSequencePairDataSetFunction() throws Exception { + public void testDataVecSequencePairDataSetFunction(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } JavaSparkContext sc = getContext(); - File f = testDir.newFolder(); + File f = new File(testDir.toFile(),"f"); ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); cpr.copyDirectory(f); String path = f.getAbsolutePath() + "/*"; @@ -260,7 +260,7 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { JavaPairRDD toWrite = DataVecSparkUtil.combineFilesForSequenceFile(sc, path, path, pathConverter); - Path p = testDir.newFolder("dl4j_testSeqPairFn").toPath(); + Path p = new File(testDir.toFile(),"dl4j_testSeqPairFn").toPath(); p.toFile().deleteOnExit(); String outPath = p.toString() + "/out"; new File(outPath).deleteOnExit(); @@ -343,17 +343,17 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { } @Test - public void testDataVecSequencePairDataSetFunctionVariableLength() throws Exception { + public void testDataVecSequencePairDataSetFunctionVariableLength(@TempDir Path testDir) throws Exception { //Same sort of test as testDataVecSequencePairDataSetFunction() but with variable length time series (labels shorter, align end) if(Platform.isWindows()) { //Spark tests don't run on windows return; } - File dirFeatures = testDir.newFolder(); + File dirFeatures = new File(testDir.toFile(),"dirFeatures"); ClassPathResource cpr = new ClassPathResource("dl4j-spark/csvsequence/"); cpr.copyDirectory(dirFeatures); - File dirLabels = testDir.newFolder(); + File dirLabels = new File(testDir.toFile(),"dirLables"); ClassPathResource cpr2 = new ClassPathResource("dl4j-spark/csvsequencelabels/"); cpr2.copyDirectory(dirLabels); @@ -362,7 +362,7 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { JavaPairRDD toWrite = DataVecSparkUtil.combineFilesForSequenceFile(sc, dirFeatures.getAbsolutePath(), dirLabels.getAbsolutePath(), pathConverter); - Path p = testDir.newFolder("dl4j_testSeqPairFnVarLength").toPath(); + Path p = new File(testDir.toFile(),"dl4j_testSeqPairFnVarLength").toPath(); p.toFile().deleteOnExit(); String outPath = p.toFile().getAbsolutePath() + "/out"; new File(outPath).deleteOnExit(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java index 8c8cb3224..b9eef9113 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction; import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -38,8 +38,8 @@ import java.util.Collections; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class TestExport extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java index 0ffe63a1f..714c3ffb6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java @@ -41,7 +41,7 @@ import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -56,8 +56,8 @@ import java.net.URI; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestPreProcessedData extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java index 32669c3e6..30ce34c6b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java @@ -28,7 +28,7 @@ import org.datavec.api.writable.Writable; import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.deeplearning4j.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.common.io.ClassPathResource; @@ -36,7 +36,7 @@ import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestIteratorUtils extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java index 43c7c6f4d..9bc828abb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java @@ -30,8 +30,8 @@ import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; public class TestKryoWarning { @@ -70,7 +70,7 @@ public class TestKryoWarning { } @Test - @Ignore + @Disabled public void testKryoMessageMLNIncorrectConfig() { //Should print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") @@ -81,7 +81,7 @@ public class TestKryoWarning { } @Test - @Ignore + @Disabled public void testKryoMessageMLNCorrectConfigKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") @@ -93,7 +93,7 @@ public class TestKryoWarning { } @Test - @Ignore + @Disabled public void testKryoMessageMLNCorrectConfigNoKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]") @@ -106,7 +106,7 @@ public class TestKryoWarning { @Test - @Ignore + @Disabled public void testKryoMessageCGIncorrectConfig() { //Should print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") @@ -117,7 +117,7 @@ public class TestKryoWarning { } @Test - @Ignore + @Disabled public void testKryoMessageCGCorrectConfigKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") @@ -129,7 +129,7 @@ public class TestKryoWarning { } @Test - @Ignore + @Disabled public void testKryoMessageCGCorrectConfigNoKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]") diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java index 70a3ed4b8..1bab0defe 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java @@ -20,9 +20,9 @@ package org.deeplearning4j.spark.impl.common.repartition; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class BalancedPartitionerTest { @@ -34,7 +34,7 @@ public class BalancedPartitionerTest { // the 10 first elements should go in the 1st partition for (int i = 0; i < 10; i++) { int p = bp.getPartition(i); - assertEquals("Found wrong partition output " + p + ", not 0", 0, p); + assertEquals(0, p,"Found wrong partition output " + p + ", not 0"); } } @@ -44,7 +44,7 @@ public class BalancedPartitionerTest { // the 10 first elements should go in the 1st partition for (int i = 0; i < 10; i++) { int p = bp.getPartition(i); - assertEquals("Found wrong partition output " + p + ", not 0", 0, p); + assertEquals( 0, p,"Found wrong partition output " + p + ", not 0"); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java index 7a87c3868..74e8f03be 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java @@ -27,12 +27,12 @@ import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.impl.common.repartition.HashingBalancedPartitioner.LinearCongruentialGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import scala.Tuple2; import java.util.*; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class HashingBalancedPartitionerTest extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java index ae89e44b3..b3c96333d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java @@ -30,7 +30,7 @@ import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.impl.customlayer.layer.CustomLayer; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index 9bd944c8d..f003b5171 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -47,8 +47,9 @@ import org.deeplearning4j.spark.api.RDDTrainingApproach; import org.deeplearning4j.spark.api.Repartition; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; @@ -69,9 +70,9 @@ import scala.Tuple2; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore("AB 2019/05/24 - Rarely getting stuck on CI - see issue #7657") +@Disabled("AB 2019/05/24 - Rarely getting stuck on CI - see issue #7657") public class TestSparkComputationGraph extends BaseSparkTest { public static ComputationGraph getBasicNetIris2Class() { @@ -213,7 +214,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { } } - @Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") + @Disabled("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") public void testSeedRepeatability() throws Exception { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.RMSPROP) @@ -281,12 +282,13 @@ public class TestSparkComputationGraph extends BaseSparkTest { boolean eq1 = p1.equalsWithEps(p2, 0.01); boolean eq2 = p1.equalsWithEps(p3, 0.01); - assertTrue("Model 1 and 2 params should be equal", eq1); - assertFalse("Model 1 and 3 params shoud be different", eq2); + assertTrue(eq1, "Model 1 and 2 params should be equal"); + assertFalse(eq2, "Model 1 and 3 params shoud be different"); } - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testEvaluationAndRoc() { for( int evalWorkers : new int[]{1, 4, 8}) { DataSetIterator iter = new IrisDataSetIterator(5, 150); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index 05eadfc2d..887696af3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -33,7 +33,7 @@ import org.deeplearning4j.spark.api.RDDTrainingApproach; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -46,7 +46,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestFrozenLayers extends BaseSparkTest { @@ -119,10 +119,10 @@ public class TestFrozenLayers extends BaseSparkTest { if (isFrozen) { //Layer should be frozen -> no change - assertEquals(entry.getKey(), orig, now); + assertEquals(orig, now, entry.getKey()); } else { //Not frozen -> should be different - assertNotEquals(entry.getKey(), orig, now); + assertNotEquals(orig, now, entry.getKey()); } } } @@ -196,10 +196,10 @@ public class TestFrozenLayers extends BaseSparkTest { if (isFrozen) { //Layer should be frozen -> no change - assertEquals(entry.getKey(), orig, now); + assertEquals(orig, now, entry.getKey()); } else { //Not frozen -> should be different - assertNotEquals(entry.getKey(), orig, now); + assertNotEquals(orig, now, entry.getKey()); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java index 5881b5f41..550ccc9b2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java @@ -36,7 +36,7 @@ import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.scoring.VaeReconstructionErrorWithKeyFunction; import org.deeplearning4j.spark.impl.multilayer.scoring.VaeReconstructionProbWithKeyFunction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -49,8 +49,8 @@ import scala.Tuple2; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestMiscFunctions extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index 19a024d49..c64618557 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -47,8 +47,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; import java.util.List; -import static junit.framework.TestCase.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class TestSparkDl4jMultiLayer extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index 673ff05c4..cbe7247bd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -39,8 +39,8 @@ import org.deeplearning4j.spark.api.RDDTrainingApproach; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -54,10 +54,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestCompareParameterAveragingSparkVsSingleMachine { - @Before + @BeforeEach public void setUp() { //CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(false); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java index 3cd2056f5..64c984ad7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java @@ -22,9 +22,9 @@ package org.deeplearning4j.spark.impl.paramavg; import org.apache.spark.storage.StorageLevel; import org.deeplearning4j.spark.api.TrainingMaster; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestJsonYaml { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index a266b9809..d4a73020f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -55,10 +55,12 @@ import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.stats.EventStats; import org.deeplearning4j.spark.stats.ExampleCountEventStats; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCMultiClass; @@ -81,7 +83,7 @@ import java.io.File; import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @@ -93,8 +95,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Override @@ -427,12 +428,12 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test - public void testFitViaStringPaths() throws Exception { + public void testFitViaStringPaths(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } - Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath(); + Path tempDir = new File(testDir.toFile(),"DL4J-testFitViaStringPaths").toPath(); File tempDirF = tempDir.toFile(); tempDirF.deleteOnExit(); @@ -494,12 +495,12 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } @Test - public void testFitViaStringPathsSize1() throws Exception { + public void testFitViaStringPathsSize1(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } - Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath(); + Path tempDir = new File(testDir.toFile(),"DL4J-testFitViaStringPathsSize1").toPath(); File tempDirF = tempDir.toFile(); tempDirF.deleteOnExit(); @@ -578,13 +579,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test - public void testFitViaStringPathsCompGraph() throws Exception { + public void testFitViaStringPathsCompGraph(@TempDir Path testDir) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } - Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath(); - Path tempDir2 = testDir.newFolder("DL4J-testFitViaStringPathsCG-MDS").toPath(); + Path tempDir = new File(testDir.toFile(),"DL4J-testFitViaStringPathsCG").toPath(); + Path tempDir2 = new File(testDir.toFile(),"DL4J-testFitViaStringPathsCG-MDS").toPath(); File tempDirF = tempDir.toFile(); File tempDirF2 = tempDir2.toFile(); tempDirF.deleteOnExit(); @@ -676,7 +677,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test - @Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") + @Disabled("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") public void testSeedRepeatability() throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows @@ -746,8 +747,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { boolean eq1 = p1.equalsWithEps(p2, 0.01); boolean eq2 = p1.equalsWithEps(p3, 0.01); - assertTrue("Model 1 and 2 params should be equal", eq1); - assertFalse("Model 1 and 3 params shoud be different", eq2); + assertTrue(eq1, "Model 1 and 2 params should be equal"); + assertFalse(eq2, "Model 1 and 3 params shoud be different"); } @@ -852,7 +853,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test - @Ignore //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 + @Disabled //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 public void testVaePretrainSimple() { //Simple sanity check on pretraining int nIn = 8; @@ -888,7 +889,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } @Test - @Ignore //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 + @Disabled //Ignored 2019/04/09 - low priority: https://github.com/eclipse/deeplearning4j/issues/6656 public void testVaePretrainSimpleCG() { //Simple sanity check on pretraining int nIn = 8; @@ -1036,7 +1037,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test() + @Timeout(120000) public void testEpochCounter() throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java index 20ceb2540..0fdeaaabf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java @@ -24,14 +24,14 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Ede Meijer diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java index 78ab9a229..a287eb836 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java @@ -40,7 +40,7 @@ import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMa import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats; import org.deeplearning4j.spark.stats.EventStats; import org.deeplearning4j.spark.stats.StatsUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -49,9 +49,7 @@ import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; import java.util.*; -import static junit.framework.TestCase.assertNotNull; -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestTrainingStatsCollection extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java index 62462b802..85a73aab4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java @@ -20,10 +20,10 @@ package org.deeplearning4j.spark.time; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestTimeSource { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java index dc7b64a68..6f79d7595 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java @@ -38,7 +38,7 @@ import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.mapdb.MapDBStatsStorage; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -46,8 +46,8 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestListeners extends BaseSparkTest { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java index 1aacc222d..ef7c0788f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java @@ -27,7 +27,7 @@ import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.regression.LabeledPoint; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.spark.BaseSparkTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -38,7 +38,7 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MLLIbUtilTest extends BaseSparkTest { private static final Logger log = LoggerFactory.getLogger(MLLIbUtilTest.class); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java index 8bc9d442c..75653f9ad 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java @@ -29,8 +29,7 @@ import org.deeplearning4j.spark.api.Repartition; import org.deeplearning4j.spark.api.RepartitionStrategy; import org.deeplearning4j.spark.impl.common.CountPartitionsFunction; import org.deeplearning4j.spark.impl.repartitioner.DefaultRepartitioner; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import scala.Tuple2; import java.util.ArrayList; @@ -38,9 +37,7 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static junit.framework.TestCase.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class TestRepartitioning extends BaseSparkTest { @@ -192,7 +189,7 @@ public class TestRepartitioning extends BaseSparkTest { new Tuple2<>(4,34), new Tuple2<>(5,35), new Tuple2<>(6,34)); - Assert.assertEquals(initialExpected, partitionCounts); + assertEquals(initialExpected, partitionCounts); JavaRDD afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java index d5a81d0ef..3ff88365e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java @@ -25,33 +25,32 @@ import org.apache.commons.io.FileUtils; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.util.data.SparkDataValidation; import org.deeplearning4j.spark.util.data.ValidationResult; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import java.io.File; +import java.nio.file.Path; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestValidation extends BaseSparkTest { - @Rule - public TemporaryFolder folder = new TemporaryFolder(); - @Test - public void testDataSetValidation() throws Exception { + public void testDataSetValidation(@TempDir Path folder) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } - File f = folder.newFolder(); + File f = folder.toFile(); for( int i = 0; i < 3; i++ ) { DataSet ds = new DataSet(Nd4j.create(1,10), Nd4j.create(1,10)); @@ -113,12 +112,12 @@ public class TestValidation extends BaseSparkTest { } @Test - public void testMultiDataSetValidation() throws Exception { + public void testMultiDataSetValidation(@TempDir Path folder) throws Exception { if(Platform.isWindows()) { //Spark tests don't run on windows return; } - File f = folder.newFolder(); + File f = folder.toFile(); for( int i = 0; i < 3; i++ ) { MultiDataSet ds = new MultiDataSet(Nd4j.create(1,10), Nd4j.create(1,10)); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestComponentSerialization.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestComponentSerialization.java index 3c1b34e15..6596c526c 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestComponentSerialization.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestComponentSerialization.java @@ -34,14 +34,14 @@ import org.deeplearning4j.ui.components.table.ComponentTable; import org.deeplearning4j.ui.components.table.style.StyleTable; import org.deeplearning4j.ui.components.text.ComponentText; import org.deeplearning4j.ui.components.text.style.StyleText; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.shade.jackson.databind.ObjectMapper; import java.awt.*; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestComponentSerialization extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestRendering.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestRendering.java index c3ad694aa..b72c0b4e2 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestRendering.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestRendering.java @@ -35,8 +35,8 @@ import org.deeplearning4j.ui.components.table.ComponentTable; import org.deeplearning4j.ui.components.table.style.StyleTable; import org.deeplearning4j.ui.components.text.ComponentText; import org.deeplearning4j.ui.components.text.style.StyleText; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.shade.jackson.databind.ObjectMapper; import java.awt.*; @@ -48,7 +48,7 @@ import java.util.Random; public class TestRendering extends BaseDL4JTest { - @Ignore + @Disabled @Test public void test() throws Exception { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java index 95bd2df96..2073d806f 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/src/test/java/org/deeplearning4j/ui/TestStandAlone.java @@ -28,8 +28,8 @@ import org.deeplearning4j.ui.components.chart.style.StyleChart; import org.deeplearning4j.ui.components.table.ComponentTable; import org.deeplearning4j.ui.components.table.style.StyleTable; import org.deeplearning4j.ui.standalone.StaticPageUtil; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import java.awt.*; diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/TestStorageMetaData.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/TestStorageMetaData.java index 1e3a17056..e31fc1541 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/TestStorageMetaData.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/TestStorageMetaData.java @@ -23,11 +23,11 @@ package org.deeplearning4j.ui; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.ui.model.storage.impl.SbeStorageMetaData; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.io.Serializable; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestStorageMetaData extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsClasses.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsClasses.java index ec775b4d5..0d9f60a24 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsClasses.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsClasses.java @@ -25,8 +25,7 @@ import org.deeplearning4j.ui.model.stats.api.*; import org.deeplearning4j.ui.model.stats.impl.SbeStatsInitializationReport; import org.deeplearning4j.ui.model.stats.impl.SbeStatsReport; import org.deeplearning4j.ui.model.stats.impl.java.JavaStatsInitializationReport; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Pair; import java.io.*; @@ -35,7 +34,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestStatsClasses extends BaseDL4JTest { @@ -547,9 +546,9 @@ public class TestStatsClasses extends BaseDL4JTest { assertEquals(perfTotalMB, report2.getTotalMinibatches()); assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0); assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0); - Assert.assertTrue(report2.hasPerformance()); + assertTrue(report2.hasPerformance()); } else { - Assert.assertFalse(report2.hasPerformance()); + assertFalse(report2.hasPerformance()); } if (collectMemoryStats) { @@ -560,30 +559,30 @@ public class TestStatsClasses extends BaseDL4JTest { assertArrayEquals(memDC, report2.getDeviceCurrentBytes()); assertArrayEquals(memDM, report2.getDeviceMaxBytes()); - Assert.assertTrue(report2.hasMemoryUse()); + assertTrue(report2.hasMemoryUse()); } else { - Assert.assertFalse(report2.hasMemoryUse()); + assertFalse(report2.hasMemoryUse()); } if (collectGCStats) { List> gcs = report2.getGarbageCollectionStats(); - Assert.assertEquals(2, gcs.size()); - Assert.assertEquals(gc1Name, gcs.get(0).getFirst()); - Assert.assertArrayEquals(new int[] {gcdc1, gcdt1}, + assertEquals(2, gcs.size()); + assertEquals(gc1Name, gcs.get(0).getFirst()); + assertArrayEquals(new int[] {gcdc1, gcdt1}, gcs.get(0).getSecond()); - Assert.assertEquals(gc2Name, gcs.get(1).getFirst()); - Assert.assertArrayEquals(new int[] {gcdc2, gcdt2}, + assertEquals(gc2Name, gcs.get(1).getFirst()); + assertArrayEquals(new int[] {gcdc2, gcdt2}, gcs.get(1).getSecond()); - Assert.assertTrue(report2.hasGarbageCollection()); + assertTrue(report2.hasGarbageCollection()); } else { - Assert.assertFalse(report2.hasGarbageCollection()); + assertFalse(report2.hasGarbageCollection()); } if (collectScore) { assertEquals(score, report2.getScore(), 0.0); - Assert.assertTrue(report2.hasScore()); + assertTrue(report2.hasScore()); } else { - Assert.assertFalse(report2.hasScore()); + assertFalse(report2.hasScore()); } if (collectLearningRates) { @@ -592,9 +591,9 @@ public class TestStatsClasses extends BaseDL4JTest { assertEquals(lrByParam.get(s), report2.getLearningRates().get(s), 1e-6); } - Assert.assertTrue(report2.hasLearningRates()); + assertTrue(report2.hasLearningRates()); } else { - Assert.assertFalse(report2.hasLearningRates()); + assertFalse(report2.hasLearningRates()); } if (collectMetaData) { @@ -609,112 +608,112 @@ public class TestStatsClasses extends BaseDL4JTest { if (collectHistograms[0]) { assertEquals(pHist, report2.getHistograms(StatsType.Parameters)); - Assert.assertTrue(report2.hasHistograms(StatsType.Parameters)); + assertTrue(report2.hasHistograms(StatsType.Parameters)); } else { - Assert.assertFalse(report2.hasHistograms(StatsType.Parameters)); + assertFalse(report2.hasHistograms(StatsType.Parameters)); } if (collectHistograms[1]) { assertEquals(gHist, report2.getHistograms(StatsType.Gradients)); - Assert.assertTrue(report2.hasHistograms(StatsType.Gradients)); + assertTrue(report2.hasHistograms(StatsType.Gradients)); } else { - Assert.assertFalse(report2.hasHistograms(StatsType.Gradients)); + assertFalse(report2.hasHistograms(StatsType.Gradients)); } if (collectHistograms[2]) { assertEquals(uHist, report2.getHistograms(StatsType.Updates)); - Assert.assertTrue(report2.hasHistograms(StatsType.Updates)); + assertTrue(report2.hasHistograms(StatsType.Updates)); } else { - Assert.assertFalse(report2.hasHistograms(StatsType.Updates)); + assertFalse(report2.hasHistograms(StatsType.Updates)); } if (collectHistograms[3]) { assertEquals(aHist, report2.getHistograms(StatsType.Activations)); - Assert.assertTrue(report2.hasHistograms(StatsType.Activations)); + assertTrue(report2.hasHistograms(StatsType.Activations)); } else { - Assert.assertFalse(report2.hasHistograms(StatsType.Activations)); + assertFalse(report2.hasHistograms(StatsType.Activations)); } if (collectMeanStdev[0]) { assertEquals(pMean, report2.getMean(StatsType.Parameters)); assertEquals(pStd, report2.getStdev(StatsType.Parameters)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, + assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, + assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev)); } if (collectMeanStdev[1]) { assertEquals(gMean, report2.getMean(StatsType.Gradients)); assertEquals(gStd, report2.getStdev(StatsType.Gradients)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, + assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, + assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev)); } if (collectMeanStdev[2]) { assertEquals(uMean, report2.getMean(StatsType.Updates)); assertEquals(uStd, report2.getStdev(StatsType.Updates)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, + assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, + assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev)); } if (collectMeanStdev[3]) { assertEquals(aMean, report2.getMean(StatsType.Activations)); assertEquals(aStd, report2.getStdev(StatsType.Activations)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, + assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, + assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev)); } if (collectMM[0]) { assertEquals(pMM, report2.getMeanMagnitudes(StatsType.Parameters)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, + assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)); } if (collectMM[1]) { assertEquals(gMM, report2.getMeanMagnitudes(StatsType.Gradients)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, + assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes)); } if (collectMM[2]) { assertEquals(uMM, report2.getMeanMagnitudes(StatsType.Updates)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, + assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)); } if (collectMM[3]) { assertEquals(aMM, report2.getMeanMagnitudes(StatsType.Activations)); - Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, + assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes)); } else { - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes)); } @@ -742,7 +741,7 @@ public class TestStatsClasses extends BaseDL4JTest { } } - Assert.assertEquals(13824, testCount); + assertEquals(13824, testCount); } @Test @@ -903,9 +902,9 @@ public class TestStatsClasses extends BaseDL4JTest { assertEquals(perfTotalMB, report2.getTotalMinibatches()); assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0); assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0); - Assert.assertTrue(report2.hasPerformance()); + assertTrue(report2.hasPerformance()); } else { - Assert.assertFalse(report2.hasPerformance()); + assertFalse(report2.hasPerformance()); } if (collectMemoryStats) { @@ -916,23 +915,23 @@ public class TestStatsClasses extends BaseDL4JTest { assertArrayEquals(memDC, report2.getDeviceCurrentBytes()); assertArrayEquals(memDM, report2.getDeviceMaxBytes()); - Assert.assertTrue(report2.hasMemoryUse()); + assertTrue(report2.hasMemoryUse()); } else { - Assert.assertFalse(report2.hasMemoryUse()); + assertFalse(report2.hasMemoryUse()); } if (collectGCStats) { List> gcs = report2.getGarbageCollectionStats(); - Assert.assertEquals(2, gcs.size()); + assertEquals(2, gcs.size()); assertNullOrZeroLength(gcs.get(0).getFirst()); - Assert.assertArrayEquals(new int[] {gcdc1, gcdt1}, + assertArrayEquals(new int[] {gcdc1, gcdt1}, gcs.get(0).getSecond()); assertNullOrZeroLength(gcs.get(1).getFirst()); - Assert.assertArrayEquals(new int[] {gcdc2, gcdt2}, + assertArrayEquals(new int[] {gcdc2, gcdt2}, gcs.get(1).getSecond()); - Assert.assertTrue(report2.hasGarbageCollection()); + assertTrue(report2.hasGarbageCollection()); } else { - Assert.assertFalse(report2.hasGarbageCollection()); + assertFalse(report2.hasGarbageCollection()); } if (collectDataSetMetaData) { @@ -941,71 +940,71 @@ public class TestStatsClasses extends BaseDL4JTest { if (collectScore) { assertEquals(score, report2.getScore(), 0.0); - Assert.assertTrue(report2.hasScore()); + assertTrue(report2.hasScore()); } else { - Assert.assertFalse(report2.hasScore()); + assertFalse(report2.hasScore()); } if (collectLearningRates) { assertNull(report2.getLearningRates()); } else { - Assert.assertFalse(report2.hasLearningRates()); + assertFalse(report2.hasLearningRates()); } assertNull(report2.getHistograms(StatsType.Parameters)); - Assert.assertFalse(report2.hasHistograms(StatsType.Parameters)); + assertFalse(report2.hasHistograms(StatsType.Parameters)); assertNull(report2.getHistograms(StatsType.Gradients)); - Assert.assertFalse(report2.hasHistograms(StatsType.Gradients)); + assertFalse(report2.hasHistograms(StatsType.Gradients)); assertNull(report2.getHistograms(StatsType.Updates)); - Assert.assertFalse(report2.hasHistograms(StatsType.Updates)); + assertFalse(report2.hasHistograms(StatsType.Updates)); assertNull(report2.getHistograms(StatsType.Activations)); - Assert.assertFalse(report2.hasHistograms(StatsType.Activations)); + assertFalse(report2.hasHistograms(StatsType.Activations)); assertNull(report2.getMean(StatsType.Parameters)); assertNull(report2.getStdev(StatsType.Parameters)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev)); assertNull(report2.getMean(StatsType.Gradients)); assertNull(report2.getStdev(StatsType.Gradients)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev)); assertNull(report2.getMean(StatsType.Updates)); assertNull(report2.getStdev(StatsType.Updates)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev)); assertNull(report2.getMean(StatsType.Activations)); assertNull(report2.getStdev(StatsType.Activations)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev)); assertNull(report2.getMeanMagnitudes(StatsType.Parameters)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, + assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)); assertNull(report2.getMeanMagnitudes(StatsType.Gradients)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, + assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes)); assertNull(report2.getMeanMagnitudes(StatsType.Updates)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, + assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)); assertNull(report2.getMeanMagnitudes(StatsType.Activations)); - Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, + assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes)); //Check standard Java serialization @@ -1032,7 +1031,7 @@ public class TestStatsClasses extends BaseDL4JTest { } } - Assert.assertEquals(13824, testCount); + assertEquals(13824, testCount); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java index 61b659d53..56952d870 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java @@ -32,15 +32,15 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.ui.model.stats.J7StatsListener; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.mapdb.MapDBStatsStorage; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class TestStatsListener extends BaseDL4JTest { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index 3117ff835..b32c4f143 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -31,9 +31,9 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.FileStatsStorage; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/storage/TestStatsStorage.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/storage/TestStatsStorage.java index 8a10300c5..18d819a80 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/storage/TestStatsStorage.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/storage/TestStatsStorage.java @@ -36,26 +36,27 @@ import org.deeplearning4j.ui.model.stats.impl.java.JavaStatsReport; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.model.storage.mapdb.MapDBStatsStorage; import org.deeplearning4j.ui.model.storage.sqlite.J7FileStatsStorage; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestStatsStorage extends BaseDL4JTest { - @Rule - public final TemporaryFolder testDir = new TemporaryFolder(); + @Test - @Ignore("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 only - Issue #7657") - public void testStatsStorage() throws IOException { + @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 only - Issue #7657") + public void testStatsStorage(@TempDir Path testDir) throws IOException { for (boolean useJ7Storage : new boolean[] {false, true}) { for (int i = 0; i < 3; i++) { @@ -63,12 +64,12 @@ public class TestStatsStorage extends BaseDL4JTest { StatsStorage ss; switch (i) { case 0: - File f = createTempFile("TestMapDbStatsStore", ".db"); + File f = createTempFile(testDir,"TestMapDbStatsStore", ".db"); f.delete(); //Don't want file to exist... ss = new MapDBStatsStorage.Builder().file(f).build(); break; case 1: - File f2 = createTempFile("TestJ7FileStatsStore", ".db"); + File f2 = createTempFile(testDir,"TestJ7FileStatsStore", ".db"); f2.delete(); //Don't want file to exist... ss = new J7FileStatsStorage(f2); break; @@ -120,7 +121,7 @@ public class TestStatsStorage extends BaseDL4JTest { assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0")); assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0")); assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), - ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0)); + ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0)); assertEquals(1, ss.getNumUpdateRecordsFor("sid0")); assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0")); @@ -158,17 +159,17 @@ public class TestStatsStorage extends BaseDL4JTest { ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage)); assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), - ss.getLatestUpdateAllWorkers("sid100", "tid200")); + ss.getLatestUpdateAllWorkers("sid100", "tid200")); assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100")); List temp = ss.listWorkerIDsForSession("sid100"); System.out.println("temp: " + temp); assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100")); assertEquals(Collections.singletonList("wid300"), - ss.listWorkerIDsForSessionAndType("sid100", "tid200")); + ss.listWorkerIDsForSessionAndType("sid100", "tid200")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getLatestUpdate("sid100", "tid200", "wid300")); + ss.getLatestUpdate("sid100", "tid200", "wid300")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getUpdate("sid100", "tid200", "wid300", 12346)); + ss.getUpdate("sid100", "tid200", "wid300", 12346)); assertEquals(2, l.countNewSession); assertEquals(3, l.countNewWorkerId); @@ -209,16 +210,16 @@ public class TestStatsStorage extends BaseDL4JTest { @Test - @Ignore("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 only - Issue #7657") - public void testFileStatsStore() throws IOException { + @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 only - Issue #7657") + public void testFileStatsStore(@TempDir Path testDir) throws IOException { for (boolean useJ7Storage : new boolean[] {false, true}) { for (int i = 0; i < 2; i++) { File f; if (i == 0) { - f = createTempFile("TestMapDbStatsStore", ".db"); + f = createTempFile(testDir,"TestMapDbStatsStore", ".db"); } else { - f = createTempFile("TestSqliteStatsStore", ".db"); + f = createTempFile(testDir,"TestSqliteStatsStore", ".db"); } f.delete(); //Don't want file to exist... @@ -270,7 +271,7 @@ public class TestStatsStorage extends BaseDL4JTest { assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0")); assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0")); assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), - ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0)); + ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0)); assertEquals(1, ss.getNumUpdateRecordsFor("sid0")); assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0")); @@ -308,17 +309,17 @@ public class TestStatsStorage extends BaseDL4JTest { ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage)); assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), - ss.getLatestUpdateAllWorkers("sid100", "tid200")); + ss.getLatestUpdateAllWorkers("sid100", "tid200")); assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100")); List temp = ss.listWorkerIDsForSession("sid100"); System.out.println("temp: " + temp); assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100")); assertEquals(Collections.singletonList("wid300"), - ss.listWorkerIDsForSessionAndType("sid100", "tid200")); + ss.listWorkerIDsForSessionAndType("sid100", "tid200")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getLatestUpdate("sid100", "tid200", "wid300")); + ss.getLatestUpdate("sid100", "tid200", "wid300")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getUpdate("sid100", "tid200", "wid300", 12346)); + ss.getUpdate("sid100", "tid200", "wid300", 12346)); assertEquals(2, l.countNewSession); assertEquals(3, l.countNewWorkerId); @@ -349,11 +350,11 @@ public class TestStatsStorage extends BaseDL4JTest { assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100")); assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100")); assertEquals(Collections.singletonList("wid300"), - ss.listWorkerIDsForSessionAndType("sid100", "tid200")); + ss.listWorkerIDsForSessionAndType("sid100", "tid200")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getLatestUpdate("sid100", "tid200", "wid300")); + ss.getLatestUpdate("sid100", "tid200", "wid300")); assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), - ss.getUpdate("sid100", "tid200", "wid300", 12346)); + ss.getUpdate("sid100", "tid200", "wid300", 12346)); } } } @@ -373,7 +374,7 @@ public class TestStatsStorage extends BaseDL4JTest { envInfo.put("envInfo0", "value0"); envInfo.put("envInfo1", "value1"); rep.reportSoftwareInfo("arch", "osName", "jvmName", "jvmVersion", "1.8", "backend", "dtype", "hostname", - "jvmuid", envInfo); + "jvmuid", envInfo); return rep; } @@ -428,8 +429,11 @@ public class TestStatsStorage extends BaseDL4JTest { } } - private File createTempFile(String prefix, String suffix) throws IOException { - return testDir.newFile(prefix + "-" + System.nanoTime() + suffix); + private File createTempFile(Path testDir, String prefix, String suffix) throws IOException { + File newFile = new File(testDir.toFile(),prefix + "-" + System.nanoTime() + suffix); + newFile.createNewFile(); + newFile.deleteOnExit(); + return newFile; } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java index 2ab861b76..8ca018ef8 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java @@ -38,8 +38,8 @@ import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.stats.impl.SbeStatsInitializationReport; import org.deeplearning4j.ui.model.stats.impl.SbeStatsReport; import org.deeplearning4j.ui.model.storage.impl.SbeStorageMetaData; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -49,13 +49,13 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore +@Disabled public class TestRemoteReceiver extends BaseDL4JTest { @Test - @Ignore + @Disabled public void testRemoteBasic() throws Exception { List updates = new ArrayList<>(); @@ -123,7 +123,7 @@ public class TestRemoteReceiver extends BaseDL4JTest { @Test - @Ignore + @Disabled public void testRemoteFull() throws Exception { //Use this in conjunction with startRemoteUI() @@ -150,7 +150,7 @@ public class TestRemoteReceiver extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void startRemoteUI() throws Exception { UIServer s = UIServer.getInstance(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java index d46e26818..fccd4bb5d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -20,35 +20,32 @@ package org.deeplearning4j.ui; -import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.ui.api.UIServer; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.File; +import java.nio.file.Path; import java.util.Arrays; -@Ignore -@Slf4j +@Disabled public class TestSameDiffUI extends BaseDL4JTest { + private static Logger log = LoggerFactory.getLogger(TestSameDiffUI.class.getName()); - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - - @Ignore + @Disabled @Test - public void testSameDiff() throws Exception { - File dir = testDir.newFolder(); + public void testSameDiff(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); File f = new File(dir, "ui_data.bin"); log.info("File path: {}", f.getAbsolutePath()); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index 526d8c25e..1f62def9a 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -23,7 +23,6 @@ package org.deeplearning4j.ui; import io.vertx.core.Future; import io.vertx.core.Promise; import io.vertx.core.Vertx; -import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StatsStorage; @@ -45,29 +44,31 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.common.function.Function; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Slf4j -@Ignore +@Disabled public class TestVertxUI extends BaseDL4JTest { + private static Logger log = LoggerFactory.getLogger(TestVertxUI.class.getName()); - @Before + + @BeforeEach public void setUp() throws Exception { UIServer.stopInstance(); } @@ -307,23 +308,26 @@ public class TestVertxUI extends BaseDL4JTest { uiServer.stop(); } - @Test (expected = DL4JException.class) + @Test () public void testUIStartPortAlreadyBound() throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - //Create HttpServer that binds the same port - int port = VertxUIServer.DEFAULT_UI_PORT; - Vertx vertx = Vertx.vertx(); - vertx.createHttpServer() - .requestHandler(event -> {}) - .listen(port, result -> latch.countDown()); - latch.await(); + assertThrows(DL4JException.class,() -> { + CountDownLatch latch = new CountDownLatch(1); + //Create HttpServer that binds the same port + int port = VertxUIServer.DEFAULT_UI_PORT; + Vertx vertx = Vertx.vertx(); + vertx.createHttpServer() + .requestHandler(event -> {}) + .listen(port, result -> latch.countDown()); + latch.await(); + + try { + //DL4JException signals that the port cannot be bound, UI server cannot start + UIServer.getInstance(); + } finally { + vertx.close(); + } + }); - try { - //DL4JException signals that the port cannot be bound, UI server cannot start - UIServer.getInstance(); - } finally { - vertx.close(); - } } @Test diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index 9d68a773b..f442135ad 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -38,13 +38,15 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.common.function.Function; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.UnsupportedEncodingException; import java.net.HttpURLConnection; @@ -53,19 +55,21 @@ import java.net.URLEncoder; import java.util.HashMap; import java.util.concurrent.CountDownLatch; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Slf4j -@Ignore +@Disabled public class TestVertxUIManual extends BaseDL4JTest { + private static Logger log = LoggerFactory.getLogger(TestVertxUIManual.class.getName()); + + @Override public long getTimeoutMilliseconds() { return 3600_000L; } @Test - @Ignore + @Disabled public void testUI() throws Exception { VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance(); assertEquals(9000, uiServer.getPort()); @@ -75,7 +79,7 @@ public class TestVertxUIManual extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void testUISequentialSessions() throws Exception { UIServer uiServer = UIServer.getInstance(); StatsStorage ss = null; @@ -118,7 +122,7 @@ public class TestVertxUIManual extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void testUIServerStop() throws Exception { UIServer uiServer = UIServer.getInstance(true, null); assertTrue(uiServer.isMultiSession()); @@ -144,7 +148,7 @@ public class TestVertxUIManual extends BaseDL4JTest { @Test - @Ignore + @Disabled public void testUIServerStopAsync() throws Exception { UIServer uiServer = UIServer.getInstance(true, null); assertTrue(uiServer.isMultiSession()); @@ -176,7 +180,7 @@ public class TestVertxUIManual extends BaseDL4JTest { } @Test - @Ignore + @Disabled public void testUIAutoAttachDetach() throws Exception { long detachTimeoutMillis = 15_000; AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index 32c92dd85..db1c7808d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -36,14 +36,16 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.common.function.Function; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.UnsupportedEncodingException; @@ -52,15 +54,16 @@ import java.net.URL; import java.net.URLEncoder; import java.util.HashMap; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Tamas Fenyvesi */ -@Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891 + @Disabled //https://github.com/eclipse/deeplearning4j/issues/8891 public class TestVertxUIMultiSession extends BaseDL4JTest { + private static Logger log = LoggerFactory.getLogger(TestVertxUIMultiSession.class.getName()); - @Before + @BeforeEach public void setUp() throws Exception { UIServer.stopInstance(); } @@ -184,19 +187,26 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { } } - @Test (expected = DL4JException.class) + @Test () public void testUIServerGetInstanceMultipleCalls1() { - UIServer uiServer = UIServer.getInstance(); - assertFalse(uiServer.isMultiSession()); - UIServer.getInstance(true, null); + assertThrows(DL4JException.class,() -> { + UIServer uiServer = UIServer.getInstance(); + assertFalse(uiServer.isMultiSession()); + UIServer.getInstance(true, null); + }); + + } - @Test (expected = DL4JException.class) + @Test () public void testUIServerGetInstanceMultipleCalls2() { - UIServer uiServer = UIServer.getInstance(true, null); - assertTrue(uiServer.isMultiSession()); - UIServer.getInstance(false, null); + assertThrows(DL4JException.class,() -> { + UIServer uiServer = UIServer.getInstance(true, null); + assertTrue(uiServer.isMultiSession()); + UIServer.getInstance(false, null); + }); + } /** diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java index 7354e0792..e1af69fab 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java @@ -26,15 +26,15 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.zoo.model.VGG16; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; -@Ignore("Times out too often") +@Disabled("Times out too often") public class MiscTests extends BaseDL4JTest { @Override diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java index 52a29df1f..5e768d6f3 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java @@ -31,38 +31,38 @@ import org.deeplearning4j.zoo.model.UNet; import org.deeplearning4j.zoo.util.darknet.COCOLabels; import org.deeplearning4j.zoo.util.darknet.DarknetLabels; import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.factory.Nd4j; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@Ignore("Times out too often") +@Disabled("Times out too often") public class TestDownload extends BaseDL4JTest { + @TempDir + static Path sharedTempDir; @Override public long getTimeoutMilliseconds() { return isIntegrationTests() ? 480000L : 60000L; } - @ClassRule - public static TemporaryFolder testDir = new TemporaryFolder(); - private static File f; - @BeforeClass + + @BeforeAll public static void before() throws Exception { - f = testDir.newFolder(); - DL4JResources.setBaseDirectory(f); + DL4JResources.setBaseDirectory(sharedTempDir.toFile()); } - @AfterClass + @AfterAll public static void after(){ DL4JResources.resetBaseDirectoryLocation(); } diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index 44c43047f..434d5e25d 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -37,8 +37,8 @@ import org.deeplearning4j.zoo.util.darknet.COCOLabels; import org.deeplearning4j.zoo.util.darknet.DarknetLabels; import org.deeplearning4j.zoo.util.darknet.VOCLabels; import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; @@ -50,11 +50,11 @@ import java.io.IOException; import java.util.List; import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2RGB; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@Ignore("Times out too often") +@Disabled("Times out too often") public class TestImageNet extends BaseDL4JTest { @Override @@ -92,7 +92,7 @@ public class TestImageNet extends BaseDL4JTest { } @Test - @Ignore("AB 2019/05/30 - Failing (intermittently?) on CI linux - see issue 7657") + @Disabled("AB 2019/05/30 - Failing (intermittently?) on CI linux - see issue 7657") public void testDarknetLabels() throws IOException { // set up model ZooModel model = Darknet19.builder().numClasses(0).build(); //num labels doesn't matter since we're getting pretrained imagenet diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java index 9548495e7..cfcd3fdf0 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java @@ -34,9 +34,9 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearningHelper; import org.deeplearning4j.zoo.model.*; import org.deeplearning4j.zoo.model.helper.DarknetHelper; -import org.junit.After; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -47,12 +47,12 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assume.assumeTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; @Slf4j -@Ignore("Times out too often") +@Disabled("Times out too often") public class TestInstantiation extends BaseDL4JTest { protected static void ignoreIfCuda(){ @@ -63,7 +63,7 @@ public class TestInstantiation extends BaseDL4JTest { } } - @After + @AfterEach public void after() throws Exception { Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); System.gc(); @@ -86,7 +86,7 @@ public class TestInstantiation extends BaseDL4JTest { runTest(TinyYOLO.builder().numClasses(10).build(), "TinyYOLO", 10); } - @Test @Ignore("AB 2019/05/28 - Crashing on CI linux-x86-64 CPU only - Issue #7657") + @Test @Disabled("AB 2019/05/28 - Crashing on CI linux-x86-64 CPU only - Issue #7657") public void testCnnTrainingYOLO2() throws Exception { runTest(YOLO2.builder().numClasses(10).build(), "YOLO2", 10); } @@ -162,12 +162,12 @@ public class TestInstantiation extends BaseDL4JTest { testInitPretrained(VGG19.builder().numClasses(0).build(), new long[]{1,3,224,224}, new long[]{1,1000}); } - @Test @Ignore("AB 2019/05/28 - JVM crash on linux CUDA CI machines - Issue 7657") + @Test @Disabled("AB 2019/05/28 - JVM crash on linux CUDA CI machines - Issue 7657") public void testInitPretrainedDarknet19() throws Exception { testInitPretrained(Darknet19.builder().numClasses(0).build(), new long[]{1,3,224,224}, new long[]{1,1000}); } - @Test @Ignore("AB 2019/05/28 - JVM crash on linux CUDA CI machines - Issue 7657") + @Test @Disabled("AB 2019/05/28 - JVM crash on linux CUDA CI machines - Issue 7657") public void testInitPretrainedDarknet19S2() throws Exception { testInitPretrained(Darknet19.builder().numClasses(0).inputShape(new int[]{3,448,448}).build(), new long[]{1,3,448,448}, new long[]{1,1000}); } @@ -240,7 +240,7 @@ public class TestInstantiation extends BaseDL4JTest { testInitRandomModel(Xception.builder().numClasses(1000).build(), new long[]{1,3,299,299}, new long[]{1,1000}); } - @Test @Ignore("AB - 2019/05/28 - JVM crash on CI - intermittent? Issue 7657") + @Test @Disabled("AB - 2019/05/28 - JVM crash on CI - intermittent? Issue 7657") public void testInitRandomModelSqueezenet() throws IOException { testInitRandomModel(SqueezeNet.builder().numClasses(1000).build(), new long[]{1,3,227,227}, new long[]{1,1000}); } @@ -250,7 +250,7 @@ public class TestInstantiation extends BaseDL4JTest { testInitRandomModel(FaceNetNN4Small2.builder().embeddingSize(100).numClasses(10).build(), new long[]{1,3,64,64}, new long[]{1,10}); } - @Test @Ignore("AB 2019/05/29 - Crashing on CI linux-x86-64 CPU only - Issue #7657") + @Test @Disabled("AB 2019/05/29 - Crashing on CI linux-x86-64 CPU only - Issue #7657") public void testInitRandomModelUNet() throws IOException { testInitRandomModel(UNet.builder().build(), new long[]{1,3,512,512}, new long[]{1,1,512,512}); } diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestUtils.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestUtils.java index 8a457dfef..a61ae386d 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestUtils.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestUtils.java @@ -31,7 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.*; import java.util.Random; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestUtils { diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index d9198ecfb..3c6d81e9a 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -52,7 +52,7 @@ import java.nio.charset.StandardCharsets; import java.util.*; import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class IntegrationTestBaselineGenerator { diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 516f7cf83..4b5df6ead 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -44,7 +44,7 @@ import org.deeplearning4j.optimize.listeners.CollectScoresListener; import org.deeplearning4j.parallelism.ParallelInference; import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.util.ModelSerializer; -import org.junit.rules.TemporaryFolder; + import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -74,10 +74,11 @@ import org.nd4j.shade.guava.reflect.ClassPath; import java.io.*; import java.lang.reflect.Modifier; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class IntegrationTestRunner { @@ -155,7 +156,7 @@ public class IntegrationTestRunner { evaluationClassesSeen = new HashMap<>(); } - public static void runTest(TestCase tc, TemporaryFolder testDir) throws Exception { + public static void runTest(TestCase tc, Path testDir) throws Exception { BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled. //This could alternatively be done via maven surefire configuration @@ -163,10 +164,10 @@ public class IntegrationTestRunner { log.info("Starting test case: {} - type = {}", tc.getTestName(), modelType); long start = System.currentTimeMillis(); - File workingDir = testDir.newFolder(); + File workingDir = new File(testDir.toFile(),"workingDir"); tc.initialize(workingDir); - File testBaseDir = testDir.newFolder(); + File testBaseDir = new File(testDir.toFile(),"baseDir"); // new ClassPathResource("dl4j-integration-tests/" + tc.getTestName()).copyDirectory(testBaseDir); Resources.copyDirectory((modelType == ModelType.SAMEDIFF ? "samediff-integration-tests/" : "dl4j-integration-tests/") + tc.getTestName(), testBaseDir); @@ -187,9 +188,9 @@ public class IntegrationTestRunner { m = mln; MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); - assertEquals("Configs not equal", loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations()); - assertEquals("Params not equal", loaded.params(), mln.params()); - assertEquals("Param table not equal", loaded.paramTable(), mln.paramTable()); + assertEquals(loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations(), "Configs not equal"); + assertEquals(loaded.params(), mln.params(), "Params not equal"); + assertEquals(loaded.paramTable(), mln.paramTable(), "Param table not equal"); } else if(config instanceof ComputationGraphConfiguration ){ ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; cg = new ComputationGraph(cgc); @@ -197,9 +198,9 @@ public class IntegrationTestRunner { m = cg; ComputationGraph loaded = ComputationGraph.load(savedModel, true); - assertEquals("Configs not equal", loaded.getConfiguration(), cg.getConfiguration()); - assertEquals("Params not equal", loaded.params(), cg.params()); - assertEquals("Param table not equal", loaded.paramTable(), cg.paramTable()); + assertEquals(loaded.getConfiguration(), cg.getConfiguration(), "Configs not equal"); + assertEquals(loaded.params(), cg.params(), "Params not equal"); + assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal"); } else if(config instanceof SameDiff){ sd = (SameDiff)config; SameDiff loaded = SameDiff.load(savedModel, true); @@ -256,7 +257,7 @@ public class IntegrationTestRunner { INDArray predictionExceedsRE = exceedsRelError(outSaved, out, tc.getMaxRelativeErrorOutput(), tc.getMinAbsErrorOutput()); int countExceeds = predictionExceedsRE.sumNumber().intValue(); - assertEquals("Predictions do not match saved predictions - output", 0, countExceeds); + assertEquals(0, countExceeds,"Predictions do not match saved predictions - output"); } } else if(modelType == ModelType.CG){ for (Pair p : inputs) { @@ -274,7 +275,7 @@ public class IntegrationTestRunner { for( int i=0; i 0) { logFailedParams(20, "Gradient", layers, gradExceedsRE, gradientFlatSaved, gradientFlat); } - assertEquals("Saved flattened gradients: not equal (using relative error)", 0, count); + assertEquals( 0, count,"Saved flattened gradients: not equal (using relative error)"); } //Load the gradient table: @@ -367,7 +368,7 @@ public class IntegrationTestRunner { INDArray gradExceedsRE = exceedsRelError(loaded, now, tc.getMaxRelativeErrorGradients(), tc.getMinAbsErrorGradients()); int count = gradExceedsRE.sumNumber().intValue(); - assertEquals("Gradients: not equal (using relative error) for parameter: " + key, 0, count); + assertEquals(0, count,"Gradients: not equal (using relative error) for parameter: " + key); } } @@ -410,7 +411,7 @@ public class IntegrationTestRunner { if(count > 0){ logFailedParams(20, "Parameter", layers, exceedsRelError, expParams, paramsPostTraining); } - assertEquals("Number of parameters exceeding relative error", 0, count); + assertEquals(0, count,"Number of parameters exceeding relative error"); //Set params to saved ones - to avoid accumulation of roundoff errors causing later failures... m.setParams(expParams); @@ -496,7 +497,7 @@ public class IntegrationTestRunner { String[] s = FileUtils.readFileToString(f, StandardCharsets.UTF_8).split(","); if(tc.isTestTrainingCurves()) { - assertEquals("Different number of scores", s.length, scores.length); + assertEquals(s.length, scores.length,"Different number of scores"); boolean pass = true; for (int i = 0; i < s.length; i++) { @@ -521,7 +522,7 @@ public class IntegrationTestRunner { if (count > 0) { logFailedParams(20, "Parameter", layers, z, paramsExp, m.params()); } - assertEquals("Number of params exceeded max relative error", 0, count); + assertEquals( 0, count,"Number of params exceeded max relative error"); } else { File dir = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_SAMEDIFF_DIR); for(SDVariable v : sd.variables()){ @@ -535,7 +536,7 @@ public class IntegrationTestRunner { if (count > 0) { logFailedParams(20, "Parameter: " + v.name(), layers, z, exp, paramNow); } - assertEquals("Number of params exceeded max relative error for parameter: \"" + v.name() + "\"", 0, count); + assertEquals(0, count,"Number of params exceeded max relative error for parameter: \"" + v.name() + "\""); } } } @@ -582,7 +583,7 @@ public class IntegrationTestRunner { } - assertEquals("Evaluation not equal: " + evals[i].getClass(), e, evals[i]); + assertEquals(e, evals[i], "Evaluation not equal: " + evals[i].getClass()); //Evaluation coverage information: evaluationClassesSeen.put(evals[i].getClass(), evaluationClassesSeen.getOrDefault(evals[i].getClass(), 0) + 1); @@ -597,8 +598,8 @@ public class IntegrationTestRunner { { log.info("Testing model serialization"); - File f = testDir.newFile(); - f.delete(); + File f = new File(testDir.toFile(),"test-file"); + f.deleteOnExit(); if (modelType == ModelType.MLN) { ModelSerializer.writeModel(m, f, true); @@ -704,7 +705,7 @@ public class IntegrationTestRunner { System.out.println("Relative error:"); System.out.println(re); } - assertEquals("Number of outputs exceeded max relative error", 0, count); + assertEquals(0, count,"Number of outputs exceeded max relative error"); } if(modelType != ModelType.SAMEDIFF) { @@ -808,8 +809,8 @@ public class IntegrationTestRunner { } for(org.deeplearning4j.nn.api.Layer l : layers){ - assertEquals("Epoch count", expEpoch, l.getEpochCount()); - assertEquals("Iteration count", expIter, l.getIterationCount()); + assertEquals(expEpoch, l.getEpochCount(),"Epoch count"); + assertEquals(expIter, l.getIterationCount(),"Iteration count"); } } @@ -854,18 +855,18 @@ public class IntegrationTestRunner { public static void checkFrozenParams(Map copiesBeforeTraining, Model m){ for(Map.Entry e : copiesBeforeTraining.entrySet()){ INDArray actual = m.getParam(e.getKey()); - assertEquals(e.getKey(), e.getValue(), actual); + assertEquals(e.getValue(), actual, e.getKey()); } } public static void checkConstants(Map copiesBefore, SameDiff sd){ for(Map.Entry e : copiesBefore.entrySet()){ INDArray actual = sd.getArrForVarName(e.getKey()); - assertEquals(e.getKey(), e.getValue(), actual); + assertEquals(e.getValue(), actual, e.getKey()); } } - public static void printCoverageInformation(){ + public static void printCoverageInformation() { log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"); @@ -1111,11 +1112,11 @@ public class IntegrationTestRunner { //Check constant and variable arrays: for(SDVariable v : sd1.variables()){ String n = v.name(); - assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType()); + assertEquals(v.getVariableType(), sd2.getVariable(n).getVariableType(), n); if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){ INDArray a1 = v.getArr(); INDArray a2 = sd2.getVariable(n).getArr(); - assertEquals(n, a1, a2); + assertEquals(a1, a2, n); } } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java index ea6672e3b..3180f7d38 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java @@ -22,23 +22,26 @@ package org.deeplearning4j.integration; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.integration.testcases.dl4j.*; -import org.junit.AfterClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -//@Ignore("AB - 2019/05/27 - Integration tests need to be updated") +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; + + +//@Disabled("AB - 2019/05/27 - Integration tests need to be updated") public class IntegrationTestsDL4J extends BaseDL4JTest { + @TempDir + static Path testDir; @Override public long getTimeoutMilliseconds() { return 300_000L; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @AfterClass + @AfterEach public static void afterClass(){ IntegrationTestRunner.printCoverageInformation(); } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java index f1bb83922..0cc6672a8 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java @@ -22,19 +22,24 @@ package org.deeplearning4j.integration; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; + public class IntegrationTestsSameDiff extends BaseDL4JTest { + @TempDir + static Path testDir; + @Override public long getTimeoutMilliseconds() { return 300_000L; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index c566400fc..5c16cc908 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.*; import java.util.Random; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestUtils { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java index dd1ba0130..915d6981a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java @@ -20,14 +20,14 @@ package org.nd4j.jita.allocator; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.DeviceLocalNDArray; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class DeviceLocalNDArrayTests extends BaseND4JTest { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java index 872afb7ad..0f5560d0f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java @@ -20,7 +20,7 @@ package org.nd4j.jita.allocator.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/workspace/CudaWorkspaceTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/workspace/CudaWorkspaceTest.java index e24db6811..14c862702 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/workspace/CudaWorkspaceTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/workspace/CudaWorkspaceTest.java @@ -19,7 +19,7 @@ package org.nd4j.jita.workspace; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -29,7 +29,7 @@ import org.nd4j.linalg.api.memory.enums.ResetPolicy; import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class CudaWorkspaceTest { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java index fec115300..2d90c71ce 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.jcublas.buffer; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataType; @@ -36,12 +36,12 @@ import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class BaseCudaDataBufferTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() { // } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java index 12d8584b6..2f918d639 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java @@ -23,7 +23,8 @@ package org.nd4j; import org.bytedeco.javacpp.Loader; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Ignore; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.runner.RunWith; import org.junit.runners.Suite; import org.nd4j.autodiff.opvalidation.*; @@ -53,7 +54,7 @@ import static org.junit.Assume.assumeFalse; }) //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" // With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ -@Ignore +@Disabled public class OpValidationSuite { /* @@ -78,7 +79,7 @@ public class OpValidationSuite { private static DataType initialType; - @BeforeClass + @BeforeAll public static void beforeClass() { Nd4j.create(1); initialType = Nd4j.dataType(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index 85515321b..9700ed253 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -20,13 +20,13 @@ package org.nd4j.autodiff; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.ImportClassMapping; @@ -121,7 +121,7 @@ import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; -@Ignore("No longer relevant after model import rewrite.") +@Disabled("No longer relevant after model import rewrite.") public class TestOpMapping extends BaseNd4jTest { Set> subTypes; @@ -166,16 +166,16 @@ public class TestOpMapping extends BaseNd4jTest { } String opName = df.opName(); - assertTrue("Op is missing - not defined in ImportClassMapping: " + opName + - "\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName) + assertTrue( opNameMapping.containsKey(opName),"Op is missing - not defined in ImportClassMapping: " + opName + + "\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping" ); try{ String[] tfNames = df.tensorflowNames(); - for(String s : tfNames ){ - assertTrue("Tensorflow mapping not found: " + s, tfOpNameMapping.containsKey(s)); - assertEquals("Tensorflow mapping: " + s, df.getClass(), tfOpNameMapping.get(s).getClass()); + for(String s : tfNames ) { + assertTrue( tfOpNameMapping.containsKey(s),"Tensorflow mapping not found: " + s); + assertEquals(df.getClass(), tfOpNameMapping.get(s).getClass(),"Tensorflow mapping: " + s); } } catch (NoOpNameFoundException e){ //OK, skip @@ -186,8 +186,8 @@ public class TestOpMapping extends BaseNd4jTest { String[] onnxNames = df.onnxNames(); for(String s : onnxNames ){ - assertTrue("Onnx mapping not found: " + s, onnxOpNameMapping.containsKey(s)); - assertEquals("Onnx mapping: " + s, df.getClass(), onnxOpNameMapping.get(s).getClass()); + assertTrue( onnxOpNameMapping.containsKey(s),"Onnx mapping not found: " + s); + assertEquals(df.getClass(), onnxOpNameMapping.get(s).getClass(),"Onnx mapping: " + s); } } catch (NoOpNameFoundException e){ //OK, skip @@ -354,7 +354,7 @@ public class TestOpMapping extends BaseNd4jTest { s.add(Assign.class); } - @Test @Ignore + @Test @Disabled public void generateOpClassList() throws Exception{ Reflections reflections = new Reflections("org.nd4j"); Set> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index 3cee7866e..c48727018 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -20,7 +20,8 @@ package org.nd4j.autodiff; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; @@ -43,7 +44,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestSessions extends BaseNd4jTest { @@ -202,7 +203,8 @@ public class TestSessions extends BaseNd4jTest { assertEquals(expFalse, outMap.get(n)); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testSwitchWhile() throws Exception{ /* diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java index 666b89bb1..68374d35b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.internal; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.internal.DependencyList; import org.nd4j.autodiff.samediff.internal.DependencyTracker; import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; @@ -33,7 +33,7 @@ import org.nd4j.common.primitives.Pair; import java.util.Collections; import static junit.framework.TestCase.assertNotNull; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestDependencyTracker extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java index 93cc41304..c3ed6099d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.opvalidation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.validation.GradCheckUtil; @@ -34,7 +34,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ActivationGradChecks extends BaseOpValidation { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java index 69f11dac2..3d22d0ded 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.opvalidation; -import org.junit.Before; +import org.junit.jupiter.api.BeforeEach; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -39,7 +39,7 @@ public abstract class BaseOpValidation extends BaseNd4jTest { return 'c'; } - @Before + @BeforeEach public void beforeClass() { Nd4j.create(1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 09ea0d93e..e104242d5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -26,7 +26,7 @@ import java.util.Collections; import java.util.List; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.validation.OpValidation; @@ -61,7 +61,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class LayerOpValidation extends BaseOpValidation { @@ -297,7 +297,7 @@ public class LayerOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -338,7 +338,7 @@ public class LayerOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -371,7 +371,7 @@ public class LayerOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -571,7 +571,7 @@ public class LayerOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -1325,81 +1325,89 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err, err); } - @Test(expected = IllegalArgumentException.class) + @Test() public void exceptionThrown_WhenConv1DConfigInvalid() { - int nIn = 3; - int nOut = 4; - int k = 2; - int mb = 3; - int img = 28; + assertThrows(IllegalArgumentException.class,() -> { + int nIn = 3; + int nOut = 4; + int k = 2; + int mb = 3; + int img = 28; - SameDiff sd = SameDiff.create(); - INDArray wArr = Nd4j.create(k, nIn, nOut); - INDArray inArr = Nd4j.create(mb, nIn, img); + SameDiff sd = SameDiff.create(); + INDArray wArr = Nd4j.create(k, nIn, nOut); + INDArray inArr = Nd4j.create(mb, nIn, img); - SDVariable in = sd.var("in", inArr); - SDVariable w = sd.var("W", wArr); + SDVariable in = sd.var("in", inArr); + SDVariable w = sd.var("W", wArr); - SDVariable[] vars = new SDVariable[]{in, w}; + SDVariable[] vars = new SDVariable[]{in, w}; - Conv1DConfig conv1DConfig = Conv1DConfig.builder() - .k(k).p(-1).s(0) - .paddingMode(PaddingMode.VALID) - .build(); + Conv1DConfig conv1DConfig = Conv1DConfig.builder() + .k(k).p(-1).s(0) + .paddingMode(PaddingMode.VALID) + .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void exceptionThrown_WhenConv2DConfigInvalid() { - - Nd4j.getRandom().setSeed(12345); - - SameDiff sd = SameDiff.create(); - SDVariable in = null; - - int[] inSizeNCHW = {1, 3, 8, 8}; - - String msg = "0 - conv2d+bias, nchw - input " + Arrays.toString(inSizeNCHW); - SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{3, 3, inSizeNCHW[1], 3}).muli(10)); //kH,kW,iC,oC - SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); - SDVariable out = sd.cnn().conv2d(in, w0, b0, Conv2DConfig.builder() - .dataFormat(Conv2DConfig.NCHW) - .isSameMode(true) - .kH(3).kW(-3) - .sH(1).sW(0) - .build()); - } - - @Test(expected = IllegalArgumentException.class) - public void exceptionThrown_WhenConf3DInvalid() { - Nd4j.getRandom().setSeed(12345); - - //NCDHW format - int[] inSizeNCDHW = {2, 3, 4, 5, 5}; - - List failed = new ArrayList<>(); - - for (boolean ncdhw : new boolean[]{true, false}) { - int nIn = inSizeNCDHW[1]; - int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); + assertThrows(IllegalArgumentException.class,() -> { + Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", shape); + SDVariable in = null; - SDVariable out; - String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); + int[] inSizeNCHW = {1, 3, 8, 8}; - SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] + String msg = "0 - conv2d+bias, nchw - input " + Arrays.toString(inSizeNCHW); + SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{3, 3, inSizeNCHW[1], 3}).muli(10)); //kH,kW,iC,oC SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); - out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() - .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) + SDVariable out = sd.cnn().conv2d(in, w0, b0, Conv2DConfig.builder() + .dataFormat(Conv2DConfig.NCHW) .isSameMode(true) - .kH(2).kW(2).kD(2) - .sD(1).sH(1).sW(-1).dW(-1) + .kH(3).kW(-3) + .sH(1).sW(0) .build()); - } + }); + + } + + @Test() + public void exceptionThrown_WhenConf3DInvalid() { + assertThrows(IllegalArgumentException.class,() -> { + Nd4j.getRandom().setSeed(12345); + + //NCDHW format + int[] inSizeNCDHW = {2, 3, 4, 5, 5}; + + List failed = new ArrayList<>(); + + for (boolean ncdhw : new boolean[]{true, false}) { + int nIn = inSizeNCDHW[1]; + int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", shape); + + SDVariable out; + String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); + + SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] + SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); + out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() + .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) + .isSameMode(true) + .kH(2).kW(2).kD(2) + .sD(1).sH(1).sW(-1).dW(-1) + .build()); + } + }); + } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 4746a8faa..7f5cf0884 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -21,7 +21,7 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; @@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class LossOpValidation extends BaseOpValidation { @@ -363,7 +363,7 @@ public class LossOpValidation extends BaseOpValidation { } } - assertEquals(failed.size() + " of " + totalRun + " failed: " + failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.size() + " of " + totalRun + " failed: " + failed.toString()); } @@ -461,7 +461,7 @@ public class LossOpValidation extends BaseOpValidation { .build(); Nd4j.getExecutioner().exec(op); - assertNotEquals(lossOp + " returns zero result. Reduction Mode " + reductionMode, out, zero); + assertNotEquals(out, zero,lossOp + " returns zero result. Reduction Mode " + reductionMode); } } @@ -480,7 +480,7 @@ public class LossOpValidation extends BaseOpValidation { .build(); Nd4j.getExecutioner().exec(op); - assertNotEquals(lossOp + "_grad returns zero result. Reduction Mode " + reductionMode, outBP, zeroBp); + assertNotEquals(outBP, zeroBp,lossOp + "_grad returns zero result. Reduction Mode " + reductionMode); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 38bef1c8d..58c8f0825 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -22,7 +22,7 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -72,7 +72,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.junit.Assume.assumeNotNull; @Slf4j @@ -167,7 +167,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals(0, failed.size(),"Failed: " + failed); } @Test @@ -256,7 +256,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals(0, failed.size(),"Failed: " + failed); } @Test @@ -358,7 +358,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals(0, failed.size(),"Failed: " + failed); } @@ -466,7 +466,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -537,7 +537,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @@ -739,7 +739,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } private static int[] t(boolean transpose, int[] orig){ @@ -1061,7 +1061,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failing.toString(), 0, failing.size()); + assertEquals(0, failing.size(),failing.toString()); } @@ -1130,7 +1130,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failing.toString(), 0, failing.size()); + assertEquals(0, failing.size(),failing.toString()); } @Test @@ -1160,7 +1160,7 @@ public class MiscOpValidation extends BaseOpValidation { failed.add(err); } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @Test @@ -1388,7 +1388,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -1502,7 +1502,7 @@ public class MiscOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 484741969..8c6b62dcd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -21,7 +21,7 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -46,7 +46,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class RandomOpValidation extends BaseOpValidation { @@ -153,7 +153,7 @@ public class RandomOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -279,7 +279,7 @@ public class RandomOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -342,8 +342,8 @@ public class RandomOpValidation extends BaseOpValidation { double expStd = 1.0/lambda; assertTrue(min >= 0.0); - assertEquals("mean", expMean, mean, 0.1); - assertEquals("std", expStd, std, 0.1); + assertEquals(expMean, mean, 0.1,"mean"); + assertEquals( expStd, std, 0.1,"std"); } @Test @@ -437,7 +437,7 @@ public class RandomOpValidation extends BaseOpValidation { double min = out.minNumber().doubleValue(); double max = out.maxNumber().doubleValue(); - assertTrue(String.valueOf(min), min > 0.0); - assertTrue(String.valueOf(max), max > 1.0); + assertTrue(min > 0.0,String.valueOf(min)); + assertTrue( max > 1.0,String.valueOf(max)); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java index 5b1ca243a..72a0dcadf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java @@ -21,9 +21,9 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.linalg.api.buffer.DataType; @@ -44,7 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.nativeblas.NativeOpsHolder; -import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertNull; @Slf4j public class ReductionBpOpValidation extends BaseOpValidation { @@ -55,7 +55,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { super(backend); } - @Before + @BeforeEach public void before() { Nd4j.create(1); initialType = Nd4j.dataType(); @@ -64,13 +64,13 @@ public class ReductionBpOpValidation extends BaseOpValidation { Nd4j.getRandom().setSeed(123); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 3e0692154..2f1cea2d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -21,9 +21,9 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.OpValidationSuite; @@ -73,14 +73,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) public class ReductionOpValidation extends BaseOpValidation { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); public ReductionOpValidation(Nd4jBackend backend) { super(backend); @@ -109,7 +107,7 @@ public class ReductionOpValidation extends BaseOpValidation { } } } - assertEquals(errors.toString(), 0, errors.size()); + assertEquals(0, errors.size(),errors.toString()); } @Test @@ -142,7 +140,7 @@ public class ReductionOpValidation extends BaseOpValidation { if (error != null) allFailed.add(error); } - assertEquals(allFailed.toString(), 0, allFailed.size()); + assertEquals(0, allFailed.size(),allFailed.toString()); } @@ -173,7 +171,7 @@ public class ReductionOpValidation extends BaseOpValidation { allFailed.add(error); } - assertEquals(allFailed.toString(), 0, allFailed.size()); + assertEquals(0, allFailed.size(),allFailed.toString()); } @Test @@ -342,7 +340,7 @@ public class ReductionOpValidation extends BaseOpValidation { failed.add(error); } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -465,7 +463,7 @@ public class ReductionOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals(0, failed.size(),"Failed: " + failed); } @Override @@ -647,7 +645,7 @@ public class ReductionOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals( 0, failed.size(),"Failed: " + failed); } @@ -753,7 +751,7 @@ public class ReductionOpValidation extends BaseOpValidation { } } - assertEquals("Failed: " + failed, 0, failed.size()); + assertEquals(0, failed.size(),"Failed: " + failed); } @Test @@ -938,7 +936,7 @@ public class ReductionOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -1032,10 +1030,10 @@ public class ReductionOpValidation extends BaseOpValidation { log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape()) + ", outExp=" + Arrays.toString(expOut.shape())); - assertArrayEquals(msg, expShape, out.shape()); - assertArrayEquals(msg, expShape, expOut.shape()); + assertArrayEquals( expShape, out.shape(),msg); + assertArrayEquals(expShape, expOut.shape(),msg); - assertEquals(msg, expOut, out); + assertEquals(expOut, out,msg); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 12991b1a2..53ea7d095 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -21,7 +21,7 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; @@ -39,7 +39,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class RnnOpValidation extends BaseOpValidation { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 420f0abe0..46e03f3e3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -26,8 +26,8 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.math3.linear.LUDecomposition; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -62,7 +62,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j @@ -119,7 +119,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @Test @@ -155,7 +155,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -193,7 +193,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -276,7 +276,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -339,7 +339,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -392,7 +392,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @@ -492,7 +492,7 @@ public class ShapeOpValidation extends BaseOpValidation { failed.add(error); } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -564,7 +564,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Override @@ -659,7 +659,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @@ -731,7 +731,7 @@ public class ShapeOpValidation extends BaseOpValidation { Map m = sd.outputAll(null); for (SDVariable v : unstacked) { - assertArrayEquals(msg, shape, m.get(v.name()).shape()); + assertArrayEquals(shape, m.get(v.name()).shape(),msg); } TestCase tc = new TestCase(sd).testName(msg); @@ -748,7 +748,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @Test @@ -819,7 +819,7 @@ public class ShapeOpValidation extends BaseOpValidation { } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -1358,7 +1358,7 @@ public class ShapeOpValidation extends BaseOpValidation { failed.add(err); } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -1467,7 +1467,7 @@ public class ShapeOpValidation extends BaseOpValidation { failed.add(err); } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals( 0, failed.size(),failed.toString()); } @@ -2517,7 +2517,7 @@ public class ShapeOpValidation extends BaseOpValidation { } - @Test @Ignore //AB 2020/04/01 - https://github.com/eclipse/deeplearning4j/issues/8592 + @Test @Disabled //AB 2020/04/01 - https://github.com/eclipse/deeplearning4j/issues/8592 public void testReshapeZeros(){ int[][] shapes = new int[][]{{2,0}, {10,0}, {10, 0}, {2,0,0,10}, {10, 0}, {0, 0, 10}, {0,2,10}, {1,2,0}}; int[][] reshape = new int[][]{{2,-1}, {2,0,-1}, {5,2,-1}, {2,0,-1}, {-1, 2, 0}, {2, -1, 0}, {2, 0, 0, 0, -1}, {2,0,-1}}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index fcb78a9ae..70e263740 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -22,10 +22,10 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; @@ -87,7 +87,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TransformOpValidation extends BaseOpValidation { @@ -98,7 +98,7 @@ public class TransformOpValidation extends BaseOpValidation { super(backend); } - @Before + @BeforeEach public void before() { Nd4j.create(1); initialType = Nd4j.dataType(); @@ -107,13 +107,13 @@ public class TransformOpValidation extends BaseOpValidation { Nd4j.getRandom().setSeed(123); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); @@ -213,7 +213,7 @@ public class TransformOpValidation extends BaseOpValidation { } } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -1309,7 +1309,7 @@ public class TransformOpValidation extends BaseOpValidation { failed.add(err); } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(0, failed.size(),failed.toString()); } @Test @@ -1420,7 +1420,7 @@ public class TransformOpValidation extends BaseOpValidation { val outShapes = Nd4j.getExecutioner().calculateOutputShape(op); assertEquals(1, outShapes.size()); - assertArrayEquals(Arrays.toString(outShapes.get(0).getShape()), new long[]{3, 2, 4}, outShapes.get(0).getShape()); + assertArrayEquals(new long[]{3, 2, 4}, outShapes.get(0).getShape(),Arrays.toString(outShapes.get(0).getShape())); } @Test @@ -1476,12 +1476,12 @@ public class TransformOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); - assertEquals(s, exp, out); + assertEquals(exp, out,s); } } - @Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") + @Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test public void testPad() { INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index eeeb9db3a..108fd4ab6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -20,10 +20,10 @@ package org.nd4j.autodiff.samediff; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index 5bf4cbc11..e1473414a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -21,8 +21,8 @@ package org.nd4j.autodiff.samediff; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -36,10 +36,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657") +@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657") public class FailingSameDiffTests extends BaseNd4jTest { public FailingSameDiffTests(Nd4jBackend b){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index fb8fedb90..de2249359 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -22,9 +22,10 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.graph.FlatConfiguration; import org.nd4j.graph.FlatGraph; @@ -57,14 +58,16 @@ import org.nd4j.linalg.learning.regularization.WeightDecay; import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import static junit.framework.TestCase.assertNotNull; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class FlatBufferSerdeTest extends BaseNd4jTest { @@ -78,11 +81,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { return 'c'; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testBasic() throws Exception { + public void testBasic(@TempDir Path testDir) throws Exception { SameDiff sd = SameDiff.create(); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() ); @@ -91,7 +93,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { ByteBuffer bb = sd.asFlatBuffers(true); - File f = testDir.newFile(); + File f = Files.createTempFile(testDir,"some-file","bin").toFile(); f.delete(); try(FileChannel fc = new FileOutputStream(f, false).getChannel()){ @@ -137,8 +139,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } @Test - public void testSimple() throws Exception { - for( int i=0; i<10; i++ ) { + public void testSimple(@TempDir Path testDir) throws Exception { + for( int i = 0; i < 10; i++ ) { for(boolean execFirst : new boolean[]{false, true}) { log.info("Starting test: i={}, execFirst={}", i, execFirst); SameDiff sd = SameDiff.create(); @@ -194,7 +196,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); } - File f = testDir.newFile(); + File f = Files.createTempFile(testDir,"some-file","fb").toFile(); f.delete(); sd.asFlatFile(f); @@ -223,7 +225,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { Map m2 = restored.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name())); INDArray outRestored = m2.get(x.name()); - assertEquals(String.valueOf(i), outOrig, outRestored); + assertEquals(outOrig, outRestored,String.valueOf(i)); //Check placeholders @@ -231,15 +233,15 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { Map vAfter = restored.variableMap(); assertEquals(vBefore.keySet(), vAfter.keySet()); for(String s : vBefore.keySet()){ - assertEquals(s, vBefore.get(s).isPlaceHolder(), vAfter.get(s).isPlaceHolder()); - assertEquals(s, vBefore.get(s).isConstant(), vAfter.get(s).isConstant()); + assertEquals(vBefore.get(s).isPlaceHolder(), vAfter.get(s).isPlaceHolder(),s); + assertEquals(vBefore.get(s).isConstant(), vAfter.get(s).isConstant(),s); } //Check save methods for(boolean withUpdaterState : new boolean[]{false, true}) { - File f2 = testDir.newFile(); + File f2 = Files.createTempFile(testDir,"some-file-2","fb").toFile(); sd.save(f2, withUpdaterState); SameDiff r2 = SameDiff.load(f2, withUpdaterState); assertEquals(varsOrig.size(), r2.variables().size()); @@ -247,8 +249,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { assertEquals(sd.getLossVariables(), r2.getLossVariables()); //Save via stream: - File f3 = testDir.newFile(); - try(OutputStream os = new BufferedOutputStream(new FileOutputStream(f3))){ + File f3 = Files.createTempFile(testDir,"some-file-3","fb").toFile(); + try(OutputStream os = new BufferedOutputStream(new FileOutputStream(f3))) { sd.save(os, withUpdaterState); } @@ -266,7 +268,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { @Test - public void testTrainingSerde() throws Exception { + public void testTrainingSerde(@TempDir Path testDir) throws Exception { //Ensure 2 things: //1. Training config is serialized/deserialized correctly @@ -301,12 +303,12 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { DataSet ds = new DataSet(inArr, labelArr); - for (int i = 0; i < 10; i++){ + for (int i = 0; i < 10; i++) { sd.fit(ds); } - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "samediff.bin"); sd.asFlatFile(f); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java index 9118aab84..384d6eb22 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java @@ -21,7 +21,7 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.transform.GraphTransformUtil; import org.nd4j.autodiff.samediff.transform.OpPredicate; import org.nd4j.autodiff.samediff.transform.SubGraph; @@ -37,9 +37,9 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class GraphTransformUtilTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java index 227662f12..68d6a0905 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.samediff; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.lang.reflect.Field; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class MemoryMgrTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 0e584e320..0811f140c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.samediff; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.linalg.BaseNd4jTest; @@ -32,8 +32,8 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class NameScopeTests extends BaseNd4jTest { @@ -101,10 +101,10 @@ public class NameScopeTests extends BaseNd4jTest { assertEquals("s1/s2/z", z.name()); assertEquals("a", a.name()); - assertTrue(add.name(), add.name().startsWith("s1/")); + assertTrue(add.name().startsWith("s1/"),add.name()); assertEquals("s1/addxy", addWithName.name()); - assertTrue(merge.name(), merge.name().startsWith("s1/s2/")); + assertTrue(merge.name().startsWith("s1/s2/"),merge.name()); assertEquals("s1/s2/mmax", mergeWithName.name()); Set allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a", @@ -116,17 +116,17 @@ public class NameScopeTests extends BaseNd4jTest { System.out.println(ops.keySet()); for(String s : ops.keySet()){ - assertTrue(s, s.startsWith("s1") || s.startsWith("s1/s2")); + assertTrue(s.startsWith("s1") || s.startsWith("s1/s2"),s); allowedOpNames.add(s); } //Check fields - Variable, SDOp, etc for(Variable v : sd.getVariables().values()){ - assertTrue(v.getVariable().name(), allowedVarNames.contains(v.getVariable().name())); + assertTrue( allowedVarNames.contains(v.getVariable().name()),v.getVariable().name()); assertEquals(v.getName(), v.getVariable().name()); if(v.getInputsForOp() != null){ for(String s : v.getInputsForOp()){ - assertTrue(s, allowedOpNames.contains(s)); + assertTrue(allowedOpNames.contains(s),s); } } @@ -164,7 +164,7 @@ public class NameScopeTests extends BaseNd4jTest { scope.close(); - assertTrue("Var with name test/argmax exists", SD.variableMap().containsKey("test/argmax")); + assertTrue(SD.variableMap().containsKey("test/argmax"),"Var with name test/argmax exists"); } @Test @@ -182,6 +182,6 @@ public class NameScopeTests extends BaseNd4jTest { scope.close(); - assertTrue("Var with name test/switch:1 exists", SD.variableMap().containsKey("test/switch:1")); + assertTrue( SD.variableMap().containsKey("test/switch:1"),"Var with name test/switch:1 exists"); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java index be7897cad..fae729e6d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java @@ -21,10 +21,11 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.imports.tfgraphs.TFGraphTestZooModels; import org.nd4j.linalg.api.buffer.DataType; @@ -34,19 +35,19 @@ import org.nd4j.common.primitives.AtomicBoolean; import org.nd4j.common.resources.Resources; import java.io.File; +import java.nio.file.Path; import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; @Slf4j public class SameDiffMultiThreadTests extends BaseND4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Override public long getTimeoutMilliseconds() { @@ -93,19 +94,19 @@ public class SameDiffMultiThreadTests extends BaseND4JTest { s.release(nThreads); latch.await(); - for(int i=0; i m = sd.output(Collections.emptyMap(), "out"); INDArray outAct = m.get("out"); - assertEquals(a.toString(), outExp, outAct); + assertEquals(outExp, outAct,a.toString()); // L = sum_i (label - out)^2 //dL/dOut = 2(out - label) @@ -1048,8 +1046,8 @@ public class SameDiffTests extends BaseNd4jTest { INDArray dLdOutAct = grads.get("out"); INDArray dLdInAct = grads.get("in"); - assertEquals(a.toString(), dLdOutExp, dLdOutAct); - assertEquals(a.toString(), dLdInExp, dLdInAct); + assertEquals(dLdOutExp, dLdOutAct,a.toString()); + assertEquals(dLdInExp, dLdInAct,a.toString()); } } @@ -1650,7 +1648,7 @@ public class SameDiffTests extends BaseNd4jTest { } } - @Ignore(/*AS - 20191114 https://github.com/eclipse/deeplearning4j/issues/8393*/) + @Disabled(/*AS - 20191114 https://github.com/eclipse/deeplearning4j/issues/8393*/) @Test public void testIsStrictlyIncShape() { int nOut = 0; @@ -1694,7 +1692,7 @@ public class SameDiffTests extends BaseNd4jTest { String msg = "expandDim=" + i + ", source=" + p.getSecond(); - assertEquals(msg, out, expOut); + assertEquals(out, expOut,msg); } } } @@ -1735,7 +1733,7 @@ public class SameDiffTests extends BaseNd4jTest { String msg = "squeezeDim=" + i + ", source=" + p.getSecond(); - assertEquals(msg, out, expOut); + assertEquals(out, expOut,msg); } } } @@ -1759,7 +1757,7 @@ public class SameDiffTests extends BaseNd4jTest { String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); - assertEquals(msg, out, inArr); //expand -> squeeze: should be opposite ops + assertEquals(out, inArr,msg); //expand -> squeeze: should be opposite ops } } } @@ -1787,7 +1785,7 @@ public class SameDiffTests extends BaseNd4jTest { String msg = "expand/Squeeze=" + i + ", source=" + p.getSecond(); - assertEquals(msg, out, inArr); //squeeze -> expand: should be opposite ops + assertEquals(out, inArr,msg); //squeeze -> expand: should be opposite ops } } } @@ -2427,7 +2425,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.createGradFunction(); fail("Expected exception"); } catch (IllegalStateException e) { - assertTrue(e.getMessage(), e.getMessage().contains("No loss variables")); + assertTrue(e.getMessage().contains("No loss variables"),e.getMessage()); } SDVariable add = mean.add(sum); @@ -2445,7 +2443,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.createGradFunction(); fail("Expected exception"); } catch (IllegalStateException e) { - assertTrue(e.getMessage(), e.getMessage().contains("No loss variables")); + assertTrue( e.getMessage().contains("No loss variables"),e.getMessage()); } SDVariable add = in.add(in2); @@ -2863,47 +2861,47 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable x2 = i == 0 ? sd.placeHolder("b", DataType.FLOAT, 5, 3) : sd.var("b", DataType.FLOAT, 5, 3); try { sd.placeHolder("a", DataType.FLOAT, 5, 3); - fail("Expected execption"); + fail("Expected exception"); } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m, m.contains("already exists")); + assertTrue(m.contains("already exists"),m); } try { sd.var("a", DataType.FLOAT, 1, 2); - fail("Expected execption"); + fail("Expected exception"); } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m, m.contains("already exists")); + assertTrue(m.contains("already exists"),m); } try { sd.var("a", Nd4j.zeros(1)); - fail("Expected execption"); + fail("Expected exception"); } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m, m.contains("already exists")); + assertTrue(m.contains("already exists"),m); } try { sd.var("a", LongShapeDescriptor.fromShape(new long[]{1}, DataType.FLOAT)); - fail("Expected execption"); + fail("Expected exception"); } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m, m.contains("already exists")); + assertTrue(m.contains("already exists"),m); } try { sd.constant("a", Nd4j.zeros(1)); - fail("Expected execption"); + fail("Expected exception"); } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m, m.contains("already exists")); + assertTrue(m.contains("already exists"),m); } } } @@ -2982,8 +2980,8 @@ public class SameDiffTests extends BaseNd4jTest { fail("Expected exception"); } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg - .contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg.contains("shape") && msg.contains("[2, 3]") && msg + .contains(Arrays.toString(v.placeholderShape())),msg); } } @@ -2992,8 +2990,8 @@ public class SameDiffTests extends BaseNd4jTest { fail("Expected exception"); } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg - .contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg.contains("shape") && msg.contains("[1]") && msg + .contains(Arrays.toString(v.placeholderShape())),msg); } try { @@ -3001,8 +2999,8 @@ public class SameDiffTests extends BaseNd4jTest { fail("Expected exception"); } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg - .contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg.contains("shape") && msg.contains("[3, 4, 5]") && msg + .contains(Arrays.toString(v.placeholderShape())),msg); } } @@ -3020,7 +3018,7 @@ public class SameDiffTests extends BaseNd4jTest { sd.fit(mds); } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]")); + assertTrue( msg.contains("shape") && msg.contains("[2, 3]"),msg); } } @@ -3122,7 +3120,7 @@ public class SameDiffTests extends BaseNd4jTest { Map out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { - assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); + assertEquals(DataType.FLOAT, e.getValue().dataType(),e.getKey()); } assertEquals(DataType.FLOAT, x.getArr().dataType()); @@ -3141,7 +3139,7 @@ public class SameDiffTests extends BaseNd4jTest { out = sd.output((Map)null, "x", "y", "z", "tanh", "stdev"); for (Map.Entry e : out.entrySet()) { - assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); + assertEquals(DataType.DOUBLE, e.getValue().dataType(),e.getKey()); } assertEquals(DataType.DOUBLE, x.getArr().dataType()); @@ -3171,9 +3169,9 @@ public class SameDiffTests extends BaseNd4jTest { Map out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { if (e.getKey().equals("x") || e.getKey().equals("y")) { - assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); + assertEquals(DataType.FLOAT, e.getValue().dataType(),e.getKey()); } else { - assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); + assertEquals(DataType.DOUBLE, e.getValue().dataType(),e.getKey()); } } @@ -3193,7 +3191,7 @@ public class SameDiffTests extends BaseNd4jTest { out = sd.output(ph, "x", "y", "xD", "yD", "a", "r"); for (Map.Entry e : out.entrySet()) { - assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); + assertEquals(DataType.DOUBLE, e.getValue().dataType(),e.getKey()); } assertEquals(DataType.DOUBLE, y.getArr().dataType()); @@ -3312,7 +3310,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testNestedWhile() throws IOException { SameDiff sd = SameDiff.create(); SDVariable countIn = sd.constant(5); @@ -3385,7 +3383,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - @Ignore // casted shape is null + @Disabled // casted shape is null public void castShapeTestEmpty(){ SameDiff sd = SameDiff.create(); SDVariable x = sd.constant(Nd4j.empty(DataType.INT)); @@ -3405,7 +3403,7 @@ public class SameDiffTests extends BaseNd4jTest { fail("Expected exception"); } catch (IllegalArgumentException e){ String m = e.getMessage(); - assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0")); + assertTrue(m.contains("variable") && m.contains("empty") && m.contains("0"),m); } try { @@ -3413,7 +3411,7 @@ public class SameDiffTests extends BaseNd4jTest { fail("Expected exception"); } catch (IllegalArgumentException e){ String m = e.getMessage().toLowerCase(); - assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0")); + assertTrue(m.contains("variable") && m.contains("empty") && m.contains("0"),m); } } @@ -3561,7 +3559,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMissingPlaceholderError(){ + public void testMissingPlaceholderError() { SameDiff sd = SameDiff.create(); @@ -3577,9 +3575,9 @@ public class SameDiffTests extends BaseNd4jTest { try { loss.eval(); fail("Exception should have been thrown"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided")); + assertTrue(msg.contains("\"labels\"") && msg.contains("No array was provided"),msg); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 7ff307c6e..0673429e0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -20,8 +20,8 @@ package org.nd4j.autodiff.samediff; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Collections; import java.util.HashMap; @@ -29,7 +29,7 @@ import java.util.List; import java.util.Map; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.evaluation.IEvaluation; @@ -128,7 +128,7 @@ public class SameDiffTrainingTest extends BaseNd4jTest { System.out.println(e.stats()); double acc = e.accuracy(); - assertTrue(u + " - " + acc, acc >= 0.75); + assertTrue( acc >= 0.75,u + " - " + acc); } } @@ -179,7 +179,7 @@ public class SameDiffTrainingTest extends BaseNd4jTest { double acc = e.accuracy(); - assertTrue("Accuracy bad: " + acc, acc >= 0.75); + assertTrue(acc >= 0.75,"Accuracy bad: " + acc); } @@ -234,7 +234,7 @@ public class SameDiffTrainingTest extends BaseNd4jTest { double acc = e.accuracy(); - assertTrue("Accuracy bad: " + acc, acc >= 0.75); + assertTrue(acc >= 0.75,"Accuracy bad: " + acc); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index e8565405b..195d8eb8c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -21,9 +21,10 @@ package org.nd4j.autodiff.samediff.listeners; import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -37,6 +38,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.Adam; import java.io.File; +import java.nio.file.Path; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -44,7 +46,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class CheckpointListenerTest extends BaseNd4jTest { @@ -57,9 +59,6 @@ public class CheckpointListenerTest extends BaseNd4jTest { return 'c'; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Override public long getTimeoutMilliseconds() { return 90000L; @@ -97,8 +96,8 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Test - public void testCheckpointEveryEpoch() throws Exception { - File dir = testDir.newFolder(); + public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); SameDiff sd = getModel(); CheckpointListener l = CheckpointListener.builder(dir) @@ -131,8 +130,8 @@ public class CheckpointListenerTest extends BaseNd4jTest { } @Test - public void testCheckpointEvery5Iter() throws Exception { - File dir = testDir.newFolder(); + public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); SameDiff sd = getModel(); CheckpointListener l = CheckpointListener.builder(dir) @@ -170,8 +169,8 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Test - public void testCheckpointListenerEveryTimeUnit() throws Exception { - File dir = testDir.newFolder(); + public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); SameDiff sd = getModel(); CheckpointListener l = new CheckpointListener.Builder(dir) @@ -208,14 +207,14 @@ public class CheckpointListenerTest extends BaseNd4jTest { } } - for( int i=0; i cpNums = new HashSet<>(); Set epochNums = new HashSet<>(); for(File f2 : files){ - if(!f2.getPath().endsWith(".bin")){ + if(!f2.getPath().endsWith(".bin")) { continue; } count++; @@ -251,7 +250,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { cpNums.add(epochNum); } - assertEquals(cpNums.toString(), 5, cpNums.size()); + assertEquals(5, cpNums.size(),cpNums.toString()); Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9))); Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19))); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java index 4fa0139a8..e4250fdd5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.samediff.listeners; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index d8dc54ca2..6346f87d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -20,7 +20,7 @@ package org.nd4j.autodiff.samediff.listeners; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.Listener; @@ -59,7 +59,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class ListenerTest extends BaseNd4jTest { @@ -132,7 +132,7 @@ public class ListenerTest extends BaseNd4jTest { System.out.println("Losses: " + Arrays.toString(losses)); double acc = hist.finalTrainingEvaluations().getValue(Metric.ACCURACY); - assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75); + assertTrue(acc >= 0.75,"Accuracy < 75%, was " + acc); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java index 7221591bc..c62e0fbd4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ProfilingListenerTest.java @@ -22,9 +22,10 @@ package org.nd4j.autodiff.samediff.listeners; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.listeners.profiler.ProfilingListener; import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer; import org.nd4j.autodiff.samediff.SDVariable; @@ -37,11 +38,12 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.File; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; public class ProfilingListenerTest extends BaseNd4jTest { @@ -54,11 +56,10 @@ public class ProfilingListenerTest extends BaseNd4jTest { return 'c'; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testProfilingListenerSimple() throws Exception { + public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception { SameDiff sd = SameDiff.create(); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); @@ -72,7 +73,7 @@ public class ProfilingListenerTest extends BaseNd4jTest { INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "test.json"); ProfilingListener listener = ProfilingListener.builder(f) .recordAll() @@ -96,7 +97,7 @@ public class ProfilingListenerTest extends BaseNd4jTest { //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name String[] opNames = {"matmul", "add", "softmax"}; for(String s : opNames) { - assertEquals(s, 10, StringUtils.countMatches(content, s)); + assertEquals( 10, StringUtils.countMatches(content, s),s); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java index 5013f2a40..c227e66e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java @@ -22,10 +22,11 @@ package org.nd4j.autodiff.ui; import com.google.flatbuffers.Table; import lombok.extern.slf4j.Slf4j; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.BeforeEach; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; @@ -49,13 +50,14 @@ import org.nd4j.common.primitives.Pair; import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class FileReadWriteTests extends BaseNd4jTest { @@ -70,10 +72,8 @@ public class FileReadWriteTests extends BaseNd4jTest { } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + @BeforeEach public void before() { Nd4j.create(1); Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); @@ -81,12 +81,12 @@ public class FileReadWriteTests extends BaseNd4jTest { } @Test - public void testSimple() throws IOException { + public void testSimple(@TempDir Path testDir) throws IOException { SameDiff sd = SameDiff.create(); SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); SDVariable sum = v.sum(); - File f = testDir.newFile(); + File f = testDir.toFile(); if (f.exists()) f.delete(); System.out.println(f.getAbsolutePath()); @@ -185,8 +185,8 @@ public class FileReadWriteTests extends BaseNd4jTest { } @Test - public void testNullBinLabels() throws Exception{ - File dir = testDir.newFolder(); + public void testNullBinLabels(@TempDir Path testDir) throws Exception{ + File dir = testDir.toFile(); File f = new File(dir, "temp.bin"); LogFileWriter w = new LogFileWriter(f); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 2b07a2769..72adb0bff 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -21,9 +21,10 @@ package org.nd4j.autodiff.ui; import com.google.flatbuffers.Table; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.listeners.impl.UIListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -43,11 +44,12 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.common.primitives.Pair; import java.io.File; +import java.nio.file.Path; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class UIListenerTest extends BaseNd4jTest { @@ -60,18 +62,17 @@ public class UIListenerTest extends BaseNd4jTest { return 'c'; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testUIListenerBasic() throws Exception { + public void testUIListenerBasic(@TempDir Path testDir) throws Exception { Nd4j.getRandom().setSeed(12345); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); SameDiff sd = getSimpleNet(); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "logFile.bin"); UIListener l = UIListener.builder(f) .plotLosses(1) @@ -100,13 +101,13 @@ public class UIListenerTest extends BaseNd4jTest { } @Test - public void testUIListenerContinue() throws Exception { + public void testUIListenerContinue(@TempDir Path testDir) throws Exception { IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); SameDiff sd1 = getSimpleNet(); SameDiff sd2 = getSimpleNet(); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "logFileNoContinue.bin"); f.delete(); UIListener l1 = UIListener.builder(f) @@ -191,11 +192,11 @@ public class UIListenerTest extends BaseNd4jTest { } @Test - public void testUIListenerBadContinue() throws Exception { + public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception { IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); SameDiff sd1 = getSimpleNet(); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File f = new File(dir, "logFile.bin"); f.delete(); UIListener l1 = UIListener.builder(f) @@ -233,8 +234,8 @@ public class UIListenerTest extends BaseNd4jTest { fail("Expected exception"); } catch (Throwable t){ String m = t.getMessage(); - assertTrue(m, m.contains("placeholder")); - assertTrue(m, m.contains("FileMode.CREATE_APPEND_NOCHECK")); + assertTrue(m.contains("placeholder"),m); + assertTrue(m.contains("FileMode.CREATE_APPEND_NOCHECK"),m); } @@ -254,8 +255,8 @@ public class UIListenerTest extends BaseNd4jTest { fail("Expected exception"); } catch (Throwable t){ String m = t.getMessage(); - assertTrue(m, m.contains("variable")); - assertTrue(m, m.contains("FileMode.CREATE_APPEND_NOCHECK")); + assertTrue(m.contains("variable"),m); + assertTrue(m.contains("FileMode.CREATE_APPEND_NOCHECK"),m); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java index ec8cfbc74..3c7b93779 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java @@ -20,9 +20,9 @@ package org.nd4j.evaluation; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.custom.CustomEvaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -61,7 +61,7 @@ public class CustomEvaluationTest extends BaseNd4jTest { } )); - assertEquals("Accuracy", acc, 3.0/5, 0.001); + assertEquals(acc, 3.0/5, 0.001,"Accuracy"); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java index 0c08bfe03..80dc7920a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java @@ -20,7 +20,7 @@ package org.nd4j.evaluation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; @@ -32,8 +32,8 @@ import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class EmptyEvaluationTests extends BaseNd4jTest { @@ -56,7 +56,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { e.scoreForMetric(m); fail("Expected exception"); } catch (Throwable t){ - assertTrue(t.getMessage(), t.getMessage().contains("no evaluation has been performed")); + assertTrue(t.getMessage().contains("no evaluation has been performed"),t.getMessage()); } } } @@ -70,7 +70,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { try { re.scoreForMetric(m); } catch (Throwable t){ - assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + assertTrue(t.getMessage().contains("eval must be called"),t.getMessage()); } } } @@ -85,7 +85,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { eb.scoreForMetric(m, 0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + assertTrue( t.getMessage().contains("eval must be called"),t.getMessage()); } } } @@ -100,7 +100,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { roc.scoreForMetric(m); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("no evaluation")); + assertTrue(t.getMessage().contains("no evaluation"),t.getMessage()); } } } @@ -115,7 +115,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { rb.scoreForMetric(m, 0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + assertTrue(t.getMessage().contains("eval must be called"),t.getMessage()); } } } @@ -130,7 +130,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { r.scoreForMetric(m, 0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("no data")); + assertTrue(t.getMessage().contains("no data"),t.getMessage()); } } } @@ -144,19 +144,19 @@ public class EmptyEvaluationTests extends BaseNd4jTest { ec.getResidualPlot(0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("no data")); + assertTrue( t.getMessage().contains("no data"),t.getMessage()); } try { ec.getProbabilityHistogram(0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("no data")); + assertTrue( t.getMessage().contains("no data"),t.getMessage()); } try { ec.getReliabilityDiagram(0); fail("Expected exception"); } catch (Throwable t) { - assertTrue(t.getMessage(), t.getMessage().contains("no data")); + assertTrue(t.getMessage().contains("no data"),t.getMessage()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java index 4934f6c2d..c40c38678 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java @@ -20,7 +20,7 @@ package org.nd4j.evaluation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.linalg.BaseNd4jTest; @@ -33,8 +33,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Random; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class EvalCustomThreshold extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java index 5cd0765f9..0d8ab24ab 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java @@ -20,7 +20,7 @@ package org.nd4j.evaluation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; @@ -38,8 +38,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static junit.framework.TestCase.assertNull; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class EvalJsonTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java index 706abcad3..25f606061 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java @@ -20,7 +20,7 @@ package org.nd4j.evaluation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -35,7 +35,7 @@ import org.nd4j.linalg.util.FeatureUtil; import java.text.DecimalFormat; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @@ -737,8 +737,8 @@ public class EvalTest extends BaseNd4jTest { String s2 = " 2 0 0 | 1 = 1"; //Second row: predicted 0, actual 1 - 2 times String stats = e.stats(); - assertTrue(stats, stats.contains(s1)); - assertTrue(stats, stats.contains(s2)); + assertTrue(stats.contains(s1),stats); + assertTrue(stats.contains(s2),stats); } @@ -831,10 +831,10 @@ public class EvalTest extends BaseNd4jTest { //System.out.println(evals[i].stats()); - assertEquals(m, tp, tpAct); - assertEquals(m, tn, tnAct); - assertEquals(m, fp, fpAct); - assertEquals(m, fn, fnAct); + assertEquals(tp, tpAct,m); + assertEquals( tn, tnAct,m); + assertEquals(fp, fpAct,m); + assertEquals(fn, fnAct,m); } double acc = (tp+tn) / (double)(tp+fn+tn+fp); @@ -844,10 +844,10 @@ public class EvalTest extends BaseNd4jTest { for( int i=0; i { + int specCols = 5; + INDArray labels = Nd4j.ones(3); + INDArray preds = Nd4j.ones(6); + RegressionEvaluation eval = new RegressionEvaluation(specCols); + + eval.eval(labels, preds); + }); - eval.eval(labels, preds); } @Test @@ -261,7 +265,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for (Metric m : Metric.values()) { double d1 = e3d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-6); + assertEquals(d2, d1, 1e-6,m.toString()); } } @@ -293,7 +297,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for (Metric m : Metric.values()) { double d1 = e4d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-5); + assertEquals(d2, d1, 1e-5,m.toString()); } } @@ -352,7 +356,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for(Metric m : Metric.values()){ double d1 = e4d_m2.scoreForMetric(m); double d2 = e2d_m2.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-5); + assertEquals(d2, d1, 1e-5,m.toString()); } } @@ -387,7 +391,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for(Metric m : Metric.values()){ double d1 = e4d_m1.scoreForMetric(m); double d2 = e2d_m1.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-5); + assertEquals(d2, d1, 1e-5,m.toString()); } //Check per-output masking: @@ -414,7 +418,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for(Metric m : Metric.values()){ double d1 = e4d_m2.scoreForMetric(m); double d2 = e2d_m2.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-5); + assertEquals(d2, d1, 1e-5,m.toString()); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java index 95eda58d5..1aaae65d5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java @@ -21,7 +21,7 @@ package org.nd4j.evaluation; import org.apache.commons.io.FileUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; @@ -32,7 +32,7 @@ import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.nio.charset.StandardCharsets; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestLegacyJsonLoading extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java index 96f91e359..5b5470bda 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java @@ -23,8 +23,8 @@ package org.nd4j.imports; import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.graph.FlatArray; @@ -37,8 +37,8 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -48,7 +48,7 @@ public class ByteOrderTests extends BaseNd4jTest { super(backend); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java index 872438495..b1cb771db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -22,8 +22,8 @@ package org.nd4j.imports; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.OpValidationSuite; @@ -39,7 +39,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -49,7 +49,7 @@ public class ExecutionTests extends BaseNd4jTest { super(backend); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java index f829f808e..c92370f5d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java @@ -22,14 +22,14 @@ package org.nd4j.imports; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java index f415afa6f..2545dd8fe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java @@ -21,8 +21,8 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; @@ -49,11 +49,11 @@ import java.io.File; import java.net.URL; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@Ignore("AB 2019/05/21 - JVM Crash on linux-x86_64-cuda-9.2, linux-ppc64le-cpu - Issue #7657") +@Disabled("AB 2019/05/21 - JVM Crash on linux-x86_64-cuda-9.2, linux-ppc64le-cpu - Issue #7657") public class BERTGraphTest extends BaseNd4jTest { public BERTGraphTest(Nd4jBackend b){ @@ -156,7 +156,7 @@ public class BERTGraphTest extends BaseNd4jTest { List subGraphs = GraphTransformUtil.getSubgraphsMatching(sd, p); int subGraphCount = subGraphs.size(); - assertTrue("Subgraph count: " + subGraphCount, subGraphCount > 0); + assertTrue(subGraphCount > 0,"Subgraph count: " + subGraphCount); /* @@ -274,7 +274,7 @@ public class BERTGraphTest extends BaseNd4jTest { assertEquals(exp3, softmax.getRow(3)); } - @Test //@Ignore //AB ignored 08/04/2019 until fixed + @Test //@Disabled //AB ignored 08/04/2019 until fixed public void testBertTraining() throws Exception { String url = "https://dl4jdata.blob.core.windows.net/testresources/bert_mrpc_frozen_v1.zip"; File saveDir = new File(TFGraphTestZooModels.getBaseModelDir(), ".nd4jtests/bert_mrpc_frozen_v1"); @@ -413,10 +413,10 @@ public class BERTGraphTest extends BaseNd4jTest { double scoreAfter = lossArr.getDouble(0); String s = "Before: " + scoreBefore + "; after: " + scoreAfter; - assertTrue(s, scoreAfter < scoreBefore); + assertTrue( scoreAfter < scoreBefore,s); } - @Test @Ignore + @Test @Disabled public void writeBertUI() throws Exception { //Test used to generate graph for visualization to work out appropriate subgraph structure to replace File f = new File("C:/Temp/TF_Graphs/mrpc_output/frozen/bert_mrpc_frozen.pb"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java index 95862dab7..64006120e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java @@ -21,7 +21,7 @@ package org.nd4j.imports.tfgraphs; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,8 +29,8 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class CustomOpTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java index ace75d08e..268acae1c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java @@ -22,13 +22,13 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; @Slf4j public class NodeReaderTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 7fd8b44c6..faeb04d22 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -26,9 +26,9 @@ import lombok.val; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.math.NumberUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.nd4j.autodiff.execution.NativeGraphExecutioner; import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; @@ -76,7 +76,7 @@ import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Pattern; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.imports.tfgraphs.TFGraphsSkipNodes.skipNode; @Slf4j @@ -111,17 +111,17 @@ public class TFGraphTestAllHelper { public static final DefaultGraphLoader LOADER = new DefaultGraphLoader(); - @BeforeClass + @BeforeAll public void beforeClass(){ log.info("Starting tests for class: " + getClass().getName()); } - @Before + @BeforeEach public void setup(){ Nd4j.setDataType(DataType.FLOAT); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); @@ -182,7 +182,7 @@ public class TFGraphTestAllHelper { OpValidation.collectTensorflowImportCoverage(graph); if (!execType.equals(ExecuteWith.JUST_PRINT)) { - assertTrue("No predictions to validate", predictions.keySet().size() > 0); + assertTrue(predictions.keySet().size() > 0,"No predictions to validate"); for (String outputNode : predictions.keySet()) { INDArray nd4jPred = null; INDArray tfPred = null; @@ -211,7 +211,7 @@ public class TFGraphTestAllHelper { if(maxRelErrorOverride == null) { long[] sTf = tfPred.shape(); long[] sNd4j = nd4jPred.shape(); - assertArrayEquals("Shapes for node \"" + outputNode + "\" are not equal: TF: " + Arrays.toString(sTf) + " vs SD: " + Arrays.toString(sNd4j), sTf, sNd4j); + assertArrayEquals(sTf, sNd4j,"Shapes for node \"" + outputNode + "\" are not equal: TF: " + Arrays.toString(sTf) + " vs SD: " + Arrays.toString(sNd4j)); // TODO: once we add more dtypes files - this should be removed if (tfPred.dataType() != nd4jPred.dataType()) @@ -253,7 +253,7 @@ public class TFGraphTestAllHelper { } } - assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq); + assertTrue(eq,"Predictions do not match on " + modelName + ", node " + outputNode); } else { if(!tfPred.equalShapes(nd4jPred)) { @@ -302,8 +302,8 @@ public class TFGraphTestAllHelper { } - assertEquals( outputNode + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride - + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); + assertEquals( 0, countExceeds,outputNode + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride + + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE); } } log.info("TEST {} PASSED with {} arrays compared...", modelName, predictions.keySet().size()); @@ -383,8 +383,8 @@ public class TFGraphTestAllHelper { } - assertEquals( varName + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride - + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); + assertEquals( 0, countExceeds,varName + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride + + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE); } else { // assertEquals("Value not equal on node " + varName, tfValue, sdVal); if(tfValue.equals(sdVal)){ @@ -403,7 +403,7 @@ public class TFGraphTestAllHelper { } } - assertTrue("No intermediate variables were checked", count > 0); + assertTrue(count > 0,"No intermediate variables were checked"); } Nd4j.EPS_THRESHOLD = 1e-5; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java index 87297a7b8..8a77be345 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java @@ -22,7 +22,9 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.*; +import org.junit.jupiter.api.*;import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.rules.TestWatcher; import org.junit.runner.Description; import org.junit.runner.RunWith; @@ -41,21 +43,9 @@ import java.util.*; @RunWith(Parameterized.class) @Slf4j -@Ignore("AB 2019/05/21 - JVM Crashes - Issue #7657") +@Disabled("AB 2019/05/21 - JVM Crashes - Issue #7657") public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests - @Rule - public TestWatcher testWatcher = new TestWatcher() { - - @Override - protected void starting(Description description){ - log.info("TFGraphTestAllLibnd4j: Starting parameterized test: " + description.getDisplayName()); - } - - //protected void failed(Throwable e, Description description) { - //protected void succeeded(Description description) { - }; - private Map inputs; private Map predictions; private String modelName; @@ -109,18 +99,17 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as "rnn/lstmblockfusedcell/.*", }; - @BeforeClass - public static void beforeClass() { + @BeforeAll public static void beforeClass() { Nd4j.setDataType(DataType.FLOAT); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } - @Before + @BeforeEach public void setup(){ Nd4j.setDataType(DataType.FLOAT); } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 7e5ccc517..c2a916d42 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -22,8 +22,7 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.*; -import org.junit.rules.TestWatcher; +import org.junit.jupiter.api.*; import org.junit.runner.Description; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -40,20 +39,9 @@ import java.util.*; @Slf4j @RunWith(Parameterized.class) -@Ignore +@Disabled public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests - @Rule - public TestWatcher testWatcher = new TestWatcher() { - - @Override - protected void starting(Description description){ - log.info("TFGraphTestAllSameDiff: Starting parameterized test: " + description.getDisplayName()); - } - - //protected void failed(Throwable e, Description description) { - //protected void succeeded(Description description) { - }; private Map inputs; private Map predictions; @@ -155,21 +143,21 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a */ private final List debugModeRegexes = Arrays.asList("fused_batch_norm/float16_nhwc"); - @BeforeClass - public static void beforeClass() { + @BeforeAll + public static void beforeClass() { Nd4j.scalar(1.0); Nd4j.setDataType(DataType.FLOAT); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } - @Before + @BeforeEach public void setup() { Nd4j.setDataType(DataType.FLOAT); Nd4j.getExecutioner().enableDebugMode(true); Nd4j.getExecutioner().enableVerboseMode(true); } - @After + @AfterEach public void tearDown() { } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index e33dc407e..455734817 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -20,8 +20,11 @@ package org.nd4j.imports.tfgraphs; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.io.TempDir; + import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,17 +35,16 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.io.File; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @RunWith(Parameterized.class) -@Ignore +@Disabled public class TFGraphTestList { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); //Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step //Implemented internally using ExecPrintListener @@ -52,7 +54,7 @@ public class TFGraphTestList { "resize_nearest_neighbor/int32" }; - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); @@ -66,8 +68,8 @@ public class TFGraphTestList { public static final String MODEL_DIR = "tf_graphs/examples"; public static final String MODEL_FILENAME = "frozen_model.pb"; - @BeforeClass - public static void beforeClass(){ + @BeforeAll + public static void beforeClass() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } @@ -88,9 +90,9 @@ public class TFGraphTestList { } @Test - public void testOutputOnly() throws IOException { + public void testOutputOnly(@TempDir Path testDir) throws IOException { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); Map predictions = TFGraphTestAllHelper.outputVars(modelName, MODEL_DIR, dir); Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); @@ -101,10 +103,10 @@ public class TFGraphTestList { TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging); } - @Test @Ignore - public void testAlsoIntermediate() throws IOException { + @Test @Disabled + public void testAlsoIntermediate(@TempDir Path testDir) throws IOException { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 2f0b6a2f6..f5e0f1130 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang3.ArrayUtils; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.io.TempDir; + import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.OpValidationSuite; @@ -43,6 +44,7 @@ import java.io.File; import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -50,11 +52,11 @@ import java.util.Map; @RunWith(Parameterized.class) @Slf4j -@Ignore +@Disabled public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests + @TempDir + static Path classTestDir; - @ClassRule - public static TemporaryFolder classTestDir = new TemporaryFolder(); public static final String[] IGNORE_REGEXES = { //2019/07/22 - Result value failure @@ -95,8 +97,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we "deeplabv3_pascal_train_aug_2018_01_04" }; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + public static File currentTestDir; public static final File BASE_MODEL_DL_DIR = new File(getBaseModelDir(), ".nd4jtests"); @@ -204,7 +205,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we } } - @BeforeClass + @BeforeAll public static void beforeClass(){ Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); @@ -212,8 +213,8 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we @Parameterized.Parameters(name="{2}") public static Collection data() throws IOException { - classTestDir.create(); - File baseDir = classTestDir.newFolder(); // new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); + classTestDir.toFile().mkdir(); + File baseDir = classTestDir.toFile(); // new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, baseDir); return params; } @@ -239,7 +240,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we } @Test //(timeout = 360000L) - public void testOutputOnly() throws Exception { + public void testOutputOnly(@TempDir Path testDir) throws Exception { if(isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI @@ -256,7 +257,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we // if(!modelName.startsWith("faster_rcnn_resnet101_coco_2018_01_28")){ // OpValidationSuite.ignoreFailing(); // } - currentTestDir = testDir.newFolder(); + currentTestDir = testDir.toFile(); // Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); Nd4j.getMemoryManager().setAutoGcWindow(2000); @@ -269,7 +270,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we Double maxRE = 1e-3; Double minAbs = 1e-4; - currentTestDir = testDir.newFolder(); + currentTestDir = testDir.toFile(); log.info("----- SameDiff Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, LOADER, maxRE, minAbs, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java index ed6db6c7f..17a2cd3b2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java @@ -22,11 +22,12 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.BaseNd4jTest; @@ -38,13 +39,14 @@ import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@Ignore +@Disabled public class ValidateZooModelPredictions extends BaseNd4jTest { public ValidateZooModelPredictions(Nd4jBackend backend) { @@ -56,10 +58,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { return 'c'; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - @Before + + @BeforeEach public void before() { Nd4j.create(1); Nd4j.setDataType(DataType.DOUBLE); @@ -72,7 +73,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { } @Test - public void testMobilenetV1() throws Exception { + public void testMobilenetV1(@TempDir Path testDir) throws Exception { if(TFGraphTestZooModels.isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI @@ -84,7 +85,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { OpValidationSuite.ignoreFailing(); } - TFGraphTestZooModels.currentTestDir = testDir.newFolder(); + TFGraphTestZooModels.currentTestDir = testDir.toFile(); //Load model String path = "tf_graphs/zoo_models/mobilenet_v1_0.5_128/tf_model.txt"; @@ -137,7 +138,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { @Test - public void testResnetV2() throws Exception { + public void testResnetV2(@TempDir Path testDir) throws Exception { if(TFGraphTestZooModels.isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI @@ -149,7 +150,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { OpValidationSuite.ignoreFailing(); } - TFGraphTestZooModels.currentTestDir = testDir.newFolder(); + TFGraphTestZooModels.currentTestDir = testDir.toFile(); //Load model String path = "tf_graphs/zoo_models/resnetv2_imagenet_frozen_graph/tf_model.txt"; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index 786eb6185..efb6b5820 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -22,10 +22,10 @@ package org.nd4j.imports; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.autodiff.execution.conf.ExecutionMode; @@ -62,11 +62,11 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @RunWith(Parameterized.class) public class TensorFlowImportTest extends BaseNd4jTest { private static ExecutorConfiguration configuration = ExecutorConfiguration.builder() @@ -86,11 +86,11 @@ public class TensorFlowImportTest extends BaseNd4jTest { return 'c'; } - @Before + @BeforeEach public void setUp() { } - @After + @AfterEach public void tearDown() { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); @@ -161,7 +161,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void importGraph1() throws Exception { SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream()); @@ -184,7 +184,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - @Ignore + @Disabled public void importGraph2() throws Exception { SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream()); @@ -193,7 +193,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - @Ignore + @Disabled public void importGraph3() throws Exception { SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream()); @@ -201,7 +201,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testImportIris() throws Exception { SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream()); assertNotNull(graph); @@ -210,7 +210,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - @Ignore + @Disabled public void importGraph4() throws Exception { SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream()); @@ -302,7 +302,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testWeirdConvImport() { val tg = TFGraphMapper.importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); assertNotNull(tg); @@ -335,7 +335,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testIntermediateStridedSlice1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream()); @@ -411,7 +411,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testIntermediateTensorArraySimple1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); @@ -438,7 +438,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testIntermediateTensorArrayLoop1() throws Exception { val input = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); @@ -677,7 +677,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { val ownName = func.getOwnName(); String outName = func.outputVariables()[0].name(); - assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); + assertTrue(variables.containsKey(ownName),"Missing ownName: [" + ownName +"]"); assertEquals(ownName, outName); } } @@ -835,7 +835,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testProfConv() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt")); @@ -845,7 +845,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCrash_119_matrix_diag() throws Exception { Nd4j.create(1); @@ -864,7 +864,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCrash_119_tensor_dot_misc() throws Exception { Nd4j.create(1); @@ -881,7 +881,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCrash_119_transpose() throws Exception { Nd4j.create(1); @@ -898,7 +898,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - //@Ignore + //@Disabled public void testCrash_119_simpleif_0() throws Exception { Nd4j.create(1); @@ -915,7 +915,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCrash_119_ae_00() throws Exception { Nd4j.create(1); @@ -930,7 +930,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCrash_119_expand_dim() throws Exception { Nd4j.create(1); @@ -945,7 +945,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - //@Ignore + //@Disabled public void testCrash_119_reduce_dim_false() throws Exception { Nd4j.create(1); @@ -957,7 +957,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - //@Ignore + //@Disabled public void testCrash_119_reduce_dim_true() throws Exception { Nd4j.create(1); @@ -1069,9 +1069,12 @@ public class TensorFlowImportTest extends BaseNd4jTest { assertNotNull(tg); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNonFrozenGraph1() throws Exception { - val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); + assertThrows(ND4JIllegalStateException.class,() -> { + val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream()); + + }); } @Test @@ -1091,7 +1094,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testRandomGraph3() throws Exception { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java index 6e4920bcf..16e6de0ff 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java @@ -20,7 +20,7 @@ package org.nd4j.imports; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java index 026c3608c..4f742ba09 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java @@ -21,8 +21,8 @@ package org.nd4j.imports.listeners; import org.apache.commons.io.FileUtils; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; @@ -34,11 +34,11 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -@Ignore +@Disabled public class ImportModelDebugger { @Test - @Ignore + @Disabled public void doTest(){ main(new String[0]); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java index 3b4854e3c..ebef6af1d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java @@ -21,9 +21,9 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -35,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -50,12 +50,12 @@ public class AveragingTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } - @After + @AfterEach public void shutUp() { DataTypeUtil.setDTypeForContext(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index 22985fd0f..c1061f1a6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -21,7 +21,7 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; -import org.junit.Before; +import org.junit.jupiter.api.BeforeEach; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.common.config.ND4JClassLoading; @@ -81,7 +81,7 @@ public abstract class BaseNd4jTest extends BaseND4JTest { return ret; } - @Before + @BeforeEach public void beforeTest2(){ Nd4j.factory().setOrder(ordering()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java index f99f4c7de..5f01c8526 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java @@ -22,7 +22,7 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) @Slf4j @@ -61,7 +61,7 @@ public class DataTypeTest extends BaseNd4jTest { val ois = new ObjectInputStream(bios); try { val in2 = (INDArray) ois.readObject(); - assertEquals("Failed for data type [" + type + "]", in1, in2); + assertEquals( in1, in2,"Failed for data type [" + type + "]"); } catch (Exception e) { throw new RuntimeException("Failed for data type [" + type + "]", e); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java index b8ff2095a..f2c8b5419 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java @@ -20,14 +20,14 @@ package org.nd4j.linalg; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.fail; @RunWith(Parameterized.class) public class InputValidationTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index 72e9b8fc7..e9175a0cf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -23,7 +23,7 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -43,8 +43,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @@ -223,7 +222,7 @@ public class LoneTest extends BaseNd4jTest { for (int e = 0; e < 32; e++) { val tad = res.tensorAlongDimension(e, 1, 2); - assertEquals("Failed for TAD [" + e + "]",(double) e, tad.meanNumber().doubleValue(), 1e-5); + assertEquals((double) e, tad.meanNumber().doubleValue(), 1e-5,"Failed for TAD [" + e + "]"); assertEquals((double) e, tad.getDouble(0), 1e-5); } } @@ -256,13 +255,16 @@ public class LoneTest extends BaseNd4jTest { // log.info("p50: {}; avg: {};", times.get(times.size() / 2), time); } - @Test(expected = Exception.class) + @Test() public void checkIllegalElementOps() { - INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); - INDArray B = A.dup().reshape(2, 2, 5); + assertThrows(Exception.class,() -> { + INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); + INDArray B = A.dup().reshape(2, 2, 5); + + //multiplication of arrays of different rank should throw exception + INDArray C = A.mul(B); + }); - //multiplication of arrays of different rank should throw exception - INDArray C = A.mul(B); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/MmulBug.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/MmulBug.java index 05cfeaed1..5511c9b6d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/MmulBug.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/MmulBug.java @@ -20,13 +20,13 @@ package org.nd4j.linalg; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class MmulBug extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java index 634be157d..59736d936 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java @@ -23,8 +23,8 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataBuffer; @@ -51,7 +51,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * NDArrayTests for fortran ordering @@ -243,7 +243,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false); assertEquals(sorted[1], sorted2); INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2}); - assertEquals(getFailureMessage(), shouldIndex, sorted[0]); + assertEquals(shouldIndex, sorted[0],getFailureMessage()); } @Test @@ -262,7 +262,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true); assertEquals(sorted[1], sorted2); INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2}); - assertEquals(getFailureMessage(), shouldIndex, sorted[0]); + assertEquals(shouldIndex, sorted[0],getFailureMessage()); } @Test @@ -319,13 +319,13 @@ public class NDArrayTestsFortran extends BaseNd4jTest { public void testDivide() { INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray div = two.div(two); - assertEquals(getFailureMessage(), Nd4j.ones(DataType.FLOAT, 4), div); + assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage()); INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2}); INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2}); INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2}); INDArray result = half.div(divi); - assertEquals(getFailureMessage(), assertion, result); + assertEquals( assertion, result,getFailureMessage()); } @@ -334,7 +334,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray sigmoid = Transforms.sigmoid(n, false); - assertEquals(getFailureMessage(), assertion, sigmoid); + assertEquals( assertion, sigmoid,getFailureMessage()); } @@ -343,7 +343,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray neg = Transforms.neg(n); - assertEquals(getFailureMessage(), assertion, neg); + assertEquals(assertion, neg,getFailureMessage()); } @@ -353,12 +353,12 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); double sim = Transforms.cosineSim(vec1, vec2); - assertEquals(getFailureMessage(), 1, sim, 1e-1); + assertEquals(1, sim, 1e-1,getFailureMessage()); INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f}); INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f}); sim = Transforms.cosineSim(vec3, vec4); - assertEquals(getFailureMessage(), 0.98, sim, 1e-1); + assertEquals(0.98, sim, 1e-1,getFailureMessage()); } @@ -597,7 +597,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray innerProduct = n.mmul(transposed); INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); - assertEquals(getFailureMessage(), scalar, innerProduct); + assertEquals(scalar, innerProduct,getFailureMessage()); } @@ -651,7 +651,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray five = Nd4j.ones(5); five.addi(five.dup()); INDArray twos = Nd4j.valueArrayOf(5, 2); - assertEquals(getFailureMessage(), twos, five); + assertEquals(twos, five,getFailureMessage()); } @@ -664,7 +664,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray test = arr.mmul(arr.transpose()); - assertEquals(getFailureMessage(), assertion, test); + assertEquals(assertion, test,getFailureMessage()); } @@ -675,7 +675,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { Nd4j.exec(new PrintVariable(newSlice)); log.info("Slice: {}", newSlice); n.putSlice(0, newSlice); - assertEquals(getFailureMessage(), newSlice, n.slice(0)); + assertEquals( newSlice, n.slice(0),getFailureMessage()); } @@ -683,7 +683,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { public void testRowVectorMultipleIndices() { INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); linear.putScalar(new long[] {0, 1}, 1); - assertEquals(getFailureMessage(), linear.getDouble(0, 1), 1, 1e-1); + assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage()); } @@ -1004,7 +1004,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray nClone = n1.add(n2); assertEquals(Nd4j.scalar(3), nClone); INDArray n1PlusN2 = n1.add(n2); - assertFalse(getFailureMessage(), n1PlusN2.equals(n1)); + assertFalse(n1PlusN2.equals(n1),getFailureMessage()); INDArray n3 = Nd4j.scalar(3.0); INDArray n4 = Nd4j.scalar(4.0); @@ -1029,7 +1029,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testTensorDot() { INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE); @@ -1081,12 +1081,12 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray dupc = in.dup('c'); INDArray dupf = in.dup('f'); - assertEquals(msg, in, dup); - assertEquals(msg, dup.ordering(), (char) Nd4j.order()); - assertEquals(msg, dupc.ordering(), 'c'); - assertEquals(msg, dupf.ordering(), 'f'); - assertEquals(msg, in, dupc); - assertEquals(msg, in, dupf); + assertEquals(in, dup,msg); + assertEquals(dup.ordering(), (char) Nd4j.order(),msg); + assertEquals(dupc.ordering(), 'c',msg); + assertEquals(dupf.ordering(), 'f',msg); + assertEquals( in, dupc,msg); + assertEquals(in, dupf,msg); count++; } } @@ -1104,12 +1104,12 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray dupf = Shape.toOffsetZeroCopy(in, 'f'); INDArray dupany = Shape.toOffsetZeroCopyAnyOrder(in); - assertEquals(msg + ": " + cnt, in, dup); - assertEquals(msg, in, dupc); - assertEquals(msg, in, dupf); - assertEquals(msg, dupc.ordering(), 'c'); - assertEquals(msg, dupf.ordering(), 'f'); - assertEquals(msg, in, dupany); + assertEquals( in, dup,msg + ": " + cnt); + assertEquals(in, dupc,msg); + assertEquals(in, dupf,msg); + assertEquals(dupc.ordering(), 'c',msg); + assertEquals(dupf.ordering(), 'f',msg); + assertEquals( in, dupany,msg); assertEquals(dup.offset(), 0); assertEquals(dupc.offset(), 0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 674bd0ba2..660ef4a8e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -20,26 +20,19 @@ package org.nd4j.linalg; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - import lombok.extern.slf4j.Slf4j; import lombok.val; import lombok.var; import org.apache.commons.io.FilenameUtils; import org.apache.commons.math3.stat.descriptive.rank.Percentile; import org.apache.commons.math3.util.FastMath; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.common.io.ClassPathResource; @@ -141,6 +134,7 @@ import java.io.ObjectOutputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; @@ -149,6 +143,8 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + /** * NDArrayTests * @@ -161,8 +157,6 @@ public class Nd4jTestsC extends BaseNd4jTest { DataType initialType; Level1 l1; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); public Nd4jTestsC(Nd4jBackend backend) { super(backend); @@ -175,7 +169,7 @@ public class Nd4jTestsC extends BaseNd4jTest { return 90000; } - @Before + @BeforeEach public void before() throws Exception { Nd4j.setDataType(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); @@ -183,28 +177,28 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getExecutioner().enableVerboseMode(false); } - @After + @AfterEach public void after() throws Exception { Nd4j.setDataType(initialType); } @Test public void testArangeNegative() { - INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE); - INDArray assertion = Nd4j.create(new double[]{-2, -1, 0, 1}); - assertEquals(assertion,arr); + INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE); + INDArray assertion = Nd4j.create(new double[]{-2, -1, 0, 1}); + assertEquals(assertion,arr); } @Test public void testTri() { - INDArray assertion = Nd4j.create(new double[][]{ - {1,1,1,0,0}, - {1,1,1,1,0}, - {1,1,1,1,1} - }); + INDArray assertion = Nd4j.create(new double[][]{ + {1,1,1,0,0}, + {1,1,1,1,0}, + {1,1,1,1,1} + }); - INDArray tri = Nd4j.tri(3,5,2); - assertEquals(assertion,tri); + INDArray tri = Nd4j.tri(3,5,2); + assertEquals(assertion,tri); } @@ -225,8 +219,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testDiag() { - INDArray diag = Nd4j.diag(Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(4,1)); - assertArrayEquals(new long[] {4,4},diag.shape()); + INDArray diag = Nd4j.diag(Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(4,1)); + assertArrayEquals(new long[] {4,4},diag.shape()); } @Test @@ -243,9 +237,9 @@ public class Nd4jTestsC extends BaseNd4jTest { double dRow = row.getDouble(0); String s = String.valueOf(i); - assertEquals(s, d, d2, 0.0); - assertEquals(s, d, dRowDup, 0.0); //Fails - assertEquals(s, d, dRow, 0.0); //Fails + assertEquals(d, d2, 0.0,s); + assertEquals(d, dRowDup, 0.0,s); //Fails + assertEquals(d, dRow, 0.0,s); //Fails } } @@ -260,11 +254,11 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSerialization() throws Exception { + public void testSerialization(@TempDir Path testDir) throws Exception { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(1, 20); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); String outPath = FilenameUtils.concat(dir.getAbsolutePath(), "dl4jtestserialization.bin"); @@ -290,10 +284,13 @@ public class Nd4jTestsC extends BaseNd4jTest { } - @Ignore // with broadcastables mechanic it'll be ok - @Test(expected = IllegalStateException.class) + @Disabled // with broadcastables mechanic it'll be ok + @Test public void testShapeEqualsOnElementWise() { - Nd4j.ones(10000, 1).sub(Nd4j.ones(1, 2)); + assertThrows(IllegalStateException.class,() -> { + Nd4j.ones(10000, 1).sub(Nd4j.ones(1, 2)); + + }); } @Test @@ -336,7 +333,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore //temporary till libnd4j implements general broadcasting + @Disabled //temporary till libnd4j implements general broadcasting public void testAutoBroadcastAdd() { INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1); INDArray right = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,1,5); @@ -364,9 +361,9 @@ public class Nd4jTestsC extends BaseNd4jTest { n.divi(Nd4j.scalar(1.0d)); n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); - assertEquals(getFailureMessage(), 27, n.sumNumber().doubleValue(), 1e-1); + assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage()); INDArray a = n.slice(2); - assertEquals(getFailureMessage(), true, Arrays.equals(new long[] {3, 3}, a.shape())); + assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage()); } @@ -465,23 +462,23 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray test = arr.mmul(arr.transpose()); - assertEquals(getFailureMessage(), assertion, test); + assertEquals(assertion, test,getFailureMessage()); } @Test - @Ignore + @Disabled public void testMmulOp() throws Exception { INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray z = Nd4j.create(2, 2); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); MMulTranspose mMulTranspose = MMulTranspose.builder() - .transposeB(true) - .build(); + .transposeB(true) + .build(); DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose); Nd4j.getExecutioner().execAndReturn(op); - - assertEquals(getFailureMessage(), assertion, z); + + assertEquals(assertion, z,getFailureMessage()); } @@ -491,7 +488,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray row1 = oneThroughFour.getRow(1).dup(); oneThroughFour.subiRowVector(row1); INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2}); - assertEquals(getFailureMessage(), result, oneThroughFour); + assertEquals(result, oneThroughFour,getFailureMessage()); } @@ -658,7 +655,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 50; i++) { double second = arr.sumNumber().doubleValue(); assertEquals(assertion, second, 1e-1); - assertEquals(String.valueOf(i), first, second, 1e-2); + assertEquals( first, second, 1e-2,String.valueOf(i)); } } @@ -1020,7 +1017,7 @@ public class Nd4jTestsC extends BaseNd4jTest { // System.out.println(merged.data()); // System.out.println(expected); - assertEquals("Failed for [" + order + "] order", expected, merged); + assertEquals( expected, merged,"Failed for [" + order + "] order"); } } @@ -1051,7 +1048,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testSumAlongDim1sEdgeCases() { val shapes = new long[][] { //Standard case: @@ -1132,14 +1129,12 @@ public class Nd4jTestsC extends BaseNd4jTest { double[] cBuffer = resC.data().asDouble(); double[] fBuffer = resF.data().asDouble(); for (int i = 0; i < length; i++) { - assertTrue("c buffer value at [" + i + "]=" + cBuffer[i] + ", expected 0 or 1; dimension = " - + alongDimension + ", rank = " + rank + ", shape=" + Arrays.toString(shape), - cBuffer[i] == 0.0 || cBuffer[i] == 1.0); + assertTrue(cBuffer[i] == 0.0 || cBuffer[i] == 1.0,"c buffer value at [" + i + "]=" + cBuffer[i] + ", expected 0 or 1; dimension = " + + alongDimension + ", rank = " + rank + ", shape=" + Arrays.toString(shape)); } for (int i = 0; i < length; i++) { - assertTrue("f buffer value at [" + i + "]=" + fBuffer[i] + ", expected 0 or 1; dimension = " - + alongDimension + ", rank = " + rank + ", shape=" + Arrays.toString(shape), - fBuffer[i] == 0.0 || fBuffer[i] == 1.0); + assertTrue(fBuffer[i] == 0.0 || fBuffer[i] == 1.0,"f buffer value at [" + i + "]=" + fBuffer[i] + ", expected 0 or 1; dimension = " + + alongDimension + ", rank = " + rank + ", shape=" + Arrays.toString(shape)); } } } @@ -1183,7 +1178,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray row1 = oneThroughFour.getRow(1); row1.addi(1); INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2}); - assertEquals(getFailureMessage(), result, oneThroughFour); + assertEquals(result, oneThroughFour,getFailureMessage()); } @@ -1196,8 +1191,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray linear = test.reshape(-1); linear.putScalar(2, 6); linear.putScalar(3, 7); - assertEquals(getFailureMessage(), 6, linear.getFloat(2), 1e-1); - assertEquals(getFailureMessage(), 7, linear.getFloat(3), 1e-1); + assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage()); + assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage()); } @@ -1235,8 +1230,8 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.gemm(a.dup('f'), b.dup('f'), result2, false, false, 1.0, 0.0); Nd4j.gemm(a, b, result3, false, false, 1.0, 0.0); - assertEquals(msg, result1, result2); - assertEquals(msg, result1, result3); // Fails here + assertEquals(result1, result2,msg); + assertEquals(result1, result3,msg); // Fails here } } } @@ -1547,7 +1542,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray neg = Transforms.neg(n); - assertEquals(getFailureMessage(), assertion, neg); + assertEquals(assertion, neg,getFailureMessage()); } @@ -1559,13 +1554,13 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); double assertion = 5.47722557505; double norm3 = n.norm2Number().doubleValue(); - assertEquals(getFailureMessage(), assertion, norm3, 1e-1); + assertEquals(assertion, norm3, 1e-1,getFailureMessage()); INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row1 = row.getRow(1); double norm2 = row1.norm2Number().doubleValue(); double assertion2 = 5.0f; - assertEquals(getFailureMessage(), assertion2, norm2, 1e-1); + assertEquals(assertion2, norm2, 1e-1,getFailureMessage()); Nd4j.setDataType(initialType); } @@ -1576,14 +1571,14 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); float assertion = 5.47722557505f; float norm3 = n.norm2Number().floatValue(); - assertEquals(getFailureMessage(), assertion, norm3, 1e-1); + assertEquals(assertion, norm3, 1e-1,getFailureMessage()); INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row1 = row.getRow(1); float norm2 = row1.norm2Number().floatValue(); float assertion2 = 5.0f; - assertEquals(getFailureMessage(), assertion2, norm2, 1e-1); + assertEquals(assertion2, norm2, 1e-1,getFailureMessage()); } @@ -1594,7 +1589,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); double sim = Transforms.cosineSim(vec1, vec2); - assertEquals(getFailureMessage(), 1, sim, 1e-1); + assertEquals(1, sim, 1e-1,getFailureMessage()); INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f}); INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f}); @@ -1609,14 +1604,14 @@ public class Nd4jTestsC extends BaseNd4jTest { double assertion = 2; INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8}); INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer); - assertEquals(getFailureMessage(), answer, scal); + assertEquals(answer, scal,getFailureMessage()); INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row1 = row.getRow(1); double assertion2 = 5.0; INDArray answer2 = Nd4j.create(new double[] {15, 20}); INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1); - assertEquals(getFailureMessage(), answer2, scal2); + assertEquals(answer2, scal2,getFailureMessage()); } @@ -1994,17 +1989,17 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray innerProduct = n.mmul(transposed); INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); - assertEquals(getFailureMessage(), scalar, innerProduct); + assertEquals(scalar, innerProduct,getFailureMessage()); INDArray outerProduct = transposed.mmul(n); - assertEquals(getFailureMessage(), true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape())); + assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage()); INDArray three = Nd4j.create(new double[] {3, 4}); INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}); INDArray sliceRow = test.slice(0).getRow(1); - assertEquals(getFailureMessage(), three, sliceRow); + assertEquals(three, sliceRow,getFailureMessage()); INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1}); INDArray threeTwoSix = three.mmul(twoSix); @@ -2032,7 +2027,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray k1 = n1.transpose(); INDArray testVectorVector = k1.mmul(n1); - assertEquals(getFailureMessage(), vectorVector, testVectorVector); + assertEquals(vectorVector, testVectorVector,getFailureMessage()); } @@ -2116,15 +2111,18 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(linear.getDouble(0, 1), 1, 1e-1); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testSize() { - INDArray arr = Nd4j.create(4, 5); + assertThrows(IllegalArgumentException.class,() -> { + INDArray arr = Nd4j.create(4, 5); - for (int i = 0; i < 6; i++) { - //This should fail for i >= 2, but doesn't + for (int i = 0; i < 6; i++) { + //This should fail for i >= 2, but doesn't // System.out.println(arr.size(i)); - arr.size(i); - } + arr.size(i); + } + }); + } @Test @@ -2257,7 +2255,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testTensorDot() { INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape(4, 3, 2).castTo(DataType.DOUBLE); @@ -2919,10 +2917,10 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testMeans() { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray mean1 = a.mean(1); - assertEquals(getFailureMessage(), Nd4j.create(new double[] {1.5, 3.5}), mean1); - assertEquals(getFailureMessage(), Nd4j.create(new double[] {2, 3}), a.mean(0)); - assertEquals(getFailureMessage(), 2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1); - assertEquals(getFailureMessage(), 2.5, a.meanNumber().doubleValue(), 1e-1); + assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage()); + assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage()); + assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage()); + assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage()); } @@ -2930,9 +2928,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testSums() { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); - assertEquals(getFailureMessage(), Nd4j.create(new double[] {3, 7}), a.sum(1)); - assertEquals(getFailureMessage(), Nd4j.create(new double[] {4, 6}), a.sum(0)); - assertEquals(getFailureMessage(), 10, a.sumNumber().doubleValue(), 1e-1); + assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage()); + assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage()); + assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage()); } @@ -3244,8 +3242,8 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(dup.ordering(), ordering()); assertEquals(dupc.ordering(), 'c'); assertEquals(dupf.ordering(), 'f'); - assertEquals(msg, in, dupc); - assertEquals(msg, in, dupf); + assertEquals(in, dupc,msg); + assertEquals(in, dupf,msg); } } @@ -3264,12 +3262,12 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray dupf = Shape.toOffsetZeroCopy(in, 'f'); INDArray dupany = Shape.toOffsetZeroCopyAnyOrder(in); - assertEquals(msg, in, dup); - assertEquals(msg, in, dupc); - assertEquals(msg, in, dupf); - assertEquals(msg, dupc.ordering(), 'c'); - assertEquals(msg, dupf.ordering(), 'f'); - assertEquals(msg, in, dupany); + assertEquals(in, dup,msg); + assertEquals(in, dupc,msg); + assertEquals(in, dupf,msg); + assertEquals(dupc.ordering(), 'c',msg); + assertEquals(dupf.ordering(), 'f',msg); + assertEquals(in, dupany,msg); assertEquals(dup.offset(), 0); assertEquals(dupc.offset(), 0); @@ -3283,7 +3281,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void largeInstantiation() { Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024)); // Crashes @@ -3330,7 +3328,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore //not relevant anymore + @Disabled //not relevant anymore public void testAssignMixedC() { int[] shape1 = {3, 2, 2, 2, 2, 2}; int[] shape2 = {12, 8}; @@ -3617,29 +3615,44 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(assertion, result); } - @Test(expected = IllegalStateException.class) + @Test() public void testPullRowsValidation1() { - Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2}); + assertThrows(IllegalStateException.class,() -> { + Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2}); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testPullRowsValidation2() { - Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2}); + assertThrows(IllegalStateException.class,() -> { + Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2}); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testPullRowsValidation3() { - Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10}); + assertThrows(IllegalStateException.class,() -> { + Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10}); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testPullRowsValidation4() { - Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3}); + assertThrows(IllegalStateException.class,() -> { + Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3}); + + }); } - @Test(expected = IllegalStateException.class) + @Test() public void testPullRowsValidation5() { - Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e'); + assertThrows(IllegalStateException.class,() -> { + Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e'); + + }); } @@ -3859,7 +3872,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 3; i++) { INDArray subset = result12.tensorAlongDimension(i, 1, 2);//result12.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()); - assertEquals("Failed for subset [" + i + "] orders [" + orderArr + "/" + orderbc + "]", bc12, subset); + assertEquals( bc12, subset,"Failed for subset [" + i + "] orders [" + orderArr + "/" + orderbc + "]"); } } } @@ -4362,7 +4375,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(4.5, result.meanNumber().doubleValue(), 0.01); for (int i = 0; i < 10; i++) { - assertEquals("Failed on iteration " + i, result, arrays.get(i)); + assertEquals(result, arrays.get(i),"Failed on iteration " + i); } } @@ -4382,7 +4395,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(4.5, result.meanNumber().doubleValue(), 0.01); for (int i = 0; i < 10; i++) { - assertEquals("Failed on iteration " + i, result, arrays.get(i)); + assertEquals(result, arrays.get(i),"Failed on iteration " + i); } } @@ -4446,7 +4459,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } for (int i = 0; i < out.size(); i++) { - assertEquals("Failed at iteration: [" + i + "]", out.get(i), comp.get(i)); + assertEquals(out.get(i), comp.get(i),"Failed at iteration: [" + i + "]"); } } @@ -4524,7 +4537,7 @@ public class Nd4jTestsC extends BaseNd4jTest { // log.info("Comparison ----------------------------------------------"); for (int i = 0; i < initial.rows(); i++) { val row = result.getRow(i); - assertEquals("Failed at row " + i, exp, row); + assertEquals(exp, row,"Failed at row " + i); // log.info("-------------------"); } } @@ -4550,7 +4563,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < initial.rows(); i++) { - assertEquals("Failed at row " + i, exp, result.getRow(i)); + assertEquals(exp, result.getRow(i),"Failed at row " + i); } } @@ -4573,7 +4586,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < initial.rows(); i++) { - assertEquals("Failed at row " + i,exp, result.getRow(i)); + assertEquals(exp, result.getRow(i),"Failed at row " + i); } } @@ -4595,7 +4608,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < initial.rows(); i++) { - assertEquals("Failed at row " + i, exp, result.getRow(i)); + assertEquals( exp, result.getRow(i),"Failed at row " + i); } } @@ -4631,7 +4644,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < haystack.rows(); i++) { val row = haystack.getRow(i).dup(); double res = Nd4j.getExecutioner().execAndReturn(new CosineDistance(row, needle)).z().getDouble(0); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 1e-5); + assertEquals(reduced.getDouble(i), res, 1e-5,"Failed at " + i); } } @@ -4663,7 +4676,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < initial.rows(); i++) { double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(initial.getRow(i).dup(), needle)) .getFinalResult().doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals( reduced.getDouble(i), res, 0.001,"Failed at " + i); } } @@ -4681,7 +4694,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < initial.rows(); i++) { double res = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(initial.getRow(i).dup(), needle)) .getFinalResult().doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals(reduced.getDouble(i), res, 0.001,"Failed at " + i); } } @@ -4700,7 +4713,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray x = initial.getRow(i).dup(); double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult() .doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals( reduced.getDouble(i), res, 0.001,"Failed at " + i); } } @@ -4719,7 +4732,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray x = initial.getRow(i).dup(); double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult() .doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals(reduced.getDouble(i), res, 0.001,"Failed at " + i); } } @@ -4739,18 +4752,21 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray x = initial.getRow(i).dup(); double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(x, needle)).getFinalResult() .doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals(reduced.getDouble(i), res, 0.001,"Failed at " + i); } } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testTadReduce3_5() { - INDArray initial = Nd4j.create(5, 10); - for (int i = 0; i < initial.rows(); i++) { - initial.getRow(i).assign(i + 1); - } - INDArray needle = Nd4j.create(2, 10).assign(1.0); - INDArray reduced = Nd4j.getExecutioner().exec(new EuclideanDistance(initial, needle, 1)); + assertThrows(ND4JIllegalStateException.class,() -> { + INDArray initial = Nd4j.create(5, 10); + for (int i = 0; i < initial.rows(); i++) { + initial.getRow(i).assign(i + 1); + } + INDArray needle = Nd4j.create(2, 10).assign(1.0); + INDArray reduced = Nd4j.getExecutioner().exec(new EuclideanDistance(initial, needle, 1)); + }); + } @@ -4769,7 +4785,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = Nd4j.getExecutioner() .execAndReturn(new ManhattanDistance(initial.tensorAlongDimension(i, 1, 2).dup(), needle)) .getFinalResult().doubleValue(); - assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001); + assertEquals(reduced.getDouble(i), res, 0.001,"Failed at " + i); } } @@ -4850,9 +4866,9 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int r = 0; r < x.rows(); r++) { if (r == 4) { - assertEquals("Failed at " + r, 0.0, res.getDouble(r), 1e-5); + assertEquals(0.0, res.getDouble(r), 1e-5,"Failed at " + r); } else { - assertEquals("Failed at " + r, 2.0 / 6, res.getDouble(r), 1e-5); + assertEquals(2.0 / 6, res.getDouble(r), 1e-5,"Failed at " + r); } } } @@ -4884,7 +4900,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -4914,7 +4930,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.manhattanDistance(rowX, initialY.getRow(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals( exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -4944,7 +4960,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.manhattanDistance(rowX, initialY.getRow(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -4976,7 +4992,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5006,7 +5022,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.euclideanDistance(colX, initialY.getColumn(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5036,7 +5052,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.manhattanDistance(colX, initialY.getColumn(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5066,7 +5082,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.cosineDistance(colX, initialY.getColumn(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5095,7 +5111,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.manhattanDistance(colX, initialY.getColumn(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5120,7 +5136,7 @@ public class Nd4jTestsC extends BaseNd4jTest { double res = result.getDouble(x, y); double exp = Transforms.cosineSim(rowX, initialY.getRow(y).dup()); - assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001); + assertEquals(exp, res, 0.001,"Failed for [" + x + ", " + y + "]"); } } } @@ -5464,7 +5480,7 @@ public class Nd4jTestsC extends BaseNd4jTest { val d = res.getRow(r).dup(); assertArrayEquals(e, d.toDoubleVector(), 1e-5); - assertEquals("Failed at " + r, exp1, d); + assertEquals(exp1, d,"Failed at " + r); } } @@ -5526,8 +5542,8 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int r = 0; r < array.rows(); r++) { val jrow = res.getRow(r).toFloatVector(); //log.info("jrow: {}", jrow); - assertArrayEquals("Failed at " + r, jexp, jrow, 1e-5f); - assertEquals("Failed at " + r, exp1, res.getRow(r)); + assertArrayEquals(jexp, jrow, 1e-5f,"Failed at " + r); + assertEquals( exp1, res.getRow(r),"Failed at " + r); //assertArrayEquals("Failed at " + r, exp1.data().asDouble(), res.getRow(r).dup().data().asDouble(), 1e-5); } } @@ -5544,7 +5560,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray res = Nd4j.sort(array, 1, false); for (int r = 0; r < array.rows(); r++) { - assertEquals("Failed at " + r, exp1, res.getRow(r).dup()); + assertEquals(exp1, res.getRow(r).dup(),"Failed at " + r); } } @@ -5713,7 +5729,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testLogExpSum1() { INDArray matrix = Nd4j.create(3, 3); for (int r = 0; r < matrix.rows(); r++) { @@ -5728,7 +5744,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testLogExpSum2() { INDArray row = Nd4j.create(new double[]{1, 2, 3}); @@ -5941,13 +5957,16 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testReshapeFailure() { - val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); - val b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); - val score = a.mmul(b); - val reshaped1 = score.reshape(2,100); - val reshaped2 = score.reshape(2,1); + assertThrows(ND4JIllegalStateException.class,() -> { + val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); + val b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); + val score = a.mmul(b); + val reshaped1 = score.reshape(2,100); + val reshaped2 = score.reshape(2,1); + }); + } @@ -6031,32 +6050,38 @@ public class Nd4jTestsC extends BaseNd4jTest { assertArrayEquals(new long[]{3, 2}, newShape.shape()); } - @Test(expected = IllegalStateException.class) + @Test() public void testTranspose1() { - val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); + assertThrows(IllegalStateException.class,() -> { + val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); - assertArrayEquals(new long[]{6}, vector.shape()); - assertArrayEquals(new long[]{1}, vector.stride()); + assertArrayEquals(new long[]{6}, vector.shape()); + assertArrayEquals(new long[]{1}, vector.stride()); - val transposed = vector.transpose(); + val transposed = vector.transpose(); + + assertArrayEquals(vector.shape(), transposed.shape()); + }); - assertArrayEquals(vector.shape(), transposed.shape()); } - @Test(expected = IllegalStateException.class) + @Test() public void testTranspose2() { - val scalar = Nd4j.scalar(2.f); + assertThrows(IllegalStateException.class,() -> { + val scalar = Nd4j.scalar(2.f); - assertArrayEquals(new long[]{}, scalar.shape()); - assertArrayEquals(new long[]{}, scalar.stride()); + assertArrayEquals(new long[]{}, scalar.shape()); + assertArrayEquals(new long[]{}, scalar.stride()); - val transposed = scalar.transpose(); + val transposed = scalar.transpose(); + + assertArrayEquals(scalar.shape(), transposed.shape()); + }); - assertArrayEquals(scalar.shape(), transposed.shape()); } @Test - //@Ignore + //@Disabled public void testMatmul_128by256() { val mA = Nd4j.create(128, 156).assign(1.0f); val mB = Nd4j.create(156, 256).assign(1.0f); @@ -6312,11 +6337,14 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp1, out1); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testBadReduce3Call() { - val x = Nd4j.create(400,20); - val y = Nd4j.ones(1, 20); - x.distance2(y); + assertThrows(ND4JIllegalStateException.class,() -> { + val x = Nd4j.create(400,20); + val y = Nd4j.ones(1, 20); + x.distance2(y); + }); + } @@ -6351,7 +6379,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray out = Nd4j.concat(0, arr1, arr2); Nd4j.getExecutioner().commit(); INDArray exp = Nd4j.create(new double[][]{{1, 2}, {3, 4}}); - assertEquals(String.valueOf(order), exp, out); + assertEquals(exp, out,String.valueOf(order)); } } @@ -6422,7 +6450,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for(int i=0;i<100;i++){ INDArray out2 = fwd(in, W, b); //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces()); - assertEquals("Failed at iteration [" + String.valueOf(i) + "]", out, out2); + assertEquals( out, out2,"Failed at iteration [" + String.valueOf(i) + "]"); } } @@ -6441,7 +6469,7 @@ public class Nd4jTestsC extends BaseNd4jTest { int cnt = 0; for (val f : fArray) - assertTrue("Failed for element [" + cnt++ +"]",f > 0.0f); + assertTrue(f > 0.0f,"Failed for element [" + cnt++ +"]"); } @@ -6460,7 +6488,7 @@ public class Nd4jTestsC extends BaseNd4jTest { int cnt = 0; for (val f : fArray) - assertTrue("Failed for element [" + cnt++ +"]",f > 0.0f); + assertTrue(f > 0.0f,"Failed for element [" + cnt++ +"]"); } @Test @@ -6956,7 +6984,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testMatmul_vs_tf() throws Exception { // uncomment this line to initialize & propagate sgemm/dgemm pointer @@ -7033,14 +7061,17 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(ez, z); } - @Test(expected = IllegalStateException.class) + @Test() public void testBroadcastInvalid(){ - INDArray arr1 = Nd4j.ones(3,4,1); + assertThrows(IllegalStateException.class,() -> { + INDArray arr1 = Nd4j.ones(3,4,1); + + //Invalid op: y must match x/z dimensions 0 and 2 + INDArray arrInvalid = Nd4j.create(3,12); + Nd4j.getExecutioner().exec(new BroadcastMulOp(arr1, arrInvalid, arr1, 0, 2)); + fail("Excepted exception on invalid input"); + }); - //Invalid op: y must match x/z dimensions 0 and 2 - INDArray arrInvalid = Nd4j.create(3,12); - Nd4j.getExecutioner().exec(new BroadcastMulOp(arr1, arrInvalid, arr1, 0, 2)); - fail("Excepted exception on invalid input"); } @Test @@ -7173,7 +7204,7 @@ public class Nd4jTestsC extends BaseNd4jTest { default: throw new RuntimeException(String.valueOf(i)); } - assertArrayEquals(String.valueOf(i), expShape, out.shape()); + assertArrayEquals(expShape, out.shape(),String.valueOf(i)); } } @@ -7204,7 +7235,7 @@ public class Nd4jTestsC extends BaseNd4jTest { target.put(targetIndexes, source); final INDArray expected = Nd4j.concat(0, Nd4j.ones(shapeSource), Nd4j.zeros(diffShape)); - assertEquals("Expected array to be set!", expected, target); + assertEquals(expected, target,"Expected array to be set!"); } } @@ -7281,17 +7312,20 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, array); } - @Test(expected = IllegalStateException.class) + @Test() public void testScatterUpdateShortcut_f1() { - val array = Nd4j.create(DataType.FLOAT, 5, 2); - val updates = Nd4j.createFromArray(new float[][] {{1,1}, {2,2}, {3, 3}}); - val indices = Nd4j.createFromArray(new int[]{1, 2, 3}); - val exp = Nd4j.createFromArray(new float[][] {{0,0}, {1,1}, {2,2}, {3, 3}, {0,0}}); + assertThrows(IllegalStateException.class,() -> { + val array = Nd4j.create(DataType.FLOAT, 5, 2); + val updates = Nd4j.createFromArray(new float[][] {{1,1}, {2,2}, {3, 3}}); + val indices = Nd4j.createFromArray(new int[]{1, 2, 3}); + val exp = Nd4j.createFromArray(new float[][] {{0,0}, {1,1}, {2,2}, {3, 3}, {0,0}}); - assertArrayEquals(exp.shape(), array.shape()); - Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, array, indices, updates, 0); + assertArrayEquals(exp.shape(), array.shape()); + Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, array, indices, updates, 0); + + assertEquals(exp, array); + }); - assertEquals(exp, array); } @Test @@ -7413,7 +7447,7 @@ public class Nd4jTestsC extends BaseNd4jTest { arr2.assign(arr1); fail("Expected exception"); } catch (IllegalStateException e){ - assertTrue(e.getMessage(), e.getMessage().contains("shape")); + assertTrue( e.getMessage().contains("shape"),e.getMessage()); } } @@ -7432,13 +7466,13 @@ public class Nd4jTestsC extends BaseNd4jTest { String str = from + " -> " + to; - assertEquals(str, from, emptyFrom.dataType()); - assertTrue(str, emptyFrom.isEmpty()); - assertEquals(str,0, emptyFrom.length()); + assertEquals(from, emptyFrom.dataType(),str); + assertTrue(emptyFrom.isEmpty(),str); + assertEquals(0, emptyFrom.length(),str); - assertEquals(str, to, emptyTo.dataType()); - assertTrue(str, emptyTo.isEmpty()); - assertEquals(str,0, emptyTo.length()); + assertEquals(to, emptyTo.dataType(),str); + assertTrue(emptyTo.isEmpty(),str); + assertEquals(0, emptyTo.length(),str); } } } @@ -7501,7 +7535,7 @@ public class Nd4jTestsC extends BaseNd4jTest { stepped = Nd4j.linspace(DataType.DOUBLE, lower, step, 10); for (int i = 0; i < 10; ++i) { - assertEquals(lower + i * step, stepped.getDouble(i), 1e-5); + assertEquals(lower + i * step, stepped.getDouble(i), 1e-5); } } @@ -7605,7 +7639,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr4 = arr1a.reshape('c', true, 4,1); fail("Expected exception"); } catch (ND4JIllegalStateException e){ - assertTrue(e.getMessage(), e.getMessage().contains("Unable to reshape array as view")); + assertTrue(e.getMessage().contains("Unable to reshape array as view"),e.getMessage()); } } @@ -7646,10 +7680,13 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, out); //Failing here } - @Test(expected = IllegalArgumentException.class) + @Test() public void testPullRowsFailure() { - val idxs = new int[]{0,2,3,4}; - val out = Nd4j.pullRows(Nd4j.createFromArray(0.0, 1.0, 2.0, 3.0, 4.0), 0, idxs); + assertThrows(IllegalArgumentException.class,() -> { + val idxs = new int[]{0,2,3,4}; + val out = Nd4j.pullRows(Nd4j.createFromArray(0.0, 1.0, 2.0, 3.0, 4.0), 0, idxs); + }); + } @Test @@ -7740,20 +7777,26 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp1, out1); //This is OK } - @Test(expected = IllegalArgumentException.class) + @Test() public void testPutRowValidation() { - val matrix = Nd4j.create(5, 10); - val row = Nd4j.create(25); + assertThrows(IllegalArgumentException.class,() -> { + val matrix = Nd4j.create(5, 10); + val row = Nd4j.create(25); + + matrix.putRow(1, row); + }); - matrix.putRow(1, row); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testPutColumnValidation() { - val matrix = Nd4j.create(5, 10); - val column = Nd4j.create(25); + assertThrows(IllegalArgumentException.class,() -> { + val matrix = Nd4j.create(5, 10); + val column = Nd4j.create(25); + + matrix.putColumn(1, column); + }); - matrix.putColumn(1, column); } @Test @@ -7834,7 +7877,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(scalarRank2, scalarRank2.dup()); } - //@Ignore // https://github.com/eclipse/deeplearning4j/issues/7632 + //@Disabled // https://github.com/eclipse/deeplearning4j/issues/7632 @Test public void testGetWhereINDArray() { INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); @@ -7855,10 +7898,10 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testType1() throws IOException { + public void testType1(@TempDir Path testDir) throws IOException { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100}); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir,"test.bin"))); oos.writeObject(in1); @@ -7896,10 +7939,10 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testType2() throws IOException { + public void testType2(@TempDir Path testDir) throws IOException { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.ones(DataType.UINT16); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir, "test1.bin"))); oos.writeObject(in1); @@ -7916,7 +7959,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.ones(DataType.UINT32); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir, "test2.bin"))); oos.writeObject(in1); @@ -7933,7 +7976,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.ones(DataType.UINT64); - File dir = testDir.newFolder(); + File dir = testDir.toFile(); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir, "test3.bin"))); oos.writeObject(in1); @@ -8043,7 +8086,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void mmulToScalar() { final INDArray arr1 = Nd4j.create(new float[] {1,2,3}).reshape(1,3); final INDArray arr2 = arr1.reshape(3,1); - assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType()); + assertEquals( DataType.FLOAT, arr1.mmul(arr2).dataType(),"Incorrect type!"); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java index 5a93069c3..52bb00738 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java @@ -20,9 +20,9 @@ package org.nd4j.linalg; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -38,7 +38,7 @@ import org.slf4j.LoggerFactory; import java.util.List; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; @@ -56,12 +56,12 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { } - @Before + @BeforeEach public void before() throws Exception { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } - @After + @AfterEach public void after() throws Exception { DataTypeUtil.setDTypeForContext(initialType); } @@ -106,14 +106,14 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2); String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T); //System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m))); - assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a, - b, 1e-4, 1e-6)); + assertTrue( CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a, + b, 1e-4, 1e-6),errorMsgff); + assertTrue(CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a, + b, 1e-4, 1e-6),errorMsgft); + assertTrue(CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a, + b, 1e-4, 1e-6),errorMsgtf); + assertTrue( CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a, + b, 1e-4, 1e-6),errorMsgtt); //Also: Confirm that if the C array is uninitialized and beta is 0.0, we don't have issues like 0*NaN = NaN if (b == 0.0) { @@ -122,14 +122,14 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { ctf.assign(Double.NaN); ctt.assign(Double.NaN); - assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, - a, b, 1e-4, 1e-6)); - assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, - a, b, 1e-4, 1e-6)); - assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, - a, b, 1e-4, 1e-6)); - assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, - a, b, 1e-4, 1e-6)); + assertTrue( CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, + a, b, 1e-4, 1e-6),errorMsgff); + assertTrue( CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, + a, b, 1e-4, 1e-6),errorMsgft); + assertTrue(CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, + a, b, 1e-4, 1e-6),errorMsgtf); + assertTrue(CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, + a, b, 1e-4, 1e-6),errorMsgtt); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java index 19690d56a..a45cebc75 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java @@ -22,9 +22,9 @@ package org.nd4j.linalg; import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.RealMatrix; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory; import java.util.List; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @@ -57,14 +57,14 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } - @Before + @BeforeEach public void before() throws Exception { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(SEED); } - @After + @AfterEach public void after() throws Exception { DataTypeUtil.setDTypeForContext(initialType); } @@ -94,7 +94,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { Pair p1 = first.get(i); Pair p2 = second.get(j); String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2); - assertTrue(errorMsg, CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6)); + assertTrue(CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6),errorMsg); } } } @@ -141,14 +141,14 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2); String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T); - assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a, - b, 1e-4, 1e-6)); - assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a, - b, 1e-4, 1e-6)); + assertTrue(CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a, + b, 1e-4, 1e-6),errorMsgff); + assertTrue(CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a, + b, 1e-4, 1e-6),errorMsgft); + assertTrue(CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a, + b, 1e-4, 1e-6),errorMsgtf); + assertTrue(CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a, + b, 1e-4, 1e-6),errorMsgtt); } } } @@ -203,7 +203,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { for (int r = 0; r < rows; r++) { double exp = gemv2.getEntry(r, 0); double act = gemv.getDouble(r, 0); - assertEquals(errorMsg, exp, act, 1e-5); + assertEquals(exp, act, 1e-5,errorMsg); } } } @@ -221,9 +221,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2); String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2); boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6); - assertTrue(errorMsg1, addFail); + assertTrue(addFail,errorMsg1); boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6); - assertTrue(errorMsg2, subFail); + assertTrue(subFail,errorMsg2); } } } @@ -238,8 +238,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { Pair p2 = second.get(j); String errorMsg1 = getTestWithOpsErrorMsg(i, j, "mul", p1, p2); String errorMsg2 = getTestWithOpsErrorMsg(i, j, "div", p1, p2); - assertTrue(errorMsg1, CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6)); - assertTrue(errorMsg2, CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6)); + assertTrue( CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6),errorMsg1); + assertTrue(CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6),errorMsg2); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java index 4f127b9bc..20c783031 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java @@ -22,7 +22,7 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java index 53fdadfdd..e31f9fbf8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.*; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class ShufflesTests extends BaseNd4jTest { @@ -247,11 +247,11 @@ public class ShufflesTests extends BaseNd4jTest { for (int i = 0; i < array1.length; i++) { if (i >= array1.length / 2) { - assertEquals("Failed on element [" + i + "]", -1, array1[i]); - assertEquals("Failed on element [" + i + "]", -1, array2[i]); + assertEquals(-1, array1[i],"Failed on element [" + i + "]"); + assertEquals(-1, array2[i],"Failed on element [" + i + "]"); } else { - assertNotEquals("Failed on element [" + i + "]", -1, array1[i]); - assertNotEquals("Failed on element [" + i + "]", -1, array2[i]); + assertNotEquals(-1, array1[i],"Failed on element [" + i + "]"); + assertNotEquals(-1, array2[i],"Failed on element [" + i + "]"); } } } @@ -268,11 +268,11 @@ public class ShufflesTests extends BaseNd4jTest { for (int i = 0; i < array1.length; i++) { if (i % 2 != 0) { - assertEquals("Failed on element [" + i + "]", -1, array1[i]); - assertEquals("Failed on element [" + i + "]", -1, array2[i]); + assertEquals( -1, array1[i],"Failed on element [" + i + "]"); + assertEquals(-1, array2[i],"Failed on element [" + i + "]"); } else { - assertNotEquals("Failed on element [" + i + "]", -1, array1[i]); - assertNotEquals("Failed on element [" + i + "]", -1, array2[i]); + assertNotEquals(-1, array1[i],"Failed on element [" + i + "]"); + assertNotEquals( -1, array2[i],"Failed on element [" + i + "]"); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java index 1c200151d..ef0ac7afe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java @@ -21,9 +21,9 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.util.ArrayUtil; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) @Slf4j @@ -46,12 +46,12 @@ public class TestEigen extends BaseNd4jTest { initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void before() { Nd4j.setDataType(DataType.DOUBLE); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java index eab18991c..cbd99c8cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java @@ -20,10 +20,10 @@ package org.nd4j.linalg; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java index c7e5458d0..97a8270d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.activations; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -53,7 +53,7 @@ import java.util.Iterator; import java.util.List; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class TestActivation extends BaseNd4jTest { @@ -69,7 +69,7 @@ public class TestActivation extends BaseNd4jTest { private ObjectMapper mapper; - @Before + @BeforeEach public void initMapper() { mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @@ -173,7 +173,7 @@ public class TestActivation extends BaseNd4jTest { String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields) + "\tActual fields: " + actualFieldsByName; - assertEquals(msg, expFields.length, actualFieldsByName.size()); + assertEquals(expFields.length, actualFieldsByName.size(),msg); for (String s : expFields) { msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java index 317118a3f..a2229aa0f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java @@ -19,13 +19,13 @@ */ package org.nd4j.linalg.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertFalse; public class TestBackend extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java index 95bf9bc03..8ee444adf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java @@ -19,13 +19,13 @@ */ package org.nd4j.linalg.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertFalse; public class TestEnvironment extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java index 2657f4348..1a3ce86f6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java @@ -24,8 +24,8 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.Pointer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -37,7 +37,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestNDArrayCreation extends BaseNd4jTest { @@ -48,7 +48,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testBufferCreation() { DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2}); Pointer pointer = dataBuffer.pointer(); @@ -68,7 +68,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testCreateNpy() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); assertEquals(2, arrCreate.size(0)); @@ -81,7 +81,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCreateNpz() throws Exception { Map map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile()); assertEquals(true, map.containsKey("x")); @@ -100,7 +100,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testCreateNpy3() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); assertEquals(8, arrCreate.length()); @@ -112,7 +112,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Ignore // this is endless test + @Disabled // this is endless test public void testEndlessAllocation() { Nd4j.getEnvironment().setMaxSpecialMemory(1); while (true) { @@ -122,7 +122,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Ignore("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time") + @Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time") public void testAllocationLimits() throws Exception { Nd4j.create(1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java index 71b736d76..9d6dc2988 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,7 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.primitives.Pair; import org.nd4j.common.util.ArrayUtil; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; public class TestNDArrayCreationUtil extends BaseNd4jTest { @@ -43,27 +43,27 @@ public class TestNDArrayCreationUtil extends BaseNd4jTest { long[] shape2d = {2, 3}; for (Pair p : NDArrayCreationUtil.getAllTestMatricesWithShape(2, 3, 12345, DataType.DOUBLE)) { - assertArrayEquals(p.getSecond(), shape2d, p.getFirst().shape()); + assertArrayEquals(shape2d, p.getFirst().shape(),p.getSecond()); } long[] shape3d = {2, 3, 4}; for (Pair p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape3d, DataType.DOUBLE)) { - assertArrayEquals(p.getSecond(), shape3d, p.getFirst().shape()); + assertArrayEquals( shape3d, p.getFirst().shape(),p.getSecond()); } long[] shape4d = {2, 3, 4, 5}; for (Pair p : NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, ArrayUtil.toInts(shape4d), DataType.DOUBLE)) { - assertArrayEquals(p.getSecond(), shape4d, p.getFirst().shape()); + assertArrayEquals(shape4d, p.getFirst().shape(),p.getSecond()); } long[] shape5d = {2, 3, 4, 5, 6}; for (Pair p : NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, ArrayUtil.toInts(shape5d), DataType.DOUBLE)) { - assertArrayEquals(p.getSecond(), shape5d, p.getFirst().shape()); + assertArrayEquals( shape5d, p.getFirst().shape(),p.getSecond()); } long[] shape6d = {2, 3, 4, 5, 6, 7}; for (Pair p : NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, ArrayUtil.toInts(shape6d), DataType.DOUBLE)) { - assertArrayEquals(p.getSecond(), shape6d, p.getFirst().shape()); + assertArrayEquals( shape6d, p.getFirst().shape(),p.getSecond()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java index d013bae6d..3e8990d63 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java index a16a538a3..6d34fec58 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.blas; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class LapackTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java index d5f8b918c..466af9744 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.api.blas; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,7 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson @@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTest { INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row = matrix.getRow(1); Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row); - assertEquals(getFailureMessage(), Nd4j.create(new double[] {4, 8}), row); + assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java index 3f2a44b81..3cab5d94a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.blas; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class Level2Test extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java index 2cb543107..c26b3e9fb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.blas; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class Level3Test extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java index f57e929ed..24c8a8ea8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.blas.params; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index 6cb734650..30f426216 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -23,8 +23,8 @@ package org.nd4j.linalg.api.buffer; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -42,7 +42,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -53,7 +53,7 @@ public class DataBufferTests extends BaseNd4jTest { } @Test - @Ignore("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") + @Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") public void testNoArgCreateBufferFromArray() { //Tests here: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java index dad065d96..1719ce084 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java @@ -20,9 +20,9 @@ package org.nd4j.linalg.api.buffer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,6 +31,8 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.jupiter.api.Assertions.assertThrows; + @RunWith(Parameterized.class) public class DataTypeValidationTests extends BaseNd4jTest { DataType initialType; @@ -39,13 +41,13 @@ public class DataTypeValidationTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.FLOAT); } - @After + @AfterEach public void shutUp() { Nd4j.setDataType(initialType); } @@ -70,44 +72,53 @@ public class DataTypeValidationTests extends BaseNd4jTest { /** * Testing level1 blas */ - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testBlasValidation1() { - INDArray x = Nd4j.create(10); + assertThrows(ND4JIllegalStateException.class,() -> { + INDArray x = Nd4j.create(10); - Nd4j.setDataType(DataType.DOUBLE); + Nd4j.setDataType(DataType.DOUBLE); - INDArray y = Nd4j.create(10); + INDArray y = Nd4j.create(10); + + Nd4j.getBlasWrapper().dot(x, y); + }); - Nd4j.getBlasWrapper().dot(x, y); } /** * Testing level2 blas */ - @Test(expected = RuntimeException.class) + @Test() public void testBlasValidation2() { - INDArray a = Nd4j.create(100, 10); - INDArray x = Nd4j.create(100); + assertThrows(RuntimeException.class,() -> { + INDArray a = Nd4j.create(100, 10); + INDArray x = Nd4j.create(100); - Nd4j.setDataType(DataType.DOUBLE); + Nd4j.setDataType(DataType.DOUBLE); - INDArray y = Nd4j.create(100); + INDArray y = Nd4j.create(100); + + Nd4j.getBlasWrapper().gemv(1.0, a, x, 1.0, y); + }); - Nd4j.getBlasWrapper().gemv(1.0, a, x, 1.0, y); } /** * Testing level3 blas */ - @Test(expected = IllegalStateException.class) + @Test() public void testBlasValidation3() { - INDArray x = Nd4j.create(100, 100); + assertThrows(IllegalStateException.class,() -> { + INDArray x = Nd4j.create(100, 100); - Nd4j.setDataType(DataType.DOUBLE); + Nd4j.setDataType(DataType.DOUBLE); - INDArray y = Nd4j.create(100, 100); + INDArray y = Nd4j.create(100, 100); + + x.mmul(y); + }); - x.mmul(y); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java index 4b9a6b9eb..ccaa1f4d1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java @@ -23,8 +23,9 @@ package org.nd4j.linalg.api.buffer; import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.indexer.DoubleIndexer; import org.bytedeco.javacpp.indexer.Indexer; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.io.TempDir; + import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -40,9 +41,10 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.util.SerializationUtils; import java.io.*; +import java.nio.file.Path; import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * Double data buffer tests @@ -53,11 +55,10 @@ import static org.junit.Assert.*; * @author Adam Gibson */ @RunWith(Parameterized.class) -@Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") +@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public class DoubleDataBufferTest extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + DataType initialType; @@ -68,13 +69,13 @@ public class DoubleDataBufferTest extends BaseNd4jTest { - @Before + @BeforeEach public void before() { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } - @After + @AfterEach public void after() { DataTypeUtil.setDTypeForContext(initialType); } @@ -128,8 +129,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest { @Test - public void testSerialization() throws Exception { - File dir = testDir.newFolder(); + public void testSerialization(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); DataBuffer buf = Nd4j.createBuffer(5); String fileName = "buf.ser"; File file = new File(dir, fileName); @@ -257,12 +258,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest { DataBuffer wrappedBuffer = Nd4j.createBuffer(buffer, 1, 2); DoublePointer pointer = (DoublePointer) wrappedBuffer.addressPointer(); - Assert.assertEquals(buffer.getDouble(1), pointer.get(0), 1e-1); - Assert.assertEquals(buffer.getDouble(2), pointer.get(1), 1e-1); + assertEquals(buffer.getDouble(1), pointer.get(0), 1e-1); + assertEquals(buffer.getDouble(2), pointer.get(1), 1e-1); try { pointer.asBuffer().get(3); // Try to access element outside pointer capacity. - Assert.fail("Accessing this address should not be allowed!"); + fail("Accessing this address should not be allowed!"); } catch (IndexOutOfBoundsException e) { // do nothing } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java index 7fe812440..1dce2d107 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java @@ -24,8 +24,9 @@ import lombok.val; import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.indexer.FloatIndexer; import org.bytedeco.javacpp.indexer.Indexer; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.io.TempDir; + import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -40,8 +41,9 @@ import org.nd4j.common.util.SerializationUtils; import java.io.*; import java.nio.ByteBuffer; +import java.nio.file.Path; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * Float data buffer tests @@ -51,12 +53,9 @@ import static org.junit.Assert.*; * * @author Adam Gibson */ -@Ignore("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") +@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public class FloatDataBufferTest extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - DataType initialType; public FloatDataBufferTest(Nd4jBackend backend) { @@ -64,13 +63,13 @@ public class FloatDataBufferTest extends BaseNd4jTest { initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void before() { DataTypeUtil.setDTypeForContext(DataType.FLOAT); System.out.println("DATATYPE HERE: " + Nd4j.dataType()); } - @After + @AfterEach public void after() { DataTypeUtil.setDTypeForContext(initialType); } @@ -90,15 +89,15 @@ public class FloatDataBufferTest extends BaseNd4jTest { float[] d1 = new float[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); float[] d2 = d.asFloat(); - assertArrayEquals(getFailureMessage(), d1, d2, 1e-1f); + assertArrayEquals( d1, d2, 1e-1f,getFailureMessage()); } @Test - public void testSerialization() throws Exception { - File dir = testDir.newFolder(); + public void testSerialization(@TempDir Path tempDir) throws Exception { + File dir = tempDir.toFile(); DataBuffer buf = Nd4j.createBuffer(5); String fileName = "buf.ser"; File file = new File(dir, fileName); @@ -144,7 +143,7 @@ public class FloatDataBufferTest extends BaseNd4jTest { d.put(0, 0.0); float[] result = new float[] {0, 2, 3, 4}; d1 = d.asFloat(); - assertArrayEquals(getFailureMessage(), d1, result, 1e-1f); + assertArrayEquals(d1, result, 1e-1f,getFailureMessage()); } @@ -153,12 +152,12 @@ public class FloatDataBufferTest extends BaseNd4jTest { DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); float[] get = buffer.getFloatsAt(0, 3); float[] data = new float[] {1, 2, 3}; - assertArrayEquals(getFailureMessage(), get, data, 1e-1f); + assertArrayEquals(get, data, 1e-1f,getFailureMessage()); float[] get2 = buffer.asFloat(); float[] allData = buffer.getFloatsAt(0, (int) buffer.length()); - assertArrayEquals(getFailureMessage(), get2, allData, 1e-1f); + assertArrayEquals(get2, allData, 1e-1f,getFailureMessage()); } @@ -169,13 +168,13 @@ public class FloatDataBufferTest extends BaseNd4jTest { DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); float[] get = buffer.getFloatsAt(1, 3); float[] data = new float[] {2, 3, 4}; - assertArrayEquals(getFailureMessage(), get, data, 1e-1f); + assertArrayEquals(get, data, 1e-1f,getFailureMessage()); float[] allButLast = new float[] {2, 3, 4, 5}; float[] allData = buffer.getFloatsAt(1, (int) buffer.length()); - assertArrayEquals(getFailureMessage(), allButLast, allData, 1e-1f); + assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage()); } @@ -185,7 +184,7 @@ public class FloatDataBufferTest extends BaseNd4jTest { public void testAsBytes() { INDArray arr = Nd4j.create(5); byte[] d = arr.data().asBytes(); - assertEquals(getFailureMessage(), 4 * 5, d.length); + assertEquals(4 * 5, d.length,getFailureMessage()); INDArray rand = Nd4j.rand(3, 3); rand.data().asBytes(); @@ -263,12 +262,12 @@ public class FloatDataBufferTest extends BaseNd4jTest { DataBuffer wrappedBuffer = Nd4j.createBuffer(buffer, 1, 2); FloatPointer pointer = (FloatPointer) wrappedBuffer.addressPointer(); - Assert.assertEquals(buffer.getFloat(1), pointer.get(0), 1e-1); - Assert.assertEquals(buffer.getFloat(2), pointer.get(1), 1e-1); + assertEquals(buffer.getFloat(1), pointer.get(0), 1e-1); + assertEquals(buffer.getFloat(2), pointer.get(1), 1e-1); try { pointer.asBuffer().get(3); // Try to access element outside pointer capacity. - Assert.fail("Accessing this address should not be allowed!"); + fail("Accessing this address should not be allowed!"); } catch (IndexOutOfBoundsException e) { // do nothing } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java index 26ff7162c..af3f277f8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.api.buffer; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -35,7 +35,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class IntDataBufferTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index 0feeacaa8..0f984c9d5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.indexing; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,8 +31,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 58dfcb5b8..2639b2048 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.indexing; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.common.base.Preconditions; @@ -43,7 +43,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.Arrays; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; /** @@ -450,7 +450,7 @@ public class IndexingTestsC extends BaseNd4jTest { long[] expShape = getShape(arr, indexes); long[] subShape = sub.shape(); - assertArrayEquals(msg, expShape, subShape); + assertArrayEquals(expShape, subShape,msg); msg = "Test case: rank = " + rank + ", order = " + order + ", inShape = " + Arrays.toString(inShape) + ", outShape = " + Arrays.toString(expShape) + @@ -462,7 +462,7 @@ public class IndexingTestsC extends BaseNd4jTest { double act = sub.getDouble(outIdxs); double exp = getDouble(indexes, arr, outIdxs); - assertEquals(msg, exp, act, 1e-6); + assertEquals(exp, act, 1e-6,msg); } totalTestCaseCount++; } @@ -470,7 +470,7 @@ public class IndexingTestsC extends BaseNd4jTest { } } - assertTrue(String.valueOf(totalTestCaseCount), totalTestCaseCount > 5000); + assertTrue( totalTestCaseCount > 5000,String.valueOf(totalTestCaseCount)); } private static long[] getShape(INDArray in, INDArrayIndex[] idxs){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java index 7ed8abe6d..721e5925e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.indexing.resolve; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,7 +31,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.PointIndex; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java index cf0db936a..923911f20 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.indexing.shape; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java index 561688e5f..b70af316e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.indexing.shape; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java index b82e2453a..5c4eebb1a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java @@ -21,14 +21,14 @@ package org.nd4j.linalg.api.iterator; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java index 95e0bf16c..d74759bc0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.ArrayUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -37,25 +38,23 @@ import org.nd4j.common.primitives.Pair; import java.io.File; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.Iterator; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) public class TestNdArrReadWriteTxt extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - public TestNdArrReadWriteTxt(Nd4jBackend backend) { super(backend); } @Test - public void compareAfterWrite() throws Exception { + public void compareAfterWrite(@TempDir Path testDir) throws Exception { int [] ranksToCheck = new int[] {0,1,2,3,4}; for (int i=0; i> all = NDArrayCreationUtil.getTestMatricesWithVaryingShapes(rank,ordering, Nd4j.defaultFloatingPointType()); Iterator> iter = all.iterator(); int cnt = 0; while (iter.hasNext()) { - File dir = testDir.newFolder(); + File dir = testDir.toFile(); Pair currentPair = iter.next(); INDArray origArray = currentPair.getFirst(); //adding elements outside the bounds where print switches to scientific notation @@ -79,15 +78,15 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest { // log.info("F:\n"+ origArray.toString()); Nd4j.writeTxt(origArray, new File(dir, "someArr.txt").getAbsolutePath()); INDArray readBack = Nd4j.readTxt(new File(dir, "someArr.txt").getAbsolutePath()); - assertEquals("\nNot equal on shape " + ArrayUtils.toString(origArray.shape()), origArray, readBack); + assertEquals(origArray, readBack,"\nNot equal on shape " + ArrayUtils.toString(origArray.shape())); cnt++; } } @Test - public void testNd4jReadWriteText() throws Exception { + public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception { - File dir = testDir.newFolder(); + File dir = testDir.toFile(); int count = 0; for(val testShape : new long[][]{{1,1}, {3,1}, {4,5}, {1,2,3}, {2,1,3}, {2,3,1}, {2,3,4}, {1,2,3,4}, {2,3,4,2}}){ List> l = null; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java index 13c6c7f78..f70655b7a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java @@ -21,22 +21,23 @@ package org.nd4j.linalg.api.ndarray; import lombok.extern.slf4j.Slf4j; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; +import java.nio.file.Path; + import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays; @Slf4j @RunWith(Parameterized.class) public class TestNdArrReadWriteTxtC extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); public TestNdArrReadWriteTxtC(Nd4jBackend backend) { @@ -44,7 +45,7 @@ public class TestNdArrReadWriteTxtC extends BaseNd4jTest { } @Test - public void compareAfterWrite() throws Exception { + public void compareAfterWrite(@TempDir Path testDir) throws Exception { int[] ranksToCheck = new int[]{0, 1, 2, 3, 4}; for (int i = 0; i < ranksToCheck.length; i++) { log.info("Checking read write arrays with rank " + ranksToCheck[i]); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerialization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerialization.java index 28a5e275f..6928b550b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerialization.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerialization.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.ndarray; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,7 +30,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.io.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class TestSerialization extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java index 5b5f2293e..e76f724ed 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.api.ndarray; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,7 +36,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.io.*; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @RunWith(Parameterized.class) @@ -49,7 +49,7 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @After + @AfterEach public void after() { DataTypeUtil.setDTypeForContext(this.initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationFloatToDouble.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationFloatToDouble.java index 366f6fb8d..518fc341a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationFloatToDouble.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationFloatToDouble.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.api.ndarray; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -34,8 +34,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.io.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class TestSerializationFloatToDouble extends BaseNd4jTest { @@ -47,7 +47,7 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @After + @AfterEach public void after() { Nd4j.setDataType(this.initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index 5fb055da4..f8f025ad1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.api.rng; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/string/TestFormatting.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/string/TestFormatting.java index 00331184f..d5ffb365b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/string/TestFormatting.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/string/TestFormatting.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.api.string; import lombok.extern.slf4j.Slf4j; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java index a2ceb2249..defe4ee1c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.api.tad; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.time.StopWatch; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,8 +36,8 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -126,7 +126,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { INDArray tadTest = arr.tensorAlongDimension(0, 0); assertEquals(javaTad, tadTest); //Along dimension 0: expect row vector with length 'rows' - assertEquals("Failed on " + p.getValue(), cols * dim2, arr.tensorsAlongDimension(0)); + assertEquals(cols * dim2, arr.tensorsAlongDimension(0),"Failed on " + p.getValue()); for (int i = 0; i < cols * dim2; i++) { INDArray tad = arr.tensorAlongDimension(i, 0); assertArrayEquals(new long[] {rows}, tad.shape()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java index cb2e03044..a4bb53bd1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java @@ -23,8 +23,8 @@ package org.nd4j.linalg.blas; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,7 +36,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.Collections; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -213,7 +213,7 @@ public class BlasTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testHalfPrecision() { val a = Nd4j.create(DataType.HALF, 64, 768); val b = Nd4j.create(DataType.HALF, 768, 1024); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 579feb0c2..5eb2357bb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -22,7 +22,8 @@ package org.nd4j.linalg.broadcast; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,8 +37,9 @@ import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @RunWith(Parameterized.class) @@ -130,46 +132,60 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(e, z); } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); - val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); - val z = x.subi(y); + assertThrows(IllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); + val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); + val z = x.subi(y); + }); } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_2() { - val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); - val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); - val z = x.divi(y); + assertThrows(IllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); + val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); + val z = x.divi(y); + }); + } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_3() { - val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); - val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); - val z = x.muli(y); + assertThrows(IllegalStateException.class, () -> { + val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); + val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); + val z = x.muli(y); + }); + } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_4() { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.addi(y); } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_5() { - val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); - val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); - val z = x.rsubi(y); + assertThrows(IllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); + val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); + val z = x.rsubi(y); + }); + } - @Test(expected = IllegalStateException.class) + @Test() public void basicBroadcastFailureTest_6() { - val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); - val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); - val z = x.rdivi(y); + assertThrows(IllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); + val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); + val z = x.rdivi(y); + }); + } @Test @@ -214,13 +230,14 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(y, z); } - @Test(expected = IllegalStateException.class) + @Test() public void emptyBroadcastTest_2() { val x = Nd4j.create(DataType.FLOAT, 1, 2); val y = Nd4j.create(DataType.FLOAT, 0, 2); val z = x.addi(y); assertEquals(y, z); + } @Test @@ -246,7 +263,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.addi(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } x.sub(y); @@ -255,7 +272,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.subi(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } x.mul(y); @@ -264,7 +281,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.muli(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } x.div(y); @@ -273,7 +290,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.divi(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } x.rsub(y); @@ -282,7 +299,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.rsubi(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } x.rdiv(y); @@ -291,7 +308,7 @@ public class BasicBroadcastTests extends BaseNd4jTest { x.rdivi(y); } catch (Exception e){ String s = e.getMessage(); - assertTrue(s, s.contains("broadcast") && s.contains("shape")); + assertTrue(s.contains("broadcast") && s.contains("shape"),s); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java index fbb56c04d..e0c76eac6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.compression; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,7 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class CompressionMagicTests extends BaseNd4jTest { @@ -38,7 +38,7 @@ public class CompressionMagicTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java index 928de6da7..fd271faa0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.compression; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,7 +36,7 @@ import org.nd4j.common.util.SerializationUtils; import java.io.ByteArrayOutputStream; @Slf4j -@Ignore +@Disabled @RunWith(Parameterized.class) public class CompressionPerformanceTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java index 61347b732..535db0317 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.compression; import org.apache.commons.io.output.ByteArrayOutputStream; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.ByteArrayInputStream; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class CompressionSerDeTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java index 59dbae887..ee57fd951 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.compression; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -44,7 +44,7 @@ import java.nio.ByteBuffer; import java.util.Arrays; import static junit.framework.TestCase.assertFalse; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -141,7 +141,7 @@ public class CompressionTests extends BaseNd4jTest { } - @Ignore + @Disabled @Test public void testThresholdCompression0() { INDArray initial = Nd4j.rand(new int[] {1, 150000000}, 119L); @@ -173,7 +173,7 @@ public class CompressionTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testThresholdCompression1() { INDArray initial = Nd4j.create(new float[] {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); INDArray exp_0 = Nd4j.create(DataType.FLOAT, 6); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java index f69ac8491..53bf93d3a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.convolution; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -43,8 +43,8 @@ import org.nd4j.common.util.ArrayUtil; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -1321,7 +1321,7 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testCompareIm2ColImpl() { int[] miniBatches = {1, 3, 5}; @@ -1405,7 +1405,7 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCompareIm2Col() { int[] miniBatches = {1, 3, 5}; @@ -1648,8 +1648,8 @@ public class ConvolutionTests extends BaseNd4jTest { String msg = "inOrder=" + inputOrder + ", outOrder=" + outputOrder; val vr = actDl4j.get(point(0), point(0), all(), all()); - assertEquals(msg, expDl4j, vr); - assertEquals(msg, expEnabled, actEnabled.get(point(0), point(0), all(), all())); + assertEquals(expDl4j, vr,msg); + assertEquals(expEnabled, actEnabled.get(point(0), point(0), all(), all()),msg); } } } @@ -1672,7 +1672,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); /* k=2, s=2, p=0, d=1, same mode, divisor = 1 @@ -1734,7 +1734,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1756,7 +1756,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals( exp, out,"Output order: " + outputOrder); } } @@ -1779,7 +1779,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1802,7 +1802,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1825,7 +1825,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1848,7 +1848,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1870,7 +1870,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1892,7 +1892,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1914,7 +1914,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1936,7 +1936,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -1958,7 +1958,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals( exp, out,"Output order: " + outputOrder); } } @@ -1981,7 +1981,7 @@ public class ConvolutionTests extends BaseNd4jTest { INDArray out = op.getOutputArgument(0); - assertEquals("Output order: " + outputOrder, exp, out); + assertEquals(exp, out,"Output order: " + outputOrder); } } @@ -2113,7 +2113,7 @@ public class ConvolutionTests extends BaseNd4jTest { } String msg = "TestNum=" + testNum + ", Mode: " + mode + ", " + pIn.getSecond(); - assertEquals(msg, exp, out); + assertEquals(exp, out,msg); testNum++; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java index 635cfcd5d..f48acb810 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.convolution; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -43,7 +43,7 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -106,7 +106,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testCompareIm2ColImpl() { int[] miniBatches = {1, 3, 5}; @@ -271,7 +271,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { reduced = reduced.reshape('c',m,d, outSize[0], outSize[1]).dup('c'); - assertEquals("Failed opType: " + type, reduced, output); + assertEquals(reduced, output,"Failed opType: " + type); } } } @@ -345,7 +345,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testMaxPoolBackprop(){ Nd4j.getRandom().setSeed(12345); @@ -401,7 +401,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { INDArray expEpsNext = expGradMaxPoolBackPropSame(input, epsilon, kernel, strides, same); String msg = "input=" + pIn.getSecond() + ", eps=" + pEps.getSecond(); - assertEquals(msg, expEpsNext, epsNext); + assertEquals( expEpsNext, epsNext,msg); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java index 4765a712a..f88ee0cc1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java @@ -20,12 +20,13 @@ package org.nd4j.linalg.convolution; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.resources.Resources; import org.nd4j.linalg.BaseNd4jTest; @@ -37,6 +38,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -45,9 +47,6 @@ import java.util.Set; public class DeconvTests extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - public DeconvTests(Nd4jBackend backend) { super(backend); } @@ -58,8 +57,8 @@ public class DeconvTests extends BaseNd4jTest { } @Test - public void compareKeras() throws Exception { - File newFolder = testDir.newFolder(); + public void compareKeras(@TempDir Path testDir) throws Exception { + File newFolder = testDir.toFile(); new ClassPathResource("keras/deconv/").copyDirectory(newFolder); File[] files = newFolder.listFiles(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index 2f61c54f0..5736a5577 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.crash; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -41,7 +41,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions; @Slf4j @RunWith(Parameterized.class) -@Ignore +@Disabled public class CrashTest extends BaseNd4jTest { public CrashTest(Nd4jBackend backend) { super(backend); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java index f5b51699c..4e0c28c89 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java @@ -24,7 +24,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import lombok.var; import org.apache.commons.lang3.RandomUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -51,8 +51,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j @@ -100,17 +99,20 @@ public class SpecialTests extends BaseNd4jTest { } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testScalarShuffle1() { - List listData = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - INDArray features = Nd4j.ones(25, 25); - INDArray label = Nd4j.create(new float[] {1}, new int[] {1}); - DataSet dataset = new DataSet(features, label); - listData.add(dataset); - } - DataSet data = DataSet.merge(listData); - data.shuffle(); + assertThrows(ND4JIllegalStateException.class,() -> { + List listData = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + INDArray features = Nd4j.ones(25, 25); + INDArray label = Nd4j.create(new float[] {1}, new int[] {1}); + DataSet dataset = new DataSet(features, label); + listData.add(dataset); + } + DataSet data = DataSet.merge(listData); + data.shuffle(); + }); + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 32719805f..3713947ef 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.custom; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; @@ -92,7 +92,7 @@ import java.util.ArrayList; import java.util.List; import static java.lang.Float.NaN; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class CustomOpsTests extends BaseNd4jTest { @@ -151,7 +151,7 @@ public class CustomOpsTests extends BaseNd4jTest { } @Test - @Ignore // it's noop, we dont care anymore + @Disabled // it's noop, we dont care anymore public void testNoOp1() { val arrayX = Nd4j.create(10, 10); val arrayY = Nd4j.create(5, 3); @@ -191,24 +191,27 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, arrayX); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testInplaceOp1() { - val arrayX = Nd4j.create(10, 10); - val arrayY = Nd4j.create(10, 10); + assertThrows(ND4JIllegalStateException.class,() -> { + val arrayX = Nd4j.create(10, 10); + val arrayY = Nd4j.create(10, 10); - arrayX.assign(4.0); - arrayY.assign(2.0); + arrayX.assign(4.0); + arrayY.assign(2.0); - val exp = Nd4j.create(10,10).assign(6.0); + val exp = Nd4j.create(10,10).assign(6.0); - CustomOp op = DynamicCustomOp.builder("add") - .addInputs(arrayX, arrayY) - .callInplace(true) - .build(); + CustomOp op = DynamicCustomOp.builder("add") + .addInputs(arrayX, arrayY) + .callInplace(true) + .build(); - Nd4j.getExecutioner().exec(op); + Nd4j.getExecutioner().exec(op); + + assertEquals(exp, arrayX); + }); - assertEquals(exp, arrayX); } @Test @@ -604,21 +607,24 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, z); } - @Test(expected = RuntimeException.class) + @Test() public void testInputValidationMergeMax(){ - INDArray[] inputs = new INDArray[]{ - Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3), - Nd4j.createFromArray(1.0f).reshape('c', 1, 1)}; + assertThrows(RuntimeException.class,() -> { + INDArray[] inputs = new INDArray[]{ + Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3), + Nd4j.createFromArray(1.0f).reshape('c', 1, 1)}; - INDArray out = Nd4j.create(DataType.FLOAT, 1, 3).assign(Double.NaN); - CustomOp op = DynamicCustomOp.builder("mergemax") - .addInputs(inputs) - .addOutputs(out) - .callInplace(false) - .build(); + INDArray out = Nd4j.create(DataType.FLOAT, 1, 3).assign(Double.NaN); + CustomOp op = DynamicCustomOp.builder("mergemax") + .addInputs(inputs) + .addOutputs(out) + .callInplace(false) + .build(); - Nd4j.exec(op); + Nd4j.exec(op); // System.out.println(out); + }); + } @Test @@ -786,9 +792,9 @@ public class CustomOpsTests extends BaseNd4jTest { public void test() throws Exception { INDArray in1 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1);//Nd4j.createFromArray(0.2019043,0.6464844,0.9116211,0.60058594,0.34033203,0.7036133,0.6772461,0.3815918,0.87353516,0.04650879,0.67822266,0.8618164,0.88378906,0.7573242,0.66796875,0.63427734,0.33764648,0.46923828,0.62939453,0.76464844,-0.8618164,-0.94873047,-0.9902344,-0.88916016,-0.86572266,-0.92089844,-0.90722656,-0.96533203,-0.97509766,-0.4975586,-0.84814453,-0.984375,-0.98828125,-0.95458984,-0.9472656,-0.91064453,-0.80859375,-0.83496094,-0.9140625,-0.82470703,0.4802246,0.45361328,0.28125,0.28320312,0.79345703,0.44604492,-0.30273438,0.11730957,0.56396484,0.73583984,0.1418457,-0.44848633,0.6923828,-0.40234375,0.40185547,0.48632812,0.14538574,0.4638672,0.13000488,0.5058594) - //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); + //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); INDArray in2 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1); //Nd4j.createFromArray(0.0,-0.13391113,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.1751709,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.51904297,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5107422,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0) - //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); + //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); INDArray out = in1.ulike(); @@ -870,43 +876,43 @@ public class CustomOpsTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testDrawBoundingBoxesShape() { INDArray images = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f,0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f, - 0.3087f,0.1548f,0.4695f,0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f, - 0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,0.2585f,0.4189f,0.7028f,0.7679f,0.5373f, - 0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,0.0755f,0.6245f,0.3491f, - 0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}).reshape(2,5,5,1); + 0.1804f,0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f, + 0.3087f,0.1548f,0.4695f,0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f, + 0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,0.2585f,0.4189f,0.7028f,0.7679f,0.5373f, + 0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,0.0755f,0.6245f,0.3491f, + 0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}).reshape(2,5,5,1); INDArray boxes = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f, - 0.7605f, 0.3948f, 0.9493f, 0.8600f, - 0.7876f, 0.8945f, 0.4638f, 0.7157f}).reshape(2,2,4); + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, + 0.7876f, 0.8945f, 0.4638f, 0.7157f}).reshape(2,2,4); INDArray colors = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f}).reshape(1,2); INDArray output = Nd4j.create(DataType.FLOAT, images.shape()); val op = new DrawBoundingBoxes(images, boxes, colors, output); Nd4j.exec(op); INDArray expected = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, - 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, - 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, + 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, - 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}); + 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}); assertEquals(expected, output); } @Test - @Ignore("Failing with results that are close") + @Disabled("Failing with results that are close") public void testFakeQuantAgainstTF_1() { INDArray x = Nd4j.createFromArray(new double[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); INDArray min = Nd4j.createFromArray(new double[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); INDArray max = Nd4j.createFromArray(new double[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); INDArray expected = Nd4j.createFromArray(new double[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, - 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, - 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); + 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, + 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max); INDArray[] output = Nd4j.exec(op); @@ -972,7 +978,7 @@ public class CustomOpsTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testDrawBoundingBoxes() { INDArray images = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 2*4*5*3).reshape(2,4,5,3); INDArray boxes = Nd4j.createFromArray(new float[]{ 0.0f , 0.0f , 1.0f , 1.0f, @@ -1082,7 +1088,7 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray batchVar = Nd4j.create(4); FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1, - y, batchMean, batchVar); + y, batchMean, batchVar); INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, @@ -1103,11 +1109,11 @@ public class CustomOpsTests extends BaseNd4jTest { @Test public void testFusedBatchNorm1() { INDArray x = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f, - 0.6591f, 0.5555f, 0.1596f, 0.3087f, - 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.1800f, 0.6750f, 0.2246f}).reshape(1,2,3,4); + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, + 0.6765f, 0.1800f, 0.6750f, 0.2246f}).reshape(1,2,3,4); INDArray scale = Nd4j.createFromArray(new float[]{ 0.7717f, 0.9281f, 0.9846f, 0.4838f}); INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); @@ -1119,11 +1125,11 @@ public class CustomOpsTests extends BaseNd4jTest { y, batchMean, batchVar); INDArray expectedY = Nd4j.createFromArray(new float[]{1.637202024f, 1.521406889f, 1.48303616f, -0.147269756f, - 1.44721508f, -0.51030159f, 0.810390055f, 1.03076458f, - 0.781284988f, 1.921229601f, -0.481337309f, 0.854952335f, - 1.196854949f, 0.717398405f, -0.253610134f, -0.00865117f, - -0.658405781f,0.43602103f, 2.311818838f, 0.529999137f, - 1.260738254f, -0.511638165f, 1.331095099f, -0.158477545f}).reshape(x.shape()); + 1.44721508f, -0.51030159f, 0.810390055f, 1.03076458f, + 0.781284988f, 1.921229601f, -0.481337309f, 0.854952335f, + 1.196854949f, 0.717398405f, -0.253610134f, -0.00865117f, + -0.658405781f,0.43602103f, 2.311818838f, 0.529999137f, + 1.260738254f, -0.511638165f, 1.331095099f, -0.158477545f}).reshape(x.shape()); Nd4j.exec(op); assertArrayEquals(expectedY.shape(), y.shape()); } @@ -1159,7 +1165,7 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, x); } - @Ignore("AS failed 2019/12/04") + @Disabled("AS failed 2019/12/04") @Test public void testPolygamma() { INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3); @@ -1217,12 +1223,12 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, result[0]); } - @Ignore("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449") + @Disabled("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449") @Test public void testNonMaxSuppression() { INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4); + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4); INDArray scores = Nd4j.createFromArray(new float[]{0.0029f, 0.8135f, 0.4873f}); val op = new NonMaxSuppression(boxes,scores,2,0.5,0.5); val res = Nd4j.exec(op); @@ -1232,14 +1238,14 @@ public class CustomOpsTests extends BaseNd4jTest { @Test public void testMatrixBand() { INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, - 0.7271f,0.1804f,0.5056f,0.8925f, - 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); + 0.7271f,0.1804f,0.5056f,0.8925f, + 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); MatrixBandPart op = new MatrixBandPart(input,1,-1); List lsd = op.calculateOutputShape(); assertEquals(1, lsd.size()); } - @Ignore("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450") + @Disabled("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450") @Test public void testBetaInc1() { INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); @@ -1251,15 +1257,15 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } - @Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452") + @Disabled("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452") @Test public void testPolygamma1() { INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4); + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4); INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f, - 0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4); + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4); INDArray expected = Nd4j.createFromArray(new float[]{NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN, }).reshape(3,4); Polygamma op = new Polygamma(a,b); INDArray[] ret = Nd4j.exec(op); @@ -1282,38 +1288,38 @@ public class CustomOpsTests extends BaseNd4jTest { @Test public void testAdjustHueShape(){ INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, - 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, - 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, - 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, - 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f, - 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, - 0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, - 0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f, - 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f, - 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, - 0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, - 0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f, - 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f, - 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, - 0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, - 0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f, - 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, - 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, - 0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, - 0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f, - 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f, - 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, - 0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, - 0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f, - 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f, - 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, - 0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, - 0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f, - 0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f, - 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, - 0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, - 0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3); + 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, + 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, + 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f, + 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, + 0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f, + 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f, + 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, + 0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, + 0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f, + 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f, + 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, + 0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, + 0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f, + 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, + 0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, + 0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f, + 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f, + 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, + 0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, + 0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f, + 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f, + 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, + 0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, + 0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f, + 0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f, + 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, + 0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, + 0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3); AdjustHue op = new AdjustHue(image, 0.2f); INDArray[] res = Nd4j.exec(op); @@ -1361,7 +1367,7 @@ public class CustomOpsTests extends BaseNd4jTest { // Exact copy of libnd4j test @Test - @Ignore + @Disabled public void testRgbToHsv() { INDArray expected = Nd4j.createFromArray(new float[]{ 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java index b4119bc36..eef78e1e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java @@ -22,15 +22,15 @@ package org.nd4j.linalg.custom; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ops.compat.CompatStringSplit; import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; @Slf4j public class ExpandableOpsTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java index d72b14373..704519e79 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java @@ -20,34 +20,34 @@ package org.nd4j.linalg.dataset; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4jBackend; import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.Map; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class BalanceMinibatchesTest extends BaseNd4jTest { public BalanceMinibatchesTest(Nd4jBackend backend) { super(backend); } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testBalance() throws Exception { + public void testBalance(@TempDir Path testDir) throws Exception { DataSetIterator iterator = new IrisDataSetIterator(10, 150); - File minibatches = testDir.newFolder(); - File saveDir = testDir.newFolder(); + File minibatches = new File(testDir.toFile(),"mini-batch-dir"); + File saveDir = new File(testDir.toFile(),"save-dir"); BalanceMinibatches balanceMinibatches = BalanceMinibatches.builder().dataSetIterator(iterator).miniBatchSize(10) .numLabels(3).rootDir(minibatches).rootSaveDir(saveDir).build(); @@ -60,13 +60,13 @@ public class BalanceMinibatchesTest extends BaseNd4jTest { } @Test - public void testMiniBatchBalanced() throws Exception { + public void testMiniBatchBalanced(@TempDir Path testDir) throws Exception { int miniBatchSize = 100; DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150); - File minibatches = testDir.newFolder(); - File saveDir = testDir.newFolder(); + File minibatches = new File(testDir.toFile(),"mini-batch-dir"); + File saveDir = new File(testDir.toFile(),"save-dir"); BalanceMinibatches balanceMinibatches = BalanceMinibatches.builder().dataSetIterator(iterator) .miniBatchSize(miniBatchSize).numLabels(iterator.totalOutcomes()) @@ -100,10 +100,9 @@ public class BalanceMinibatchesTest extends BaseNd4jTest { Map balancedCounts = balanced.next().labelCounts(); for (int i = 0; i < iterator.totalOutcomes(); i++) { double bCounts = (balancedCounts.containsKey(i) ? balancedCounts.get(i) : 0); - assertTrue("key " + i + " totalOutcomes: " + iterator.totalOutcomes() + " balancedCounts : " - + balancedCounts.containsKey(i) + " val : " + bCounts, - balancedCounts.containsKey(i) && balancedCounts.get(i) >= (double) miniBatchSize - / iterator.totalOutcomes()); + assertTrue( balancedCounts.containsKey(i) && balancedCounts.get(i) >= (double) miniBatchSize + / iterator.totalOutcomes(),"key " + i + " totalOutcomes: " + iterator.totalOutcomes() + " balancedCounts : " + + balancedCounts.containsKey(i) + " val : " + bCounts); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java index 525e89896..a89fc43c7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.dataset; import org.apache.commons.io.FileUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -40,7 +40,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class CachingDataSetIteratorTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index da0c13baf..ee927b330 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -40,17 +41,17 @@ import org.nd4j.common.util.ArrayUtil; import org.nd4j.linalg.util.FeatureUtil; import java.io.*; +import java.nio.file.Path; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j @RunWith(Parameterized.class) public class DataSetTest extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + public DataSetTest(Nd4jBackend backend) { @@ -117,7 +118,7 @@ public class DataSetTest extends BaseNd4jTest { assertEquals(train.getTrain().getLabels().length(), 6); SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1)); - assertEquals(getFailureMessage(), train.getTrain().getFeatures(), train2.getTrain().getFeatures()); + assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage()); DataSet x0 = new IrisDataSetIterator(150, 150).next(); SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10); @@ -153,13 +154,13 @@ public class DataSetTest extends BaseNd4jTest { @Test public void testLabelCounts() { DataSet x0 = new IrisDataSetIterator(150, 150).next(); - assertEquals(getFailureMessage(), 0, x0.get(0).outcome()); - assertEquals(getFailureMessage(), 0, x0.get(1).outcome()); - assertEquals(getFailureMessage(), 2, x0.get(149).outcome()); + assertEquals(0, x0.get(0).outcome(),getFailureMessage()); + assertEquals( 0, x0.get(1).outcome(),getFailureMessage()); + assertEquals(2, x0.get(149).outcome(),getFailureMessage()); Map counts = x0.labelCounts(); - assertEquals(getFailureMessage(), 50, counts.get(0), 1e-1); - assertEquals(getFailureMessage(), 50, counts.get(1), 1e-1); - assertEquals(getFailureMessage(), 50, counts.get(2), 1e-1); + assertEquals(50, counts.get(0), 1e-1,getFailureMessage()); + assertEquals(50, counts.get(1), 1e-1,getFailureMessage()); + assertEquals(50, counts.get(2), 1e-1,getFailureMessage()); } @@ -1078,7 +1079,7 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testDataSetMetaDataSerialization() throws IOException { + public void testDataSetMetaDataSerialization(@TempDir Path testDir) throws IOException { for(boolean withMeta : new boolean[]{false, true}) { // create simple data set with meta data object @@ -1092,7 +1093,7 @@ public class DataSetTest extends BaseNd4jTest { } // check if the meta data was serialized and deserialized - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File saved = new File(dir, "ds.bin"); ds.save(saved); DataSet loaded = new DataSet(); @@ -1108,7 +1109,7 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testMultiDataSetMetaDataSerialization() throws IOException { + public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir) throws IOException { for(boolean withMeta : new boolean[]{false, true}) { // create simple data set with meta data object @@ -1121,7 +1122,7 @@ public class DataSetTest extends BaseNd4jTest { } // check if the meta data was serialized and deserialized - File dir = testDir.newFolder(); + File dir = testDir.toFile(); File saved = new File(dir, "ds.bin"); ds.save(saved); MultiDataSet loaded = new MultiDataSet(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java index d65b1dfdd..8c43c5f30 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,8 +33,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java index 9f251a4cc..95ea38171 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,8 +32,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.HashSet; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class KFoldIteratorTest extends BaseNd4jTest { @@ -107,13 +106,16 @@ public class KFoldIteratorTest extends BaseNd4jTest { } - @Test(expected = IllegalArgumentException.class) + @Test() public void checkCornerCaseException() { - DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), - Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1)); - int k = 1; - //this will throw illegal argument exception - new KFoldIterator(k, allData); + assertThrows(IllegalArgumentException.class,() -> { + DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), + Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1)); + int k = 1; + //this will throw illegal argument exception + new KFoldIterator(k, allData); + }); + } @Test @@ -248,10 +250,10 @@ public class KFoldIteratorTest extends BaseNd4jTest { } String s = String.valueOf(count); DataSet test = iter.testFold(); - assertEquals(s, testFold, test.getFeatures()); - assertEquals(s, testFold, test.getLabels()); - assertEquals(s, countTrain, fold.getFeatures().length()); - assertEquals(s, countTrain, fold.getLabels().length()); + assertEquals(testFold, test.getFeatures(),s); + assertEquals( testFold, test.getLabels(),s); + assertEquals(countTrain, fold.getFeatures().length(),s); + assertEquals(countTrain, fold.getLabels().length(),s); count++; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java index 967ec5393..857e95c6d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Ede Meijer diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java index 3c362e2a4..3391af730 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java @@ -20,23 +20,24 @@ package org.nd4j.linalg.dataset; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); public MiniBatchFileDataSetIteratorTest(Nd4jBackend backend) { super(backend); @@ -44,9 +45,9 @@ public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTest { @Test - public void testMiniBatches() throws Exception { + public void testMiniBatches(@TempDir Path testDir) throws Exception { DataSet load = new IrisDataSetIterator(150, 150).next(); - final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.newFolder()); + final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile()); while (iter.hasNext()) assertEquals(10, iter.next().numExamples()); if (iter.getRootDir() == null) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java index 314c76e05..ced615d55 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.dataset; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -39,7 +39,7 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @@ -700,7 +700,7 @@ public class MultiDataSetTest extends BaseNd4jTest { MultiDataSet mds2 = new MultiDataSet(); mds2.load(dis); - assertEquals("Failed at [" + numF + "]/[" + numL + "]",mds, mds2); + assertEquals(mds, mds2,"Failed at [" + numF + "]/[" + numL + "]"); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java index 3d7edc72e..020b524a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.dataset; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,7 +30,7 @@ import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class MultiNormalizerHybridTest extends BaseNd4jTest { @@ -38,7 +38,7 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { private MultiDataSet data; private MultiDataSet dataCopy; - @Before + @BeforeEach public void setUp() { SUT = new MultiNormalizerHybrid(); data = new MultiDataSet( diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java index c4423e9c4..da87004a5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.dataset; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { @@ -46,7 +46,7 @@ public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { private double naturalMin; private double naturalMax; - @Before + @BeforeEach public void setUp() { SUT = new MultiNormalizerMinMaxScaler(); SUT.fitLabel(true); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java index fc417593d..899c96b46 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.dataset; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class MultiNormalizerStandardizeTest extends BaseNd4jTest { @@ -45,7 +45,7 @@ public class MultiNormalizerStandardizeTest extends BaseNd4jTest { private double meanNaturalNums; private double stdNaturalNums; - @Before + @BeforeEach public void setUp() { SUT = new MultiNormalizerStandardize(); SUT.fitLabel(true); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java index 59c4f5a45..5e7cff650 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -34,7 +34,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java index f5e7979f0..3c88e2256 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.dataset; import lombok.Getter; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -48,7 +48,8 @@ import java.util.HashMap; import java.util.Map; import static java.util.Arrays.asList; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * @author Ede Meijer @@ -62,7 +63,7 @@ public class NormalizerSerializerTest extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() throws IOException { tmpFile = File.createTempFile("test", "preProcessor"); tmpFile.deleteOnExit(); @@ -82,7 +83,7 @@ public class NormalizerSerializerTest extends BaseNd4jTest { @Test public void testNormalizerStandardizeNotFitLabels() throws Exception { NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), - Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); + Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); SUT.write(original, tmpFile); NormalizerStandardize restored = SUT.restore(tmpFile); @@ -93,8 +94,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { @Test public void testNormalizerStandardizeFitLabels() throws Exception { NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), - Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1), - Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1)); + Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1), + Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1)); original.fitLabel(true); SUT.write(original, tmpFile); @@ -131,10 +132,10 @@ public class NormalizerSerializerTest extends BaseNd4jTest { public void testMultiNormalizerStandardizeNotFitLabels() throws Exception { MultiNormalizerStandardize original = new MultiNormalizerStandardize(); original.setFeatureStats(asList( - new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), - Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), - new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), - Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); + new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), + Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), + new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), + Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); SUT.write(original, tmpFile); MultiNormalizerStandardize restored = SUT.restore(tmpFile); @@ -146,16 +147,16 @@ public class NormalizerSerializerTest extends BaseNd4jTest { public void testMultiNormalizerStandardizeFitLabels() throws Exception { MultiNormalizerStandardize original = new MultiNormalizerStandardize(); original.setFeatureStats(asList( - new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), - Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), - new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), - Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); + new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), + Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), + new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), + Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); original.setLabelStats(asList( - new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), - Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), - new DistributionStats(Nd4j.create(new double[] {4.5}).reshape(1, -1), Nd4j.create(new double[] {7.5}).reshape(1, -1)), - new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), - Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); + new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), + Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)), + new DistributionStats(Nd4j.create(new double[] {4.5}).reshape(1, -1), Nd4j.create(new double[] {7.5}).reshape(1, -1)), + new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), + Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); original.fitLabel(true); SUT.write(original, tmpFile); @@ -168,9 +169,9 @@ public class NormalizerSerializerTest extends BaseNd4jTest { public void testMultiNormalizerMinMaxScalerNotFitLabels() throws Exception { MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(asList( - new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), - new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), - Nd4j.create(new double[] {7.5, 8.5, 9.5})))); + new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), + new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), + Nd4j.create(new double[] {7.5, 8.5, 9.5})))); SUT.write(original, tmpFile); MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile); @@ -182,14 +183,14 @@ public class NormalizerSerializerTest extends BaseNd4jTest { public void testMultiNormalizerMinMaxScalerFitLabels() throws Exception { MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(asList( - new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), - new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), - Nd4j.create(new double[] {7.5, 8.5, 9.5})))); + new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), + new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), + Nd4j.create(new double[] {7.5, 8.5, 9.5})))); original.setLabelStats(asList( - new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), - new MinMaxStats(Nd4j.create(new double[] {4.5}), Nd4j.create(new double[] {7.5})), - new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), - Nd4j.create(new double[] {7.5, 8.5, 9.5})))); + new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), + new MinMaxStats(Nd4j.create(new double[] {4.5}), Nd4j.create(new double[] {7.5})), + new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), + Nd4j.create(new double[] {7.5, 8.5, 9.5})))); original.fitLabel(true); SUT.write(original, tmpFile); @@ -234,7 +235,7 @@ public class NormalizerSerializerTest extends BaseNd4jTest { @Test public void testMultiNormalizerHybridGlobalAndSpecificStats() throws Exception { MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5) - .minMaxScaleAllOutputs(-10, 10).standardizeOutput(1); + .minMaxScaleAllOutputs(-10, 10).standardizeOutput(1); Map inputStats = new HashMap<>(); inputStats.put(0, new MinMaxStats(Nd4j.create(new float[] {1, 2}).reshape(1, -1), Nd4j.create(new float[] {3, 4}).reshape(1, -1))); @@ -253,9 +254,12 @@ public class NormalizerSerializerTest extends BaseNd4jTest { assertEquals(original, restored); } - @Test(expected = RuntimeException.class) + @Test() public void testCustomNormalizerWithoutRegisteredStrategy() throws Exception { - SUT.write(new MyNormalizer(123), tmpFile); + assertThrows(RuntimeException.class, () -> { + SUT.write(new MyNormalizer(123), tmpFile); + + }); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java index a96ed6250..1725b9ebe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,8 +32,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { @@ -141,7 +141,7 @@ public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { assertTrue(sampleMeanDelta.mul(100).div(normData.theoreticalMean).max().getDouble(0) < tolerancePerc); //sanity check to see if it's within the theoretical standard error of mean sampleMeanSEM = sampleMeanDelta.div(normData.theoreticalSEM).max().getDouble(0); - assertTrue(String.valueOf(sampleMeanSEM), sampleMeanSEM < 2.6); //99% of the time it should be within this many SEMs + assertTrue(sampleMeanSEM < 2.6,String.valueOf(sampleMeanSEM)); //99% of the time it should be within this many SEMs tolerancePerc = 5; //within 5% sampleStd = myNormalizer.getStd(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java index 6fdd9227e..25cd555f3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class NormalizerStandardizeTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java index 741218d0b..317b8c806 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.dataset; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -46,8 +46,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class NormalizerTests extends BaseNd4jTest { @@ -64,7 +64,7 @@ public class NormalizerTests extends BaseNd4jTest { private int lastBatch; private final float thresholdPerc = 2.0f; //this is the difference in percentage! - @Before + @BeforeEach public void randomData() { Nd4j.getRandom().setSeed(12345); batchSize = 13; @@ -81,10 +81,10 @@ public class NormalizerTests extends BaseNd4jTest { public void testPreProcessors() { System.out.println("Running iterator vs non-iterator std scaler.."); double d1 = testItervsDataset(stdScaler); - assertTrue(d1 + " < " + thresholdPerc, d1 < thresholdPerc); + assertTrue( d1 < thresholdPerc,d1 + " < " + thresholdPerc); System.out.println("Running iterator vs non-iterator min max scaler.."); double d2 = testItervsDataset(minMaxScaler); - assertTrue(d2 + " < " + thresholdPerc, d2 < thresholdPerc); + assertTrue(d2 < thresholdPerc,d2 + " < " + thresholdPerc); } public float testItervsDataset(DataNormalization preProcessor) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java index e43920490..0c0808a07 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.dataset; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -40,7 +40,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java index 286d7842e..7fc3363bb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class PreProcessorTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java index 7ae48bf77..930b33763 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.dataset; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,7 +36,7 @@ public class StandardScalerTest extends BaseNd4jTest { super(backend); } - @Ignore + @Disabled @Test public void testScale() { StandardScaler scaler = new StandardScaler(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java index bc9a32126..efc3c06f0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java @@ -20,15 +20,14 @@ package org.nd4j.linalg.dataset.api.preprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { @@ -41,13 +40,16 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { return 'c'; } - @Test(expected = NullPointerException.class) + @Test() public void when_preConditionsIsNull_expect_NullPointerException() { - // Assemble - CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); + assertThrows(NullPointerException.class,() -> { + // Assemble + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); + + // Act + sut.preProcess(null); + }); - // Act - sut.preProcess(null); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java index d93c5b3a8..ec353d2d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset.api.preprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -29,7 +29,7 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { @@ -42,48 +42,72 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { return 'c'; } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_originalHeightIsZero_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(0, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(0, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_originalWidthIsZero_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 0, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 0, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_yStartIsNegative_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, -1, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, -1, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_xStartIsNegative_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, -1, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, -1, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 0, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 0, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 0, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 0, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException() { - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 0, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(IllegalArgumentException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 0, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + }); } - @Test(expected = NullPointerException.class) + @Test() public void when_dataSetIsNull_expect_NullPointerException() { // Assemble - CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + assertThrows(NullPointerException.class,() -> { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + // Act + sut.preProcess(null); + }); - // Act - sut.preProcess(null); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java index 46464c0c8..5a22457c9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dataset.api.preprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Ede Meijer diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java index b1fd85506..81881fc42 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java @@ -22,13 +22,13 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { @@ -41,13 +41,16 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { return 'c'; } - @Test(expected = NullPointerException.class) + @Test() public void when_dataSetIsNull_expect_NullPointerException() { - // Assemble - PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); + assertThrows(NullPointerException.class,() -> { + // Assemble + PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); + + // Act + sut.preProcess(null); + }); - // Act - sut.preProcess(null); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java index 4ba6a0886..305c87855 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java @@ -20,15 +20,14 @@ package org.nd4j.linalg.dataset.api.preprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { @@ -41,13 +40,16 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { return 'c'; } - @Test(expected = NullPointerException.class) + @Test() public void when_dataSetIsNull_expect_NullPointerException() { - // Assemble - RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); + assertThrows(NullPointerException.class,() -> { + // Assemble + RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); + + // Act + sut.preProcess(null); + }); - // Act - sut.preProcess(null); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java index 14e19c724..ed5568fea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.dataset.api.preprocessor; import lombok.extern.slf4j.Slf4j; import net.jcip.annotations.NotThreadSafe; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -41,8 +41,8 @@ import java.util.HashMap; import java.util.List; import static java.lang.Math.min; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author susaneraly @@ -151,19 +151,19 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { INDArray minorityDist = labelWindow.mul(maskWindow).sum(1).div(maskWindow.sum(1)); if (j < shortSeq / window) { - assertEquals("Failed on window " + j + " batch 0, loop " + i, targetDist, - minorityDist.getFloat(0), tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 1, loop " + i, targetDist, - minorityDist.getFloat(1), tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 2, loop " + i, 0.8, minorityDist.getFloat(2), - tolerancePerc); //should be unchanged as it was already above target dist + assertEquals(targetDist, + minorityDist.getFloat(0), tolerancePerc,"Failed on window " + j + " batch 0, loop " + i); //should now be close to target dist + assertEquals( targetDist, + minorityDist.getFloat(1), tolerancePerc,"Failed on window " + j + " batch 1, loop " + i); //should now be close to target dist + assertEquals(0.8, minorityDist.getFloat(2), + tolerancePerc,"Failed on window " + j + " batch 2, loop " + i); //should be unchanged as it was already above target dist } - assertEquals("Failed on window " + j + " batch 3, loop " + i, targetDist, minorityDist.getFloat(3), - tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 4, loop " + i, targetDist, minorityDist.getFloat(4), - tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 5, loop " + i, 0.8, minorityDist.getFloat(5), - tolerancePerc); //should be unchanged as it was already above target dist + assertEquals(targetDist, minorityDist.getFloat(3), + tolerancePerc,"Failed on window " + j + " batch 3, loop " + i); //should now be close to target dist + assertEquals(targetDist, minorityDist.getFloat(4), + tolerancePerc,"Failed on window " + j + " batch 4, loop " + i); //should now be close to target dist + assertEquals( 0.8, minorityDist.getFloat(5), + tolerancePerc,"Failed on window " + j + " batch 5, loop " + i); //should be unchanged as it was already above target dist } } } @@ -214,19 +214,19 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { INDArray minorityDist = minorityClass.sum(1).div(majorityClass.add(minorityClass).sum(1)); if (j < shortSeq / window) { - assertEquals("Failed on window " + j + " batch 0, loop " + i, targetDist, - minorityDist.getFloat(0), tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 1, loop " + i, targetDist, - minorityDist.getFloat(1), tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 2, loop " + i, 0.8, minorityDist.getFloat(2), - tolerancePerc); //should be unchanged as it was already above target dist + assertEquals(targetDist, + minorityDist.getFloat(0), tolerancePerc,"Failed on window " + j + " batch 0, loop " + i); //should now be close to target dist + assertEquals(targetDist, + minorityDist.getFloat(1), tolerancePerc,"Failed on window " + j + " batch 1, loop " + i); //should now be close to target dist + assertEquals(0.8, minorityDist.getFloat(2), + tolerancePerc,"Failed on window " + j + " batch 2, loop " + i); //should be unchanged as it was already above target dist } - assertEquals("Failed on window " + j + " batch 3, loop " + i, targetDist, minorityDist.getFloat(3), - tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 4, loop " + i, targetDist, minorityDist.getFloat(4), - tolerancePerc); //should now be close to target dist - assertEquals("Failed on window " + j + " batch 5, loop " + i, 0.8, minorityDist.getFloat(5), - tolerancePerc); //should be unchanged as it was already above target dist + assertEquals(targetDist, minorityDist.getFloat(3), + tolerancePerc,"Failed on window " + j + " batch 3, loop " + i); //should now be close to target dist + assertEquals( targetDist, minorityDist.getFloat(4), + tolerancePerc,"Failed on window " + j + " batch 4, loop " + i); //should now be close to target dist + assertEquals(0.8, minorityDist.getFloat(5), + tolerancePerc,"Failed on window " + j + " batch 5, loop " + i); //should be unchanged as it was already above target dist } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java index 5bdeb1d2f..1a4300f96 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.dimensionalityreduction; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,8 +29,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.string.NDArrayStrings; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class TestPCA extends BaseNd4jTest { @@ -59,7 +59,7 @@ public class TestPCA extends BaseNd4jTest { INDArray Reconstructed = Reduced.mmul(Factor.transpose()); INDArray Diff = Reconstructed.sub(A1); for (int i = 0; i < m * n; i++) { - assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff.getDouble(i), 1.0); + assertEquals(0.0, Diff.getDouble(i), 1.0,"Reconstructed matrix is very different from the original."); } } @@ -82,7 +82,7 @@ public class TestPCA extends BaseNd4jTest { INDArray reconstructed = reduced.mmul(factor.transpose()); INDArray diff = reconstructed.sub(A1); for (int i = 0; i < m * n; i++) { - assertEquals("Reconstructed matrix is very different from the original.", 0.0, diff.getDouble(i), 1.0); + assertEquals(0.0, diff.getDouble(i), 1.0,"Reconstructed matrix is very different from the original."); } } @@ -104,11 +104,11 @@ public class TestPCA extends BaseNd4jTest { INDArray Reconstructed1 = Reduced1.mmul(Factor1.transpose()); INDArray Diff1 = Reconstructed1.sub(A1); for (int i = 0; i < m * n; i++) { - assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff1.getDouble(i), 0.1); + assertEquals( 0.0, Diff1.getDouble(i), 0.1,"Reconstructed matrix is very different from the original."); } INDArray A2 = A.dup('f'); INDArray Factor2 = org.nd4j.linalg.dimensionalityreduction.PCA.pca_factor(A2, 0.50, true); - assertTrue("Variance differences should change factor sizes.", Factor1.columns() > Factor2.columns()); + assertTrue(Factor1.columns() > Factor2.columns(),"Variance differences should change factor sizes."); } @@ -145,10 +145,9 @@ public class TestPCA extends BaseNd4jTest { PCA myPCA = new PCA(m); INDArray reduced70 = myPCA.reducedBasis(0.70); INDArray reduced99 = myPCA.reducedBasis(0.99); - assertTrue("Major variance differences should change number of basis vectors", - reduced99.columns() > reduced70.columns()); + assertTrue( reduced99.columns() > reduced70.columns(),"Major variance differences should change number of basis vectors"); INDArray reduced100 = myPCA.reducedBasis(1.0); - assertTrue("100% variance coverage should include all eigenvectors", reduced100.columns() == m.columns()); + assertTrue(reduced100.columns() == m.columns(),"100% variance coverage should include all eigenvectors"); NDArrayStrings ns = new NDArrayStrings(5); // System.out.println("Eigenvectors:\n" + ns.format(myPCA.getEigenvectors())); // System.out.println("Eigenvalues:\n" + ns.format(myPCA.getEigenvalues())); @@ -159,22 +158,21 @@ public class TestPCA extends BaseNd4jTest { variance += myPCA.estimateVariance(m.getRow(i), reduced70.columns()); variance /= 1000.0; System.out.println("Fraction of variance using 70% variance with " + reduced70.columns() + " columns: " + variance); - assertTrue("Variance does not cover intended 70% variance", variance > 0.70); + assertTrue(variance > 0.70,"Variance does not cover intended 70% variance"); // create "dummy" data with the same exact trends INDArray testSample = myPCA.generateGaussianSamples(10000); PCA analyzePCA = new PCA(testSample); - assertTrue("Means do not agree accurately enough", - myPCA.getMean().equalsWithEps(analyzePCA.getMean(), 0.2 * myPCA.getMean().columns())); - assertTrue("Covariance is not reproduced accurately enough", myPCA.getCovarianceMatrix().equalsWithEps( - analyzePCA.getCovarianceMatrix(), 1.0 * analyzePCA.getCovarianceMatrix().length())); - assertTrue("Eigenvalues are not close enough", myPCA.getEigenvalues().equalsWithEps(analyzePCA.getEigenvalues(), - 0.5 * myPCA.getEigenvalues().columns())); - assertTrue("Eigenvectors are not close enough", myPCA.getEigenvectors() - .equalsWithEps(analyzePCA.getEigenvectors(), 0.1 * analyzePCA.getEigenvectors().length())); + assertTrue( myPCA.getMean().equalsWithEps(analyzePCA.getMean(), 0.2 * myPCA.getMean().columns()),"Means do not agree accurately enough"); + assertTrue(myPCA.getCovarianceMatrix().equalsWithEps( + analyzePCA.getCovarianceMatrix(), 1.0 * analyzePCA.getCovarianceMatrix().length()),"Covariance is not reproduced accurately enough"); + assertTrue( myPCA.getEigenvalues().equalsWithEps(analyzePCA.getEigenvalues(), + 0.5 * myPCA.getEigenvalues().columns()),"Eigenvalues are not close enough"); + assertTrue(myPCA.getEigenvectors() + .equalsWithEps(analyzePCA.getEigenvectors(), 0.1 * analyzePCA.getEigenvectors().length()),"Eigenvectors are not close enough"); // System.out.println("Original cov:\n" + ns.format(myPCA.getCovarianceMatrix()) + "\nDummy cov:\n" // + ns.format(analyzePCA.getCovarianceMatrix())); INDArray testSample2 = analyzePCA.convertBackToFeatures(analyzePCA.convertToComponents(testSample)); - assertTrue("Transformation does not work.", testSample.equalsWithEps(testSample2, 1e-5 * testSample.length())); + assertTrue( testSample.equalsWithEps(testSample2, 1e-5 * testSample.length()),"Transformation does not work."); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java index 62dbb40d0..0df37d632 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java @@ -20,9 +20,9 @@ package org.nd4j.linalg.dimensionalityreduction; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -38,19 +38,16 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.johnsonLindenStraussMinDim; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.targetShape; -@Ignore +@Disabled @RunWith(Parameterized.class) public class TestRandomProjection extends BaseNd4jTest { INDArray z1 = Nd4j.createUninitialized(new int[]{(int)1e6, 1000}); - @Rule - public final ExpectedException exception = ExpectedException.none(); - public TestRandomProjection(Nd4jBackend backend) { super(backend); @@ -79,22 +76,26 @@ public class TestRandomProjection extends BaseNd4jTest { @Test public void testTargetEpsilonChecks() { - exception.expect(IllegalArgumentException.class); - // wrong rel. error - targetShape(z1, 0.0); + assertThrows(IllegalArgumentException.class,() -> { + // wrong rel. error + targetShape(z1, 0.0); + }); + } @Test public void testTargetShapeTooHigh() { - exception.expect(ND4JIllegalStateException.class); - // original dimension too small - targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 1}), 0.5); - // target dimension too high - targetShape(z1, 1001); - // suggested dimension too high - targetShape(z1, 0.1); - // original samples too small - targetShape(Nd4j.createUninitialized(new int[]{1, 1000}), 0.5); + assertThrows(ND4JIllegalStateException.class,() -> { + // original dimension too small + targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 1}), 0.5); + // target dimension too high + targetShape(z1, 1001); + // suggested dimension too high + targetShape(z1, 0.1); + // original samples too small + targetShape(Nd4j.createUninitialized(new int[]{1, 1000}), 0.5); + }); + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java index 318fb9587..58f226c64 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java @@ -23,8 +23,8 @@ package org.nd4j.linalg.factory; import lombok.val; import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.Pointer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -44,8 +44,8 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** */ @@ -133,7 +133,7 @@ public class Nd4jTest extends BaseNd4jTest { INDArray actualResult = data.mean(0); INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4}); - assertEquals(getFailureMessage(), expectedResult, actualResult); + assertEquals(expectedResult, actualResult,getFailureMessage()); } @@ -147,7 +147,7 @@ public class Nd4jTest extends BaseNd4jTest { INDArray actualResult = data.var(false, 0); INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4}); - assertEquals(getFailureMessage(), expectedResult, actualResult); + assertEquals(expectedResult, actualResult,getFailureMessage()); } @Test @@ -178,12 +178,12 @@ public class Nd4jTest extends BaseNd4jTest { val tmR = testMatrix.ravel(); val expR = expanded.ravel(); - assertEquals(message, 1, expanded.shape()[i < 0 ? i + rank : i]); - assertEquals(message, tmR, expR); - assertEquals(message, ordering, expanded.ordering()); + assertEquals( 1, expanded.shape()[i < 0 ? i + rank : i],message); + assertEquals(tmR, expR,message); + assertEquals( ordering, expanded.ordering(),message); testMatrix.assign(Nd4j.rand(DataType.DOUBLE, shape)); - assertEquals(message, testMatrix.ravel(), expanded.ravel()); + assertEquals(testMatrix.ravel(), expanded.ravel(),message); } } } @@ -204,19 +204,19 @@ public class Nd4jTest extends BaseNd4jTest { final long[] expShape = ArrayUtil.removeIndex(shape, 1); final String message = "Squeezing in dimension 1; Shape before squeezing: " + Arrays.toString(shape) + " " + ordering + " Order; Shape after expanding: " + Arrays.toString(squeezed.shape()) + " "+squeezed.ordering()+"; Input Created via: " + recreation; - assertArrayEquals(message, expShape, squeezed.shape()); - assertEquals(message, ordering, squeezed.ordering()); - assertEquals(message, testMatrix.ravel(), squeezed.ravel()); + assertArrayEquals(expShape, squeezed.shape(),message); + assertEquals(ordering, squeezed.ordering(),message); + assertEquals(testMatrix.ravel(), squeezed.ravel(),message); testMatrix.assign(Nd4j.rand(shape)); - assertEquals(message, testMatrix.ravel(), squeezed.ravel()); + assertEquals(testMatrix.ravel(), squeezed.ravel(),message); } } @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testNumpyConversion() throws Exception { INDArray linspace = Nd4j.linspace(1,4,4, DataType.FLOAT); Pointer convert = Nd4j.getNDArrayFactory().convertToNumpy(linspace); @@ -254,7 +254,7 @@ public class Nd4jTest extends BaseNd4jTest { @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testNumpyWrite() throws Exception { INDArray linspace = Nd4j.linspace(1,4,4, Nd4j.dataType()); File tmpFile = new File(System.getProperty("java.io.tmpdir"),"nd4j-numpy-tmp-" + UUID.randomUUID().toString() + ".bin"); @@ -266,7 +266,7 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testNpyByteArray() throws Exception { INDArray linspace = Nd4j.linspace(1,4,4, Nd4j.dataType()); byte[] bytes = Nd4j.toNpyByteArray(linspace); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java index 0e642dcf6..745dc00d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.factory.ops; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,7 +28,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.conditions.Conditions; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class NDBaseTest extends BaseNd4jTest { public NDBaseTest(Nd4jBackend backend) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java index 60c179282..a4c6f0527 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.factory.ops; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,8 +31,8 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class NDLossTest extends BaseNd4jTest { public NDLossTest(Nd4jBackend backend) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java index 1c5a19293..ba26e181a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java @@ -19,8 +19,8 @@ */ package org.nd4j.linalg.generated; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.BaseNd4jTest; @@ -29,8 +29,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class SDLinalgTest extends BaseNd4jTest { public SDLinalgTest(Nd4jBackend backend) { @@ -44,7 +44,7 @@ public class SDLinalgTest extends BaseNd4jTest { private SameDiff sameDiff; - @Before + @BeforeEach public void setup() { sameDiff = SameDiff.create(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index ac6a4c873..5c465317c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.indexing; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -41,7 +41,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; import java.util.Collections; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class BooleanIndexingTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java index 670eb7d84..f4b59fdb1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.indexing; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,8 +32,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java index ebdef9b9d..22e0de225 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java @@ -24,7 +24,7 @@ import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -38,7 +38,7 @@ import org.nd4j.common.primitives.Pair; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class TestInvertMatrices extends BaseNd4jTest { @@ -74,7 +74,7 @@ public class TestInvertMatrices extends BaseNd4jTest { RealMatrix rmInverse = new LUDecomposition(rm).getSolver().getInverse(); INDArray expected = CheckUtil.convertFromApacheMatrix(rmInverse, orig.dataType()); - assertTrue(p.getSecond(), CheckUtil.checkEntries(expected, inverse, 1e-3, 1e-4)); + assertTrue(CheckUtil.checkEntries(expected, inverse, 1e-3, 1e-4),p.getSecond()); } } @@ -190,19 +190,25 @@ public class TestInvertMatrices extends BaseNd4jTest { /** * Try to compute the right pseudo inverse of a matrix without full row rank (x1 = 2*x2) */ - @Test(expected = IllegalArgumentException.class) + @Test() public void testRightPseudoInvertWithNonFullRowRank() { - INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}).transpose(); - INDArray rightInverse = InvertMatrix.pRightInvert(X, false); + assertThrows(IllegalArgumentException.class,() -> { + INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}).transpose(); + INDArray rightInverse = InvertMatrix.pRightInvert(X, false); + }); + } /** * Try to compute the left pseudo inverse of a matrix without full column rank (x1 = 2*x2) */ - @Test(expected = IllegalArgumentException.class) + @Test() public void testLeftPseudoInvertWithNonFullColumnRank() { - INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}); - INDArray leftInverse = InvertMatrix.pLeftInvert(X, false); + assertThrows(IllegalArgumentException.class,() -> { + INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}); + INDArray leftInverse = InvertMatrix.pLeftInvert(X, false); + }); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java index c6e4551f3..7fc8e85d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java @@ -21,9 +21,9 @@ package org.nd4j.linalg.lapack; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -44,12 +44,12 @@ public class LapackTestsC extends BaseNd4jTest { initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDataType(DataType.DOUBLE); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java index 8c0171e88..1c27010ae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java @@ -21,9 +21,9 @@ package org.nd4j.linalg.lapack; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -44,12 +44,12 @@ public class LapackTestsF extends BaseNd4jTest { initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDataType(DataType.DOUBLE); } - @After + @AfterEach public void after() { Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java index e9427c485..897098ace 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.learning; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -35,7 +35,7 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Nadam; import org.nd4j.linalg.learning.config.Nesterovs; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class UpdaterTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index 3c18f74f1..27409e0d6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.learning; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,7 +40,7 @@ import org.nd4j.linalg.learning.config.Sgd; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class UpdaterValidation extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java index 2ae63f428..c1a75fbbd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.lossfunctions; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -45,7 +45,7 @@ import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class LossFunctionJson extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java index 566178d6c..a39f715f0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.lossfunctions; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -45,7 +45,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import static junit.framework.TestCase.assertFalse; import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class LossFunctionTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java index ce301a077..a4ef0632d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.lossfunctions; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index a816451e4..6f4213554 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -37,10 +37,10 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @RunWith(Parameterized.class) public class AccountingTests extends BaseNd4jTest { public AccountingTests(Nd4jBackend backend) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java index 230ac40d9..5ded208b6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,8 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -82,21 +81,26 @@ public class CloseableTests extends BaseNd4jTest { } } - @Test(expected = IllegalStateException.class) + @Test() public void testAccessException_1() { - val array = Nd4j.create(5, 5); - array.close(); + assertThrows(IllegalStateException.class,() -> { + val array = Nd4j.create(5, 5); + array.close(); + + array.data().pointer(); + }); - array.data().pointer(); } - @Test(expected = IllegalStateException.class) + @Test() public void testAccessException_2() { - val array = Nd4j.create(5, 5); - val view = array.getRow(0); - array.close(); + assertThrows(IllegalStateException.class,() -> { + val array = Nd4j.create(5, 5); + val view = array.getRow(0); + array.close(); - view.data().pointer(); + view.data().pointer(); + }); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java index 75d45655b..9dc02b36a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -35,8 +35,8 @@ import org.nd4j.linalg.util.DeviceLocalNDArray; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index 40a6706cf..b60941bc6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -23,8 +23,8 @@ package org.nd4j.linalg.mixed; import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.graph.FlatArray; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -48,7 +48,7 @@ import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.nativeblas.NativeOpsHolder; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class MixedDataTypesTests extends BaseNd4jTest { @@ -359,33 +359,42 @@ public class MixedDataTypesTests extends BaseNd4jTest { assertEquals(exp, arrayZ); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testTypesValidation_1() { - val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG); - val arrayY = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); - val exp = new long[]{1, 0, 0, 1}; + assertThrows(IllegalArgumentException.class,() -> { + val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG); + val arrayY = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); + val exp = new long[]{1, 0, 0, 1}; + + val op = new CosineSimilarity(arrayX, arrayY); + val result = Nd4j.getExecutioner().exec(op); + }); - val op = new CosineSimilarity(arrayX, arrayY); - val result = Nd4j.getExecutioner().exec(op); } - @Test(expected = RuntimeException.class) + @Test() public void testTypesValidation_2() { - val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); - val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG); - val exp = new long[]{1, 0, 0, 1}; + assertThrows(RuntimeException.class,() -> { + val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); + val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG); + val exp = new long[]{1, 0, 0, 1}; - val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0]; - val arr = result.data().asLong(); + val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0]; + val arr = result.data().asLong(); + + assertArrayEquals(exp, arr); + }); - assertArrayEquals(exp, arr); } - @Test(expected = RuntimeException.class) + @Test() public void testTypesValidation_3() { - val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); + assertThrows(RuntimeException.class,() -> { + val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); + + val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(arrayX, arrayX, -1)); + }); - val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(arrayX, arrayX, -1)); } public void testTypesValidation_4() { @@ -533,7 +542,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testArrayCreationFromPointer() { val source = Nd4j.create(new double[]{1, 2, 3, 4, 5}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java index 28770add9..c14020d22 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java @@ -23,14 +23,14 @@ package org.nd4j.linalg.mixed; import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.graph.FlatArray; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class StringArrayTests extends BaseNd4jTest { @@ -55,7 +55,7 @@ public class StringArrayTests extends BaseNd4jTest { assertEquals("alpha", array.getString(0)); String s = array.toString(); - assertTrue(s, s.contains("alpha")); + assertTrue(s.contains("alpha"),s); System.out.println(s); } @@ -72,9 +72,9 @@ public class StringArrayTests extends BaseNd4jTest { assertEquals("beta", array.getString(1)); assertEquals("gamma", array.getString(2)); String s = array.toString(); - assertTrue(s, s.contains("alpha")); - assertTrue(s, s.contains("beta")); - assertTrue(s, s.contains("gamma")); + assertTrue(s.contains("alpha"),s); + assertTrue(s.contains("beta"),s); + assertTrue(s.contains("gamma"),s); System.out.println(s); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java index 30a813efc..821499afc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.multithreading; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,7 +31,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.HashSet; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class MultithreadedTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java index 4a4edea17..e5752cff4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java @@ -22,16 +22,16 @@ package org.nd4j.linalg.nativ; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class NativeBlasTests extends BaseNd4jTest { @@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { Nd4j.getExecutioner().enableDebugMode(true); Nd4j.getExecutioner().enableVerboseMode(true); } - @After + @AfterEach public void setDown() { Nd4j.getExecutioner().enableDebugMode(false); Nd4j.getExecutioner().enableVerboseMode(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index 86c89a82a..c2dca756e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -23,7 +23,7 @@ package org.nd4j.linalg.nativ; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.imports.NoOpNameFoundException; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index 5e233b45e..f623847e1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -21,9 +21,9 @@ package org.nd4j.linalg.ops; import org.apache.commons.math3.util.FastMath; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -42,7 +42,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) @@ -58,12 +58,12 @@ public class DerivativeTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void before() { Nd4j.setDataType(DataType.DOUBLE); } - @After + @AfterEach public void after() { Nd4j.setDataType(this.initialType); } @@ -257,7 +257,7 @@ public class DerivativeTests extends BaseNd4jTest { if (d1 == 0.0 && d2 == 0.0) relError = 0.0; String str = "exp=" + expOut[i] + ", act=" + zPrime.getDouble(i) + "; relError = " + relError; - assertTrue(str, relError < REL_ERROR_TOLERANCE); + assertTrue(relError < REL_ERROR_TOLERANCE,str); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java index 20d8b684d..1addead6c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java @@ -20,8 +20,8 @@ package org.nd4j.linalg.ops; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.linalg.BaseNd4jTest; @@ -38,9 +38,9 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Modifier; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore //AB 2019/08/23 Ignored for now +@Disabled //AB 2019/08/23 Ignored for now public class OpConstructorTests extends BaseNd4jTest { public OpConstructorTests(Nd4jBackend backend) { @@ -119,7 +119,7 @@ public class OpConstructorTests extends BaseNd4jTest { System.out.println("No INDArray constructor: " + c.getName()); } } - assertEquals("Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors", 0, classes.size()); + assertEquals(0, classes.size(),"Found " + classes.size() + " (non-ignored) op classes with no INDArray/INDArray[] constructors"); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 7f366e2f9..1b9de3efa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.ops; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -64,7 +64,7 @@ import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) @@ -82,7 +82,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); double sim = Transforms.cosineSim(vec1, vec2); - assertEquals(getFailureMessage(), 1, sim, 1e-1); + assertEquals( 1, sim, 1e-1,getFailureMessage()); } @@ -92,7 +92,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray vec2 = Nd4j.create(new float[] {3, 5, 7}); // 1-17*sqrt(2/581) double distance = Transforms.cosineDistance(vec1, vec2); - assertEquals(getFailureMessage(), 0.0025851, distance, 1e-7); + assertEquals(0.0025851, distance, 1e-7,getFailureMessage()); } @Test @@ -100,7 +100,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray arr = Nd4j.create(new double[] {55, 55}); INDArray arr2 = Nd4j.create(new double[] {60, 60}); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0); - assertEquals(getFailureMessage(), 7.0710678118654755, result, 1e-1); + assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage()); } @Test @@ -118,7 +118,7 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testDistance() throws Exception { INDArray matrix = Nd4j.rand(new int[] {400,10}); INDArray rowVector = matrix.getRow(70); @@ -141,7 +141,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); - assertEquals(getFailureMessage(), scalarMax, postMax); + assertEquals(scalarMax, postMax,getFailureMessage()); } @Test @@ -150,14 +150,14 @@ public class OpExecutionerTests extends BaseNd4jTest { Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); for (int i = 0; i < linspace.length(); i++) { double val = linspace.getDouble(i); - assertTrue(getFailureMessage(), val >= 0 && val <= 1); + assertTrue( val >= 0 && val <= 1,getFailureMessage()); } INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4)); for (int i = 0; i < linspace2.length(); i++) { double val = linspace2.getDouble(i); - assertTrue(getFailureMessage(), val >= 2 && val <= 4); + assertTrue( val >= 2 && val <= 4,getFailureMessage()); } } @@ -165,7 +165,7 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testNormMax() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0); - assertEquals(getFailureMessage(), 4, normMax, 1e-1); + assertEquals(4, normMax, 1e-1,getFailureMessage()); } @Test @@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testNorm2() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0); - assertEquals(getFailureMessage(), 5.4772255750516612, norm2, 1e-1); + assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage()); } @Test @@ -197,7 +197,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x})); - assertEquals(getFailureMessage(), solution, x); + assertEquals(solution, x,getFailureMessage()); } @Test @@ -218,13 +218,13 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x})); - assertEquals(getFailureMessage(), solution, x); + assertEquals(solution, x,getFailureMessage()); Sum acc = new Sum(x.dup()); opExecutioner.exec(acc); - assertEquals(getFailureMessage(), 10.0, acc.getFinalResult().doubleValue(), 1e-1); + assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); Prod prod = new Prod(x.dup()); opExecutioner.exec(prod); - assertEquals(getFailureMessage(), 32.0, prod.getFinalResult().doubleValue(), 1e-1); + assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @@ -268,7 +268,7 @@ public class OpExecutionerTests extends BaseNd4jTest { Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); - assertEquals(getFailureMessage(), 2.5, variance.getFinalResult().doubleValue(), 1e-1); + assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @@ -276,13 +276,13 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test public void testIamax() { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); - assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace)); + assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); } @Test public void testIamax2() { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); - assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace)); + assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); val op = new ArgAmax(linspace); int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0); @@ -297,11 +297,11 @@ public class OpExecutionerTests extends BaseNd4jTest { Mean mean = new Mean(x); opExecutioner.exec(mean); - assertEquals(getFailureMessage(), 3.0, mean.getFinalResult().doubleValue(), 1e-1); + assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); - assertEquals(getFailureMessage(), 2.5, variance.getFinalResult().doubleValue(), 1e-1); + assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @Test @@ -310,7 +310,7 @@ public class OpExecutionerTests extends BaseNd4jTest { val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); + assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); } @@ -320,7 +320,7 @@ public class OpExecutionerTests extends BaseNd4jTest { Pow pow = new Pow(oneThroughSix, 2); Nd4j.getExecutioner().exec(pow); INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36}); - assertEquals(getFailureMessage(), answer, pow.z()); + assertEquals(answer, pow.z(),getFailureMessage()); } @@ -368,7 +368,7 @@ public class OpExecutionerTests extends BaseNd4jTest { Log log = new Log(slice); opExecutioner.exec(log); INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791}); - assertEquals(getFailureMessage(), assertion, slice); + assertEquals(assertion, slice,getFailureMessage()); } @Test @@ -551,7 +551,7 @@ public class OpExecutionerTests extends BaseNd4jTest { expected[i] = (float) Math.exp(slice.getDouble(i)); Exp exp = new Exp(slice); opExecutioner.exec(exp); - assertEquals(getFailureMessage(), Nd4j.create(expected), slice); + assertEquals( Nd4j.create(expected), slice,getFailureMessage()); } @Test @@ -560,7 +560,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); + assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); } @Test @@ -658,7 +658,7 @@ public class OpExecutionerTests extends BaseNd4jTest { exp.putScalar(0, 0, i, j, sum); } } - assertEquals("Failed for [" + order + "] order", exp, arr6s); + assertEquals(exp, arr6s,"Failed for [" + order + "] order"); // System.out.println("ORDER: " + order); // for (int i = 0; i < 6; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 1c5eb3f97..d52c39755 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.ops; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -73,7 +73,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -88,7 +88,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { DataType initialType; - @After + @AfterEach public void after() { Nd4j.setDataType(this.initialType); } @@ -141,7 +141,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); double sim = Transforms.cosineSim(vec1, vec2); - assertEquals(getFailureMessage(), 1, sim, 1e-1); + assertEquals(1, sim, 1e-1,getFailureMessage()); } @Test @@ -150,7 +150,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray vec2 = Nd4j.create(new float[] {3, 5, 7}); // 1-17*sqrt(2/581) double distance = Transforms.cosineDistance(vec1, vec2); - assertEquals(getFailureMessage(), 0.0025851, distance, 1e-7); + assertEquals( 0.0025851, distance, 1e-7,getFailureMessage()); } @Test @@ -176,7 +176,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr2 = Nd4j.create(new double[] {60, 60}); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).getFinalResult() .doubleValue(); - assertEquals(getFailureMessage(), 7.0710678118654755, result, 1e-1); + assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage()); } @Test @@ -184,7 +184,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); - assertEquals(getFailureMessage(), postMax, scalarMax); + assertEquals(postMax, scalarMax,getFailureMessage()); } @Test @@ -193,14 +193,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest { Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); for (int i = 0; i < linspace.length(); i++) { double val = linspace.getDouble(i); - assertTrue(getFailureMessage(), val >= 0 && val <= 1); + assertTrue( val >= 0 && val <= 1,getFailureMessage()); } INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4)); for (int i = 0; i < linspace2.length(); i++) { double val = linspace2.getDouble(i); - assertTrue(getFailureMessage(), val >= 2 && val <= 4); + assertTrue(val >= 2 && val <= 4,getFailureMessage()); } } @@ -209,7 +209,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testNormMax() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).getFinalResult().doubleValue(); - assertEquals(getFailureMessage(), 4, normMax, 1e-1); + assertEquals(4, normMax, 1e-1,getFailureMessage()); } @@ -217,7 +217,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testNorm2() { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).getFinalResult().doubleValue(); - assertEquals(getFailureMessage(), 5.4772255750516612, norm2, 1e-1); + assertEquals( 5.4772255750516612, norm2, 1e-1,getFailureMessage()); } @Test @@ -227,7 +227,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x})); - assertEquals(getFailureMessage(), solution, x); + assertEquals(solution, x,getFailureMessage()); } @Test @@ -248,13 +248,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 2.0); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{ x})); - assertEquals(getFailureMessage(), solution, x); + assertEquals(solution, x,getFailureMessage()); Sum acc = new Sum(x.dup()); opExecutioner.exec(acc); - assertEquals(getFailureMessage(), 10.0, acc.getFinalResult().doubleValue(), 1e-1); + assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); Prod prod = new Prod(x.dup()); opExecutioner.exec(prod); - assertEquals(getFailureMessage(), 32.0, prod.getFinalResult().doubleValue(), 1e-1); + assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @@ -302,7 +302,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); - assertEquals(getFailureMessage(), 2.5, variance.getFinalResult().doubleValue(), 1e-1); + assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @@ -313,11 +313,11 @@ public class OpExecutionerTestsC extends BaseNd4jTest { Mean mean = new Mean(x); opExecutioner.exec(mean); - assertEquals(getFailureMessage(), 3.0, mean.getFinalResult().doubleValue(), 1e-1); + assertEquals(3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); Variance variance = new Variance(x.dup(), true); opExecutioner.exec(variance); - assertEquals(getFailureMessage(), 2.5, variance.getFinalResult().doubleValue(), 1e-1); + assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); } @Test @@ -326,7 +326,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); + assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); } @Test @@ -354,7 +354,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { Pow pow = new Pow(oneThroughSix, 2); Nd4j.getExecutioner().exec(pow); INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36}); - assertEquals(getFailureMessage(), answer, pow.z()); + assertEquals(answer, pow.z(),getFailureMessage()); } @@ -404,7 +404,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { Log exp = new Log(slice); opExecutioner.exec(exp); INDArray assertion = Nd4j.create(new double[] {0.0, 0.6931471824645996, 1.0986123085021973}); - assertEquals(getFailureMessage(), assertion, slice); + assertEquals(assertion, slice,getFailureMessage()); } @Test @@ -417,7 +417,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { expected[i] = (float) Math.exp(slice.getDouble(i)); Exp exp = new Exp(slice); opExecutioner.exec(exp); - assertEquals(getFailureMessage(), Nd4j.create(expected), slice); + assertEquals(Nd4j.create(expected), slice,getFailureMessage()); } @Test @@ -425,12 +425,12 @@ public class OpExecutionerTestsC extends BaseNd4jTest { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); - opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); + opExecutioner.exec(softMax); + assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val softmax = new SoftMax(linspace.dup()); - Nd4j.getExecutioner().exec((CustomOp) softmax); + Nd4j.getExecutioner().exec(softmax); assertEquals(linspace.rows(), softmax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @@ -439,9 +439,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testDimensionSoftMax() { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val max = new SoftMax(linspace); - Nd4j.getExecutioner().exec((CustomOp) max); + Nd4j.getExecutioner().exec(max); linspace.assign(max.outputArguments().get(0)); - assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); + assertEquals(linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1,getFailureMessage()); } @Test @@ -1059,7 +1059,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { * @throws Exception */ @Test - @Ignore + @Disabled public void testTadEws() { INDArray array = Nd4j.create(32, 5, 10); assertEquals(1, array.tensorAlongDimension(0, 1, 2).elementWiseStride()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java index db8096c4b..045e3e78c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.ops; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; @RunWith(Parameterized.class) public class RationalTanhTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java index 9ad664970..8971e80b2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.ops.broadcast.row; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java index 90d476c7a..21f70bc67 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.ops.copy; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class CopyTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java index 12e0a2292..a27a92c76 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.options; import lombok.extern.slf4j.Slf4j; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,8 +31,8 @@ import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j @RunWith(Parameterized.class) @@ -44,7 +44,7 @@ public class ArrayOptionsTests extends BaseNd4jTest { } - @Before + @BeforeEach public void setUp() { shapeInfo = new long[]{2, 2, 2, 2, 1, 0, 1, 99}; } @@ -84,9 +84,9 @@ public class ArrayOptionsTests extends BaseNd4jTest { String s = dt.toString(); long l = 0; l = ArrayOptionsHelper.setOptionBit(l, dt); - assertNotEquals(s, 0, l); + assertNotEquals(0, l,s); DataType dt2 = ArrayOptionsHelper.dataType(l); - assertEquals(s, dt, dt2); + assertEquals(dt, dt2,s); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java index 3d603b9b9..898232888 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java @@ -20,9 +20,9 @@ package org.nd4j.linalg.profiling; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,6 +33,8 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.jupiter.api.Assertions.assertThrows; + @RunWith(Parameterized.class) public class InfNanTests extends BaseNd4jTest { @@ -40,37 +42,43 @@ public class InfNanTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { } - @After + @AfterEach public void cleanUp() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testInf1() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NEGATIVE_INFINITY); + x.putScalar(2, Float.NEGATIVE_INFINITY); + + OpExecutionerUtil.checkForAny(x); + }); - OpExecutionerUtil.checkForAny(x); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testInf2() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NEGATIVE_INFINITY); + x.putScalar(2, Float.NEGATIVE_INFINITY); + + OpExecutionerUtil.checkForAny(x); + }); - OpExecutionerUtil.checkForAny(x); } @Test @@ -91,27 +99,33 @@ public class InfNanTests extends BaseNd4jTest { OpExecutionerUtil.checkForAny(x); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNaN1() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NaN); + x.putScalar(2, Float.NaN); + + OpExecutionerUtil.checkForAny(x); + }); - OpExecutionerUtil.checkForAny(x); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNaN2() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NaN); + x.putScalar(2, Float.NaN); + + OpExecutionerUtil.checkForAny(x); + }); - OpExecutionerUtil.checkForAny(x); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 8dd92bb69..5312f7adf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -23,10 +23,10 @@ package org.nd4j.linalg.profiling; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.ArrayUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -45,7 +45,7 @@ import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class OperationProfilerTests extends BaseNd4jTest { @@ -59,13 +59,13 @@ public class OperationProfilerTests extends BaseNd4jTest { return 'c'; } - @Before + @BeforeEach public void setUp() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.OPERATIONS); OpProfiler.getInstance().reset(); } - @After + @AfterEach public void tearDown() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } @@ -162,7 +162,7 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testBadCombos6() { INDArray x = Nd4j.create(27).reshape('f', 3, 3, 3).slice(1); INDArray y = Nd4j.create(100).reshape('f', 10, 10); @@ -219,7 +219,7 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testBadTad4() { INDArray x = Nd4j.create(2, 4, 5, 6); @@ -292,71 +292,86 @@ public class OperationProfilerTests extends BaseNd4jTest { } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNaNPanic1() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); - INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NaN}); + INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NaN}); + + a.muli(3f); + }); - a.muli(3f); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNaNPanic2() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); - INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.POSITIVE_INFINITY}); + INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.POSITIVE_INFINITY}); + + a.muli(3f); + }); - a.muli(3f); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNaNPanic3() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); - INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NEGATIVE_INFINITY}); + INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NEGATIVE_INFINITY}); + + a.muli(3f); + }); - a.muli(3f); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testScopePanic1() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); - INDArray array; + INDArray array; - try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS119")) { - array = Nd4j.create(10); + try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS119")) { + array = Nd4j.create(10); - assertTrue(array.isAttached()); - } - - array.add(1.0); - } - - - @Test(expected = ND4JIllegalStateException.class) - public void testScopePanic2() { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); - - INDArray array; - - try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS120")) { - array = Nd4j.create(10); - assertTrue(array.isAttached()); - - assertEquals(1, workspace.getGenerationId()); - } - - - try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS120")) { - assertEquals(2, workspace.getGenerationId()); + assertTrue(array.isAttached()); + } array.add(1.0); + }); + + } + + + @Test() + public void testScopePanic2() { + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + + INDArray array; + + try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS120")) { + array = Nd4j.create(10); + assertTrue(array.isAttached()); + + assertEquals(1, workspace.getGenerationId()); + } + + + try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS120")) { + assertEquals(2, workspace.getGenerationId()); + + array.add(1.0); + + assertTrue(array.isAttached()); + } + }); - assertTrue(array.isAttached()); - } } @@ -448,7 +463,7 @@ public class OperationProfilerTests extends BaseNd4jTest { } catch (Exception e){ //throw new RuntimeException(e); log.info("Message: {}", e.getMessage()); - assertTrue(e.getMessage(), e.getMessage().contains("NaN")); + assertTrue(e.getMessage().contains("NaN"),e.getMessage()); } INDArray in = op.getInputArgument(0); @@ -478,7 +493,7 @@ public class OperationProfilerTests extends BaseNd4jTest { fail(); } catch (Exception e){ log.error("",e); - assertTrue(e.getMessage(), e.getMessage().contains("Inf")); + assertTrue(e.getMessage().contains("Inf"),e.getMessage()); } INDArray in = op.getInputArgument(0); @@ -513,7 +528,7 @@ public class OperationProfilerTests extends BaseNd4jTest { fail("Expected op profiler exception"); } catch (Throwable t) { //OK - assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + assertTrue(t.getMessage().contains(nan ? "NaN" : "Inf"),t.getMessage()); } } } @@ -540,7 +555,7 @@ public class OperationProfilerTests extends BaseNd4jTest { fail("Expected op profiler exception"); } catch (Throwable t) { //OK - assertTrue(t.getMessage(), t.getMessage().contains(nan ? "NaN" : "Inf")); + assertTrue(t.getMessage().contains(nan ? "NaN" : "Inf"),t.getMessage()); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java index fae07a2fd..9e7c68979 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java @@ -22,10 +22,10 @@ package org.nd4j.linalg.profiling; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,8 +36,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.api.memory.MemcpyDirection; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @RunWith(Parameterized.class) @@ -46,13 +46,13 @@ public class PerformanceTrackerTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { PerformanceTracker.getInstance().clear(); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.BANDWIDTH); } - @After + @AfterEach public void tearDown() { PerformanceTracker.getInstance().clear(); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); @@ -107,7 +107,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testTrackerCpu_1() { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("native")) return; @@ -125,7 +125,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - @Ignore("useless these days") + @Disabled("useless these days") public void testTrackerGpu_1() { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java index 058acb4b3..c0f5470a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java @@ -21,10 +21,10 @@ package org.nd4j.linalg.profiling; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -35,8 +35,8 @@ import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.data.StackAggregator; import org.nd4j.linalg.profiler.data.primitives.StackDescriptor; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class StackAggregatorTests extends BaseNd4jTest { @@ -50,14 +50,14 @@ public class StackAggregatorTests extends BaseNd4jTest { return 'c'; } - @Before + @BeforeEach public void setUp() { Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().stackTrace(true).build()); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); OpProfiler.getInstance().reset(); } - @After + @AfterEach public void tearDown() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } @@ -129,7 +129,7 @@ public class StackAggregatorTests extends BaseNd4jTest { }*/ @Test - @Ignore + @Disabled public void testScalarAggregator() { INDArray x = Nd4j.create(10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java index d1ec78756..f198db33d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.rng; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -45,7 +45,7 @@ public class HalfTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { if (!Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) return; @@ -54,7 +54,7 @@ public class HalfTests extends BaseNd4jTest { Nd4j.setDataType(DataType.HALF); } - @After + @AfterEach public void tearDown() { if (!Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java index 8843b2c5e..6e5c966b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.rng; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; @Slf4j @RunWith(Parameterized.class) -@Ignore +@Disabled public class RandomPerformanceTests extends BaseNd4jTest { public RandomPerformanceTests(Nd4jBackend backend) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 0413b7677..34ed252ed 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -24,10 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.math3.random.JDKRandomGenerator; import org.apache.commons.math3.util.FastMath; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -67,7 +67,7 @@ import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -79,13 +79,13 @@ public class RandomTests extends BaseNd4jTest { super(backend); } - @Before + @BeforeEach public void setUp() { initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.DOUBLE); } - @After + @AfterEach public void tearDown() { Nd4j.setDataType(initialType); } @@ -177,7 +177,7 @@ public class RandomTests extends BaseNd4jTest { INDArray z2 = Nd4j.randn('c', new int[] {1, 1000}); - assertEquals("Failed on iteration " + i, z1, z2); + assertEquals(z1, z2,"Failed on iteration " + i); } } @@ -192,7 +192,7 @@ public class RandomTests extends BaseNd4jTest { INDArray z2 = Nd4j.rand('c', new int[] {1, 1000}); - assertEquals("Failed on iteration " + i, z1, z2); + assertEquals( z1, z2,"Failed on iteration " + i); } } @@ -207,7 +207,7 @@ public class RandomTests extends BaseNd4jTest { INDArray z2 = Nd4j.getExecutioner().exec(new BinomialDistribution(Nd4j.createUninitialized(1000), 10, 0.2)); - assertEquals("Failed on iteration " + i, z1, z2); + assertEquals(z1, z2,"Failed on iteration " + i); } } @@ -242,7 +242,7 @@ public class RandomTests extends BaseNd4jTest { for (int x = 0; x < z1.length(); x++) { - assertEquals("Failed on element: [" + x + "]", z1.getFloat(x), z2.getFloat(x), 0.01f); + assertEquals(z1.getFloat(x), z2.getFloat(x), 0.01f,"Failed on element: [" + x + "]"); } assertEquals(z1, z2); } @@ -421,7 +421,7 @@ public class RandomTests extends BaseNd4jTest { A = A / n - n; A *= (1 + 4.0/n - 25.0/(n*n)); - assertTrue("Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A, A < 1.8692); + assertTrue(A < 1.8692,"Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A); } @Test @@ -491,13 +491,13 @@ public class RandomTests extends BaseNd4jTest { for (int x = 0; x < z01.length(); x++) { - assertEquals("Failed on element: [" + x + "]", z01.getFloat(x), z11.getFloat(x), 0.01f); + assertEquals(z11.getFloat(x), z01.getFloat(x),0.01f,"Failed on element: [" + x + "]"); } assertEquals(z01, z11); for (int x = 0; x < z02.length(); x++) { - assertEquals("Failed on element: [" + x + "]", z02.getFloat(x), z12.getFloat(x), 0.01f); + assertEquals(z02.getFloat(x), z12.getFloat(x), 0.01f,"Failed on element: [" + x + "]"); } assertEquals(z02, z12); @@ -891,7 +891,7 @@ public class RandomTests extends BaseNd4jTest { assertEquals(exp, sampled); } - @Ignore + @Disabled @Test public void testDeallocation1() throws Exception { @@ -1254,7 +1254,7 @@ public class RandomTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testTruncatedNormal1() { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -1273,7 +1273,7 @@ public class RandomTests extends BaseNd4jTest { log.info("Truncated: {} ms; Gaussian: {} ms", time2 - time1, time3 - time2); for (int e = 0; e < z01.length(); e++) { - assertTrue("Value: " + z01.getDouble(e) + " at " + e,FastMath.abs(z01.getDouble(e)) <= 2.0); + assertTrue(FastMath.abs(z01.getDouble(e)) <= 2.0,"Value: " + z01.getDouble(e) + " at " + e); assertNotEquals(-119119d, z01.getDouble(e), 1e-3); } @@ -1364,7 +1364,7 @@ public class RandomTests extends BaseNd4jTest { INDArray arr = Nd4j.create(DataType.DOUBLE, 100); Nd4j.exec(new BernoulliDistribution(arr, 0.5)); double sum = arr.sumNumber().doubleValue(); - assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0); + assertTrue(sum > 0.0 && sum < 100.0,String.valueOf(sum)); } private List getList(int numBatches){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java index 1801d5a0a..5f74d8be4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java @@ -20,13 +20,13 @@ package org.nd4j.linalg.rng; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; import lombok.Builder; import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.common.base.Preconditions; import org.nd4j.common.util.ArrayUtil; @@ -281,8 +281,8 @@ public class RngValidationTests extends BaseNd4jTest { //Check for NaNs, Infs, etc int countNaN = Nd4j.getExecutioner().exec(new MatchConditionTransform(z, Nd4j.create(DataType.BOOL, z.shape()), Conditions.isNan())).castTo(DataType.INT).sumNumber().intValue(); int countInf = Nd4j.getExecutioner().exec(new MatchConditionTransform(z, Nd4j.create(DataType.BOOL, z.shape()), Conditions.isInfinite())).castTo(DataType.INT).sumNumber().intValue(); - assertEquals("NaN - expected 0 values", 0, countNaN); - assertEquals("Infinite - expected 0 values", 0, countInf); + assertEquals(0, countNaN,"NaN - expected 0 values"); + assertEquals( 0, countInf,"Infinite - expected 0 values"); //Check min/max values double min = z.minNumber().doubleValue(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java index 699831057..761380274 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.schedule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.shade.jackson.databind.DeserializationFeature; @@ -28,7 +28,7 @@ import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestSchedules extends BaseNd4jTest { @@ -113,7 +113,7 @@ public class TestSchedules extends BaseNd4jTest { throw new RuntimeException(); } - assertEquals(s.toString() + ", " + st, e, now, 1e-6); + assertEquals(e, now, 1e-6,s.toString() + ", " + st); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java index 12b8da7bd..cdbb14bd3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.serde; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -47,7 +47,7 @@ public class BasicSerDeTests extends BaseNd4jTest { DataType initialType; - @After + @AfterEach public void after() { Nd4j.setDataType(this.initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java index 800b19655..d2e70277c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -39,7 +39,7 @@ import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class JsonSerdeTests extends BaseNd4jTest { @@ -83,7 +83,7 @@ public class JsonSerdeTests extends BaseNd4jTest { // System.out.println("\n\n\n"); TestClass deserialized = om.readValue(s, TestClass.class); - assertEquals(dt.toString(), tc, deserialized); + assertEquals(tc, deserialized,dt.toString()); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java index 44f7b8755..706c727f5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.serde; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,12 +33,12 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) @Slf4j -@Ignore("AB 2019/05/23 - JVM crash on linux-x86_64-cpu-avx512 - issue #7657") +@Disabled("AB 2019/05/23 - JVM crash on linux-x86_64-cpu-avx512 - issue #7657") public class LargeSerDeTests extends BaseNd4jTest { public LargeSerDeTests(Nd4jBackend backend) { super(backend); @@ -68,7 +68,7 @@ public class LargeSerDeTests extends BaseNd4jTest { @Test - @Ignore // this should be commented out, since it requires approx 10GB ram to run + @Disabled // this should be commented out, since it requires approx 10GB ram to run public void testLargeArraySerDe_2() throws Exception { INDArray arrayA = Nd4j.createUninitialized(100000, 12500); log.info("Shape: {}; Length: {}", arrayA.shape(), arrayA.length()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 3b8c9075f..f8452e84a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -23,10 +23,11 @@ package org.nd4j.linalg.serde; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.FileUtils; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -35,24 +36,23 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class NumpyFormatTests extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); public NumpyFormatTests(Nd4jBackend backend) { super(backend); } @Test - public void testToNpyFormat() throws Exception { + public void testToNpyFormat(@TempDir Path testDir) throws Exception { - val dir = testDir.newFolder(); + val dir = testDir.toFile(); new ClassPathResource("numpy_arrays/").copyDirectory(dir); File[] files = dir.listFiles(); @@ -91,7 +91,7 @@ public class NumpyFormatTests extends BaseNd4jTest { System.out.println(); */ - assertArrayEquals("Failed with file [" + f.getName() + "]", expected, bytes); + assertArrayEquals(expected, bytes,"Failed with file [" + f.getName() + "]"); cnt++; } @@ -99,10 +99,10 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test - public void testToNpyFormatScalars() throws Exception { + public void testToNpyFormatScalars(@TempDir Path testDir) throws Exception { // File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar"); - val dir = testDir.newFolder(); + val dir = testDir.toFile(); new ClassPathResource("numpy_arrays/scalar/").copyDirectory(dir); File[] files = dir.listFiles(); @@ -142,7 +142,7 @@ public class NumpyFormatTests extends BaseNd4jTest { System.out.println(); */ - assertArrayEquals("Failed with file [" + f.getName() + "]", expected, bytes); + assertArrayEquals(expected, bytes,"Failed with file [" + f.getName() + "]"); cnt++; System.out.println(); @@ -153,9 +153,9 @@ public class NumpyFormatTests extends BaseNd4jTest { @Test - public void testNpzReading() throws Exception { + public void testNpzReading(@TempDir Path testDir) throws Exception { - val dir = testDir.newFolder(); + val dir = testDir.toFile(); new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir); File[] files = dir.listFiles(); @@ -212,9 +212,9 @@ public class NumpyFormatTests extends BaseNd4jTest { @Test - public void testNpy() throws Exception { + public void testNpy(@TempDir Path testDir) throws Exception { for(boolean empty : new boolean[]{false, true}) { - val dir = testDir.newFolder(); + val dir = testDir.toFile(); if(!empty) { new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir); } else { @@ -247,7 +247,7 @@ public class NumpyFormatTests extends BaseNd4jTest { } INDArray act = Nd4j.createFromNpyFile(f); - assertEquals("Failed with file [" + f.getName() + "]", exp, act); + assertEquals( exp, act,"Failed with file [" + f.getName() + "]"); cnt++; } @@ -261,62 +261,74 @@ public class NumpyFormatTests extends BaseNd4jTest { assertEquals(Nd4j.scalar(DataType.INT, 1), out); } - @Test(expected = RuntimeException.class) - public void readNumpyCorruptHeader1() throws Exception { - File f = testDir.newFolder(); + @Test() + public void readNumpyCorruptHeader1(@TempDir Path testDir) throws Exception { + assertThrows(RuntimeException.class,() -> { + File f = testDir.toFile(); - File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); - byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); - for( int i=0; i<10; i++ ){ - numpyBytes[i] = 0; - } - File fCorrupt = new File(f, "corrupt.npy"); - FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); + File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); + byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); + for( int i = 0; i < 10; i++) { + numpyBytes[i] = 0; + } + File fCorrupt = new File(f, "corrupt.npy"); + FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); - INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); + INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); - INDArray act1 = Nd4j.createFromNpyFile(fValid); - assertEquals(exp, act1); + INDArray act1 = Nd4j.createFromNpyFile(fValid); + assertEquals(exp, act1); + + INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine + boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content + }); - INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine - boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content } - @Test(expected = RuntimeException.class) - public void readNumpyCorruptHeader2() throws Exception { - File f = testDir.newFolder(); + @Test() + public void readNumpyCorruptHeader2(@TempDir Path testDir) throws Exception { + assertThrows(RuntimeException.class,() -> { + File f = testDir.toFile(); - File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); - byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); - for( int i=1; i<10; i++ ){ - numpyBytes[i] = 0; - } - File fCorrupt = new File(f, "corrupt.npy"); - FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); + File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); + byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); + for( int i = 1; i < 10; i++) { + numpyBytes[i] = 0; + } + File fCorrupt = new File(f, "corrupt.npy"); + FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); - INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); + INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); - INDArray act1 = Nd4j.createFromNpyFile(fValid); - assertEquals(exp, act1); + INDArray act1 = Nd4j.createFromNpyFile(fValid); + assertEquals(exp, act1); + + INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine + boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content + }); - INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine - boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content } - @Test(expected = IllegalArgumentException.class) + @Test() public void testAbsentNumpyFile_1() throws Exception { - val f = new File("pew-pew-zomg.some_extension_that_wont_exist"); - INDArray act1 = Nd4j.createFromNpyFile(f); + assertThrows(IllegalArgumentException.class,() -> { + val f = new File("pew-pew-zomg.some_extension_that_wont_exist"); + INDArray act1 = Nd4j.createFromNpyFile(f); + }); + } - @Test(expected = IllegalArgumentException.class) + @Test() public void testAbsentNumpyFile_2() throws Exception { - val f = new File("c:/develop/batch-x-1.npy"); - INDArray act1 = Nd4j.createFromNpyFile(f); - log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); + assertThrows(IllegalArgumentException.class,() -> { + val f = new File("c:/develop/batch-x-1.npy"); + INDArray act1 = Nd4j.createFromNpyFile(f); + log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); + }); + } - @Ignore + @Disabled @Test public void testNumpyBoolean() { INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy")); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 8a55ae722..14b0e858d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -34,7 +34,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -180,10 +180,13 @@ public class EmptyTests extends BaseNd4jTest { assertEquals(1, array.rank()); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testEmptyWithShape_3() { - val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); - array.tensorAlongDimension(0, 2); + assertThrows(IllegalArgumentException.class,() -> { + val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); + array.tensorAlongDimension(0, 2); + }); + } @Test @@ -239,15 +242,18 @@ public class EmptyTests extends BaseNd4jTest { assertEquals(e, reduced); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testEmptyReduction_4() { - val x = Nd4j.create(DataType.FLOAT, 2, 0); - val e = Nd4j.create(DataType.FLOAT, 0); + assertThrows(ND4JIllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 2, 0); + val e = Nd4j.create(DataType.FLOAT, 0); - val reduced = x.argMax(1); + val reduced = x.argMax(1); + + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + }); - assertArrayEquals(e.shape(), reduced.shape()); - assertEquals(e, reduced); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java index edcb3f125..2db07226d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,8 +29,8 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(Parameterized.class) public class LongShapeTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java index 54263a428..c7ba3e7b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,7 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.util.NDArrayMath; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java index e8023640d..3e82c9844 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,7 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.util.ArrayUtil; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(Parameterized.class) public class ShapeBufferTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java index 65a1185c6..df1cf9ae3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -36,7 +36,7 @@ import org.nd4j.common.primitives.Triple; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; /** @@ -76,7 +76,7 @@ public class ShapeTests extends BaseNd4jTest { for (int i = 0; i < baseArr.tensorsAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); - assertEquals("Wrong at index " + i, assertions[i], test); + assertEquals(assertions[i], test,"Wrong at index " + i); } } @@ -106,7 +106,7 @@ public class ShapeTests extends BaseNd4jTest { for (int i = 0; i < baseArr.tensorsAlongDimension(2); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 2); - assertEquals("Failed at index " + i, assertions[i], arr); + assertEquals( assertions[i], arr,"Failed at index " + i); } } @@ -207,7 +207,7 @@ public class ShapeTests extends BaseNd4jTest { for (int i = 0; i < baseArr.tensorsAlongDimension(1); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 1); - assertEquals("Failed at index " + i, assertions[i], arr); + assertEquals( assertions[i], arr,"Failed at index " + i); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index bd5353352..45dd3b447 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,7 +33,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson @@ -49,7 +49,7 @@ public class ShapeTestsC extends BaseNd4jTest { DataType initialType; - @After + @AfterEach public void after() { Nd4j.setDataType(this.initialType); } @@ -69,7 +69,7 @@ public class ShapeTestsC extends BaseNd4jTest { new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorsAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); - assertEquals("Wrong at index " + i, assertions[i], test); + assertEquals( assertions[i], test,"Wrong at index " + i); } } @@ -87,7 +87,7 @@ public class ShapeTestsC extends BaseNd4jTest { for (int i = 0; i < baseArr.tensorsAlongDimension(2); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 2); - assertEquals("Failed at index " + i, assertions[i], arr); + assertEquals( assertions[i], arr,"Failed at index " + i); } } @@ -151,7 +151,7 @@ public class ShapeTestsC extends BaseNd4jTest { for (int i = 0; i < baseArr.tensorsAlongDimension(1); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 1); - assertEquals("Failed at index " + i, assertions[i], arr); + assertEquals(assertions[i], arr,"Failed at index " + i); } } @@ -271,7 +271,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.FLOAT).reshape(150, 4); INDArray columnVar = twoByThree.sum(0); INDArray assertion = Nd4j.create(new float[] {44850.0f, 45000.0f, 45150.0f, 45300.0f}); - assertEquals(getFailureMessage(), assertion, columnVar); + assertEquals(assertion, columnVar,getFailureMessage()); } @@ -280,7 +280,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowMean = twoByThree.mean(1); INDArray assertion = Nd4j.create(new double[] {1.5, 3.5}); - assertEquals(getFailureMessage(), assertion, rowMean); + assertEquals(assertion, rowMean,getFailureMessage()); } @@ -290,7 +290,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowStd = twoByThree.std(1); INDArray assertion = Nd4j.create(new double[] {0.7071067811865476f, 0.7071067811865476f}); - assertEquals(getFailureMessage(), assertion, rowStd); + assertEquals(assertion, rowStd,getFailureMessage()); } @@ -302,7 +302,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.DOUBLE).reshape(150, 4); INDArray columnVar = twoByThree.sum(0); INDArray assertion = Nd4j.create(new double[] {44850.0f, 45000.0f, 45150.0f, 45300.0f}); - assertEquals(getFailureMessage(), assertion, columnVar); + assertEquals(assertion, columnVar,getFailureMessage()); DataTypeUtil.setDTypeForContext(initialType); } @@ -322,14 +322,14 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {1, 4}); INDArray cumSumAnswer = Nd4j.create(new double[] {1, 3, 6, 10}, new long[] {1, 4}); INDArray cumSumTest = n.cumsum(0); - assertEquals(getFailureMessage(), cumSumAnswer, cumSumTest); + assertEquals( cumSumAnswer, cumSumTest,getFailureMessage()); INDArray n2 = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray axis0assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0}, n2.shape()); INDArray axis0Test = n2.cumsum(0); - assertEquals(getFailureMessage(), axis0assertion, axis0Test); + assertEquals(axis0assertion, axis0Test,getFailureMessage()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java index 85ff0b072..7a7386f9d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -38,8 +38,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java index 01f592601..c85264965 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.shape; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -37,7 +37,7 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @@ -145,9 +145,9 @@ public class TADTests extends BaseNd4jTest { INDArray get = orig.get(all(), all(), point(i)); String str = String.valueOf(i); - assertEquals(str, get, tad); - assertEquals(str, get.data().offset(), tad.data().offset()); - assertEquals(str, get.elementWiseStride(), tad.elementWiseStride()); + assertEquals(get, tad,str); + assertEquals(get.data().offset(), tad.data().offset(),str); + assertEquals(get.elementWiseStride(), tad.elementWiseStride(),str); char orderTad = Shape.getOrder(tad.shape(), tad.stride(), 1); char orderGet = Shape.getOrder(get.shape(), get.stride(), 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java index c3ecfa688..846166367 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.shape.concat; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -38,8 +38,8 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Adam Gibson @@ -171,7 +171,7 @@ public class ConcatTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testConcat3dv2() { INDArray first = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape('c', 2, 3, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 49586d864..391d1fec7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.shape.concat; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -41,8 +41,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; /** @@ -218,13 +217,16 @@ public class ConcatTestsC extends BaseNd4jTest { assertEquals(exp, concat2); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testConcatVector() { - Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1)); + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1)); + + }); } @Test - @Ignore + @Disabled public void testConcat3dv2() { INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java index a5695f2d2..47867eae1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.shape.concat.padding; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,8 +29,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 6a549df65..055185e58 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape.concat.padding; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,8 +31,8 @@ import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java index d122046f8..522b0fe2c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.shape.indexing; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; import org.junit.rules.ErrorCollector; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -37,7 +37,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson @@ -46,8 +46,6 @@ import static org.junit.Assert.*; @RunWith(Parameterized.class) public class IndexingTests extends BaseNd4jTest { - @Rule - public ErrorCollector collector = new ErrorCollector(); public IndexingTests(Nd4jBackend backend) { super(backend); @@ -136,9 +134,9 @@ public class IndexingTests extends BaseNd4jTest { expected.putScalar(0, 1, 20); expected.putScalar(1, 0, 14); expected.putScalar(1, 1, 23); - assertEquals("View with two get", expected, viewTwo); - assertEquals("View with one get", expected, viewOne); //FAILS! - assertEquals("Two views should be the same", viewOne, viewTwo); //Obviously fails + assertEquals(expected, viewTwo,"View with two get"); + assertEquals(expected, viewOne,"View with one get"); //FAILS! + assertEquals(viewOne, viewTwo,"Two views should be the same"); //Obviously fails } /* @@ -164,9 +162,9 @@ public class IndexingTests extends BaseNd4jTest { INDArray sameView = A.get(ndi_Slice, ndi_I, ndi_J); String failureMessage = String.format("Fails for (%d , %d - %d, %d - %d)\n", s, i, rows, j, cols); try { - assertEquals(failureMessage, aView, sameView); + assertEquals(aView, sameView,failureMessage); } catch (Throwable t) { - collector.addError(t); + log.error("Error with view",t); } } } @@ -175,7 +173,7 @@ public class IndexingTests extends BaseNd4jTest { @Test - @Ignore //added recently: For some reason this is passing. + @Disabled //added recently: For some reason this is passing. // The test .equals fails on a comparison of row vs column vector. //TODO: possibly figure out what's going on here at some point? // - Adam diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index 9ae05a57d..d6507d6ce 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.shape.indexing; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Rule; -import org.junit.Test; + +import org.junit.jupiter.api.Test; import org.junit.rules.ErrorCollector; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -37,7 +37,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson @@ -45,8 +45,7 @@ import static org.junit.Assert.*; @Slf4j @RunWith(Parameterized.class) public class IndexingTestsC extends BaseNd4jTest { - @Rule - public ErrorCollector collector = new ErrorCollector(); + public IndexingTestsC(Nd4jBackend backend) { super(backend); @@ -58,7 +57,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray sub = nd.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)); Nd4j.getExecutioner().exec(new ScalarAdd(sub, 2)); - assertEquals(getFailureMessage(), Nd4j.create(new double[][] {{3, 4}, {6, 7}}), sub); + assertEquals(Nd4j.create(new double[][] {{3, 4}, {6, 7}}), sub,getFailureMessage()); } @@ -287,9 +286,9 @@ public class IndexingTestsC extends BaseNd4jTest { expected.putScalar(0, 1, 12); expected.putScalar(1, 0, 14); expected.putScalar(1, 1, 15); - assertEquals("View with two get", expected, viewTwo); - assertEquals("View with one get", expected, viewOne); //FAILS! - assertEquals("Two views should be the same", viewOne, viewTwo); //obviously fails + assertEquals(expected, viewTwo,"View with two get"); + assertEquals( expected, viewOne,"View with one get"); //FAILS! + assertEquals(viewOne, viewTwo,"Two views should be the same"); //obviously fails } /* @@ -315,9 +314,10 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray sameView = A.get(ndi_Slice, ndi_I, ndi_J); String failureMessage = String.format("Fails for (%d , %d - %d, %d - %d)\n", s, i, rows, j, cols); try { - assertEquals(failureMessage, aView, sameView); + assertEquals(aView, sameView,failureMessage); } catch (Throwable t) { - collector.addError(t); + log.error("Error on view ",t); + //collector.addError(t); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java index 18371fe87..eb691afef 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.shape.ones; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java index 683cc6a4c..cf9a1a9b3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.shape.ones; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java index 79c5bc36b..144fc146f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.shape.reshape; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,8 +30,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.Assume.assumeNotNull; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java index 4c37e81e2..6f8d80828 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java @@ -21,7 +21,7 @@ package org.nd4j.linalg.slicing; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -30,7 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java index c6dea94df..b627ea3b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.slicing; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -31,8 +31,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java index c7ac31350..9347addcb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.specials; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -33,7 +33,7 @@ import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -46,12 +46,12 @@ public class CudaTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDataType(DataType.FLOAT); } - @After + @AfterEach public void setDown() { Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java index efd292200..a9b1d8da7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java @@ -21,8 +21,8 @@ package org.nd4j.linalg.specials; import lombok.extern.slf4j.Slf4j; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -38,11 +38,11 @@ import org.nd4j.common.primitives.Pair; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j -@Ignore +@Disabled @RunWith(Parameterized.class) public class LongTests extends BaseNd4jTest { @@ -123,7 +123,7 @@ public class LongTests extends BaseNd4jTest { INDArray hugeY = Nd4j.create(1, 1000).assign(2.0); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 1000, hugeX.getRow(x).sumNumber().intValue()); + assertEquals(1000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); } INDArray result = Nd4j.getExecutioner().exec(new ManhattanDistance(hugeX, hugeY, 1)); @@ -139,7 +139,7 @@ public class LongTests extends BaseNd4jTest { hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 3000, hugeX.getRow(x).sumNumber().intValue()); + assertEquals( hugeX.getRow(x).sumNumber().intValue(),3000,"Failed at row " + x); } } @@ -150,7 +150,7 @@ public class LongTests extends BaseNd4jTest { hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 3000, hugeX.getRow(x).sumNumber().intValue()); + assertEquals( 3000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); } } @@ -161,7 +161,7 @@ public class LongTests extends BaseNd4jTest { INDArray mean = hugeX.mean(1); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 1.0, mean.getDouble(x), 1e-5); + assertEquals( 1.0, mean.getDouble(x), 1e-5,"Failed at row " + x); } } @@ -172,7 +172,7 @@ public class LongTests extends BaseNd4jTest { INDArray mean = hugeX.argMax(1); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 0.0, mean.getDouble(x), 1e-5); + assertEquals(0.0, mean.getDouble(x), 1e-5,"Failed at row " + x); } } @@ -187,7 +187,7 @@ public class LongTests extends BaseNd4jTest { INDArray hugeX = Nd4j.vstack(list); for (int x = 0; x < hugeX.rows(); x++) { - assertEquals("Failed at row " + x, 2.0, hugeX.getRow(x).meanNumber().doubleValue(), 1e-5); + assertEquals(2.0, hugeX.getRow(x).meanNumber().doubleValue(), 1e-5,"Failed at row " + x); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java index ab5007fc4..23bdfb376 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java @@ -22,10 +22,10 @@ package org.nd4j.linalg.specials; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.LongPointer; -import org.junit.After; +import org.junit.jupiter.api.AfterEach; import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -47,12 +47,12 @@ public class RavelIndexTest extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDataType(DataType.FLOAT); } - @After + @AfterEach public void setDown() { Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java index e1a8ebf7c..1941ac47a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java @@ -25,9 +25,9 @@ import com.google.common.primitives.Floats; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.LongPointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -43,7 +43,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; import java.util.stream.LongStream; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; @Slf4j @RunWith(Parameterized.class) @@ -58,12 +58,12 @@ public class SortCooTests extends BaseNd4jTest { this.initialDefaultType = Nd4j.defaultFloatingPointType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } - @After + @AfterEach public void setDown() { Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java index 83ab2bb78..b77633083 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java @@ -21,10 +21,11 @@ package org.nd4j.linalg.util; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.AfterEach; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -32,7 +33,9 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.tools.SIS; -import static org.junit.Assert.assertTrue; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class DataSetUtilsTest extends BaseNd4jTest { @@ -47,19 +50,18 @@ public class DataSetUtilsTest extends BaseNd4jTest { } // - @Rule - public TemporaryFolder tmpFld = new TemporaryFolder(); + // private SIS sis; // @Test - public void testAll() { + public void testAll(@TempDir Path tmpFld) { // sis = new SIS(); // int mtLv = 0; // - sis.initValues( mtLv, "TEST", System.out, System.err, tmpFld.getRoot().getAbsolutePath(), "Test", "ABC", true, true ); + sis.initValues( mtLv, "TEST", System.out, System.err, tmpFld.toAbsolutePath().toString(), "Test", "ABC", true, true ); // INDArray in_INDA = Nd4j.zeros( 8, 8 ); INDArray ot_INDA = Nd4j.ones( 8, 1 ); @@ -88,7 +90,7 @@ public class DataSetUtilsTest extends BaseNd4jTest { // } - @After + @AfterEach public void after() { // int mtLv = 0; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java index 13e9f90ee..7e784853f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.util; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.common.util.ArrayUtil; @@ -28,8 +28,8 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hamdi Douss diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java index 53bc9bb82..b6f06e668 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.util; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,8 +30,8 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class PreconditionsTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java index 5ac3f1613..6f1d088be 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java @@ -20,7 +20,7 @@ package org.nd4j.linalg.util; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -29,7 +29,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java index 9e617ee6b..419b8d015 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.util; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -34,8 +34,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson @@ -186,14 +185,17 @@ public class ShapeTestC extends BaseNd4jTest { assertArrayEquals(exp, norm); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testAxisNormalization_3() { - val axis = new int[] {1, -2, 2}; - val rank = 2; - val exp = new int[] {0, 1}; + assertThrows(ND4JIllegalStateException.class,() -> { + val axis = new int[] {1, -2, 2}; + val rank = 2; + val exp = new int[] {0, 1}; + + val norm = Shape.normalizeAxis(rank, axis); + assertArrayEquals(exp, norm); + }); - val norm = Shape.normalizeAxis(rank, axis); - assertArrayEquals(exp, norm); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java index 8f42be2cb..64e235af5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java @@ -20,14 +20,14 @@ package org.nd4j.linalg.util; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.util.ArrayUtil; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.Random; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TestArrayUtils extends BaseNd4jTest { @@ -196,7 +196,7 @@ public class TestArrayUtils extends BaseNd4jTest { fail("Expected exception"); } catch (Exception e){ String msg = e.getMessage(); - assertTrue(msg, msg.contains("Ragged array detected")); + assertTrue(msg.contains("Ragged array detected"),msg); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java index 68c923126..1d153d512 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java @@ -20,15 +20,15 @@ package org.nd4j.linalg.util; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.common.collection.CompactHeapStringList; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestCollections extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java index 8621af247..01a363b10 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java @@ -21,10 +21,11 @@ package org.nd4j.linalg.util; import org.apache.commons.io.FileUtils; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.BaseNd4jTest; @@ -40,29 +41,28 @@ import java.io.BufferedOutputStream; import java.io.File; import java.io.FileOutputStream; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class ValidationUtilTests extends BaseNd4jTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - + public ValidationUtilTests(Nd4jBackend backend) { super(backend); } @Test - public void testFileValidation() throws Exception { - File f = testDir.newFolder(); + public void testFileValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.bin"); ValidationResult vr0 = Nd4jCommonValidator.isValidFile(fNonExistent); assertFalse(vr0.isValid()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -70,7 +70,7 @@ public class ValidationUtilTests extends BaseNd4jTest { fEmpty.createNewFile(); ValidationResult vr1 = Nd4jCommonValidator.isValidFile(fEmpty); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory @@ -79,7 +79,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(created); ValidationResult vr2 = Nd4jCommonValidator.isValidFile(directory); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test valid non-empty file - valid @@ -91,14 +91,14 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testZipValidation() throws Exception { - File f = testDir.newFolder(); + public void testZipValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.zip"); ValidationResult vr0 = Nd4jCommonValidator.isValidZipFile(fNonExistent, false); assertFalse(vr0.isValid()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty zip: @@ -106,7 +106,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(fEmpty.exists()); ValidationResult vr1 = Nd4jCommonValidator.isValidZipFile(fEmpty, false); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -115,7 +115,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertTrue(created); ValidationResult vr2 = Nd4jCommonValidator.isValidFile(directory); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-empty zip - valid @@ -134,22 +134,22 @@ public class ValidationUtilTests extends BaseNd4jTest { assertFalse(vr4.isValid()); assertEquals(1, vr4.getIssues().size()); String s = vr4.getIssues().get(0); - assertTrue(s, s.contains("someFile1.bin") && s.contains("someFile2.bin")); - assertFalse(s, s.contains("content.txt")); + assertTrue(s.contains("someFile1.bin") && s.contains("someFile2.bin"),s); + assertFalse( s.contains("content.txt"),s); // System.out.println(vr4.toString()); } @Test - public void testINDArrayTextValidation() throws Exception { - File f = testDir.newFolder(); + public void testINDArrayTextValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.txt"); ValidationResult vr0 = Nd4jValidator.validateINDArrayTextFile(fNonExistent); assertFalse(vr0.isValid()); assertEquals("INDArray Text File", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -159,7 +159,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jValidator.validateINDArrayTextFile(fEmpty); assertEquals("INDArray Text File", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -169,7 +169,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jValidator.validateINDArrayTextFile(directory); assertEquals("INDArray Text File", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-INDArray format: @@ -179,7 +179,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("INDArray Text File", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("text") && s.contains("INDArray") && s.contains("corrupt")); + assertTrue(s.contains("text") && s.contains("INDArray") && s.contains("corrupt"),s); // System.out.println(vr3.toString()); //Test corrupted txt format: @@ -197,7 +197,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("INDArray Text File", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("text") && s.contains("INDArray") && s.contains("corrupt")); + assertTrue(s.contains("text") && s.contains("INDArray") && s.contains("corrupt"),s); // System.out.println(vr4.toString()); @@ -212,17 +212,17 @@ public class ValidationUtilTests extends BaseNd4jTest { @Test - @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") - public void testNpyValidation() throws Exception { + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") + public void testNpyValidation(@TempDir Path testDir) throws Exception { - File f = testDir.newFolder(); + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.npy"); ValidationResult vr0 = Nd4jValidator.validateNpyFile(fNonExistent); assertFalse(vr0.isValid()); assertEquals("Numpy .npy File", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -232,7 +232,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jValidator.validateNpyFile(fEmpty); assertEquals("Numpy .npy File", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -242,7 +242,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jValidator.validateNpyFile(directory); assertEquals("Numpy .npy File", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-numpy format: @@ -252,7 +252,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npy File", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue(s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr3.toString()); //Test corrupted npy format: @@ -268,7 +268,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npy File", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue(s.contains("npy") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr4.toString()); @@ -282,16 +282,16 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testNpzValidation() throws Exception { + public void testNpzValidation(@TempDir Path testDIr) throws Exception { - File f = testDir.newFolder(); + File f = testDIr.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.npz"); ValidationResult vr0 = Nd4jValidator.validateNpzFile(fNonExistent); assertFalse(vr0.isValid()); assertEquals("Numpy .npz File", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -301,7 +301,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jValidator.validateNpzFile(fEmpty); assertEquals("Numpy .npz File", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -311,7 +311,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jValidator.validateNpzFile(directory); assertEquals("Numpy .npz File", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-numpy format: @@ -321,7 +321,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npz File", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue(s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr3.toString()); //Test corrupted npz format: @@ -337,7 +337,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy .npz File", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue( s.contains("npz") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr4.toString()); @@ -351,15 +351,15 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testNumpyTxtValidation() throws Exception { - File f = testDir.newFolder(); + public void testNumpyTxtValidation(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); //Test not existent file: File fNonExistent = new File("doesntExist.txt"); ValidationResult vr0 = Nd4jValidator.validateNumpyTxtFile(fNonExistent, " ", StandardCharsets.UTF_8); assertFalse(vr0.isValid()); assertEquals("Numpy text file", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -369,7 +369,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jValidator.validateNumpyTxtFile(fEmpty, " ", StandardCharsets.UTF_8); assertEquals("Numpy text file", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -379,7 +379,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jValidator.validateNumpyTxtFile(directory, " ", StandardCharsets.UTF_8); assertEquals("Numpy text file", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-numpy format: @@ -389,7 +389,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy text file", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue(s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr3.toString()); //Test corrupted txt format: @@ -405,7 +405,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("Numpy text file", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt")); + assertTrue(s.contains("text") && s.toLowerCase().contains("numpy") && s.contains("corrupt"),s); // System.out.println(vr4.toString()); @@ -419,10 +419,10 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testValidateSameDiff() throws Exception { + public void testValidateSameDiff(@TempDir Path testDir) throws Exception { Nd4j.setDataType(DataType.FLOAT); - File f = testDir.newFolder(); + File f = testDir.toFile(); SameDiff sd = SameDiff.create(); SDVariable v = sd.placeHolder("x", DataType.FLOAT, 3,4); SDVariable loss = v.std(true); @@ -436,7 +436,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr0 = Nd4jValidator.validateSameDiffFlatBuffers(fNonExistent); assertFalse(vr0.isValid()); assertEquals("SameDiff FlatBuffers file", vr0.getFormatType()); - assertTrue(vr0.getIssues().get(0), vr0.getIssues().get(0).contains("exist")); + assertTrue(vr0.getIssues().get(0).contains("exist"),vr0.getIssues().get(0)); // System.out.println(vr0.toString()); //Test empty file: @@ -446,7 +446,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr1 = Nd4jValidator.validateSameDiffFlatBuffers(fEmpty); assertEquals("SameDiff FlatBuffers file", vr1.getFormatType()); assertFalse(vr1.isValid()); - assertTrue(vr1.getIssues().get(0), vr1.getIssues().get(0).contains("empty")); + assertTrue(vr1.getIssues().get(0).contains("empty"),vr1.getIssues().get(0)); // System.out.println(vr1.toString()); //Test directory (not zip file) @@ -456,7 +456,7 @@ public class ValidationUtilTests extends BaseNd4jTest { ValidationResult vr2 = Nd4jValidator.validateSameDiffFlatBuffers(directory); assertEquals("SameDiff FlatBuffers file", vr2.getFormatType()); assertFalse(vr2.isValid()); - assertTrue(vr2.getIssues().get(0), vr2.getIssues().get(0).contains("directory")); + assertTrue(vr2.getIssues().get(0).contains("directory"),vr2.getIssues().get(0)); // System.out.println(vr2.toString()); //Test non-flatbuffers @@ -466,7 +466,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("SameDiff FlatBuffers file", vr3.getFormatType()); assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); - assertTrue(s, s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt")); + assertTrue(s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt"),s); // System.out.println(vr3.toString()); //Test corrupted flatbuffers format: @@ -481,7 +481,7 @@ public class ValidationUtilTests extends BaseNd4jTest { assertEquals("SameDiff FlatBuffers file", vr4.getFormatType()); assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); - assertTrue(s, s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt")); + assertTrue( s.contains("FlatBuffers") && s.contains("SameDiff") && s.contains("corrupt"),s); // System.out.println(vr4.toString()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index d3a0e4e26..f70c753fc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -22,10 +22,10 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -48,7 +48,7 @@ import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; import java.io.File; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.api.buffer.DataType.DOUBLE; @Slf4j @@ -77,12 +77,12 @@ public class BasicWorkspaceTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void setUp() { Nd4j.setDataType(DOUBLE); } - @After + @AfterEach public void shutdown() { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); @@ -712,7 +712,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { assertEquals(reqMem + reqMem % 16, workspace.getPrimaryOffset()); - assertEquals("Failed on iteration " + x, 10, array.sumNumber().doubleValue(), 0.01); + assertEquals(10, array.sumNumber().doubleValue(), 0.01,"Failed on iteration " + x); workspace.notifyScopeLeft(); @@ -960,7 +960,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - @Ignore + @Disabled public void testMmap2() throws Exception { // we don't support MMAP on cuda yet if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java index 436d0704d..c10115122 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -34,7 +34,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @RunWith(Parameterized.class) @@ -61,7 +61,7 @@ public class CudaWorkspaceTests extends BaseNd4jTest { try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test")) { final INDArray zeros = Nd4j.zeros(4, 'f'); //final INDArray zeros = Nd4j.create(4, 'f'); // Also fails, but maybe less of an issue as javadoc does not say that one can expect returned array to be all zeros. - assertEquals("Got non-zero array " + zeros + " after " + cnt + " iterations !", 0d, zeros.sumNumber().doubleValue(), 1e-10); + assertEquals( 0d, zeros.sumNumber().doubleValue(), 1e-10,"Got non-zero array " + zeros + " after " + cnt + " iterations !"); zeros.putScalar(0, 1); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java index 6a42d9c5b..3aaf5b23b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java @@ -22,8 +22,8 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -63,7 +63,7 @@ public class CyclicWorkspaceTests extends BaseNd4jTest { } @Test - @Ignore + @Disabled public void testGc() { val indArray = Nd4j.create(4, 4); indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0})); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java index 6ca9f872c..2b18ead2d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -51,12 +51,12 @@ public class DebugModeTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void turnMeUp() { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.DISABLED); } - @After + @AfterEach public void turnMeDown() { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.DISABLED); Nd4j.getMemoryManager().setCurrentWorkspace(null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java index 4807356ca..cce4562ca 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java @@ -23,10 +23,10 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -45,9 +45,9 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -@Ignore +@Disabled @Slf4j @RunWith(Parameterized.class) public class EndlessWorkspaceTests extends BaseNd4jTest { @@ -58,12 +58,12 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @Before + @BeforeEach public void startUp() { Nd4j.getMemoryManager().togglePeriodicGc(false); } - @After + @AfterEach public void shutUp() { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 80fcfa729..1abb24014 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -20,15 +20,10 @@ package org.nd4j.linalg.workspace; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertTrue; - import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -50,6 +45,8 @@ import java.nio.file.Files; import java.util.ArrayList; import java.util.Arrays; +import static org.junit.jupiter.api.Assertions.*; + @Slf4j @RunWith(Parameterized.class) public class SpecialWorkspaceTests extends BaseNd4jTest { @@ -60,7 +57,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @After + @AfterEach public void shutUp() { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); @@ -120,14 +117,14 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { Nd4j.create(500); } - assertEquals("Failed on iteration " + i, (i + 1) * workspace.getInitialBlockSize(), - workspace.getDeviceOffset()); + assertEquals((i + 1) * workspace.getInitialBlockSize(), + workspace.getDeviceOffset(),"Failed on iteration " + i); } if (e >= 2) { - assertEquals("Failed on iteration " + e, 0, workspace.getNumberOfPinnedAllocations()); + assertEquals(0, workspace.getNumberOfPinnedAllocations(),"Failed on iteration " + e); } else { - assertEquals("Failed on iteration " + e, 1, workspace.getNumberOfPinnedAllocations()); + assertEquals(1, workspace.getNumberOfPinnedAllocations(),"Failed on iteration " + e); } } @@ -407,25 +404,28 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { Files.delete(tmpFile); } - @Test(expected = IllegalArgumentException.class) + @Test() public void testDeleteMappedFile_2() throws Exception { - if (!Nd4j.getEnvironment().isCPU()) - throw new IllegalArgumentException("Don't try to run on CUDA"); + assertThrows(IllegalArgumentException.class,() -> { + if (!Nd4j.getEnvironment().isCPU()) + throw new IllegalArgumentException("Don't try to run on CUDA"); - val tmpFile = Files.createTempFile("some", "file"); - val mmap = WorkspaceConfiguration.builder() - .initialSize(200 * 1024L * 1024L) // 200mbs - .tempFilePath(tmpFile.toAbsolutePath().toString()) - .policyLocation(LocationPolicy.MMAP) - .build(); + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .build(); - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { - val x = Nd4j.rand(DataType.FLOAT, 1024); - } + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + Files.delete(tmpFile); + }); - Files.delete(tmpFile); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java index f07681efb..7d6141bfd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java @@ -22,9 +22,9 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.After; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; @@ -48,7 +48,7 @@ import java.io.DataOutputStream; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j @RunWith(Parameterized.class) @@ -115,7 +115,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { this.initialType = Nd4j.dataType(); } - @After + @AfterEach public void shutUp() { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); @@ -152,7 +152,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { assertEquals(5 * shiftedSize, ws1.getCurrentSize()); } else if (x < 4) { // we're making sure we're not initialize early - assertEquals("Failed on iteration " + x, 0, ws1.getCurrentSize()); + assertEquals(0, ws1.getCurrentSize(),"Failed on iteration " + x); } } @@ -529,7 +529,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { if (i == 3) { workspace.initializeWorkspace(); - assertEquals("Failed on iteration " + i, 100 * i * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); + assertEquals(100 * i * Nd4j.sizeOfDataType(), workspace.getCurrentSize(),"Failed on iteration " + i); } } @@ -543,7 +543,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } workspace.initializeWorkspace(); - assertEquals("Failed on final", 100 * 10 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); + assertEquals(100 * 10 * Nd4j.sizeOfDataType(), workspace.getCurrentSize(),"Failed on final"); } @Test @@ -558,7 +558,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } if (i >= 3) - assertEquals("Failed on iteration " + i, 100 * i * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); + assertEquals(100 * i * Nd4j.sizeOfDataType(), workspace.getCurrentSize(),"Failed on iteration " + i); else assertEquals(0, workspace.getCurrentSize()); } @@ -619,7 +619,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - @Ignore("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?") + @Disabled("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?") public void testNestedWorkspaces11() { for (int x = 1; x < 10; x++) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { @@ -990,7 +990,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { log.info("Done"); } - @Ignore + @Disabled @Test public void testMemcpy1() { INDArray warmUp = Nd4j.create(100000); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java index 2f07797a3..db3e84870 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java @@ -20,14 +20,14 @@ package org.nd4j.list; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class NDArrayListTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java index 8f2f2baa5..aa4fff5dc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java @@ -20,13 +20,13 @@ package org.nd4j.serde.base64; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class Nd4jBase64Test extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java index dc50a9f8d..78356eb3d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java @@ -21,7 +21,7 @@ package org.nd4j.serde.binary; import org.apache.commons.lang3.time.StopWatch; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.OpValidationSuite; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; @@ -36,7 +36,7 @@ import java.io.File; import java.nio.ByteBuffer; import java.util.UUID; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class BinarySerdeTest extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java index 74831083d..4f658e4fa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java @@ -24,7 +24,7 @@ package org.nd4j.smoketests; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java index 14ee512c4..095818e1c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java @@ -20,7 +20,7 @@ package org.nd4j.systeminfo; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; public class TestSystemInfo extends BaseND4JTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt b/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt index d4d72571b..6f728f79d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt @@ -21,7 +21,7 @@ package org.nd4j.linalg.custom import junit.framework.Assert.assertEquals -import org.junit.Ignore +import org.junit.jupiter.api.Disabled import org.junit.Test import org.nd4j.linalg.api.buffer.DataType import org.nd4j.linalg.api.ops.impl.image.CropAndResize @@ -34,7 +34,7 @@ import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner class CustomOpTensorflowInteropTests { @Test - @Ignore("Tensorflow expects different shape") + @Disabled("Tensorflow expects different shape") fun testCropAndResize() { val image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1) val boxes = Nd4j.createFromArray(*floatArrayOf(1f, 2f, 3f, 4f)).reshape(1, 4) diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/base/TestPreconditions.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/base/TestPreconditions.java index 9a4b4aa5a..45d16becf 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/base/TestPreconditions.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/base/TestPreconditions.java @@ -20,10 +20,10 @@ package org.nd4j.common.base; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; public class TestPreconditions { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java index bc77c6440..b4c86b2e9 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java @@ -20,13 +20,13 @@ package org.nd4j.common.function; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.function.FunctionalUtils; import org.nd4j.common.primitives.Pair; import java.util.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class FunctionalUtilsTest { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java index 30d621b37..b68bfd246 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java @@ -20,27 +20,27 @@ package org.nd4j.common.io; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class ClassPathResourceTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testDirExtractingIntelliJ() throws Exception { + public void testDirExtractingIntelliJ(@TempDir Path testDir) throws Exception { //https://github.com/deeplearning4j/deeplearning4j/issues/6483 ClassPathResource cpr = new ClassPathResource("somedir"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java index 1f576c6c2..e3878d8fd 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java @@ -21,32 +21,34 @@ package org.nd4j.common.loader; import org.apache.commons.io.FileUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.loader.FileBatch; import java.io.*; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.*; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class TestFileBatch { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Test - public void testFileBatch() throws Exception { - File baseDir = testDir.newFolder(); + public void testFileBatch(@TempDir Path testDir) throws Exception { + File baseDir = testDir.toFile(); List fileList = new ArrayList<>(); - for( int i=0; i<10; i++ ){ + for( int i = 0; i < 10; i++) { String s = "File contents - file " + i; File f = new File(baseDir, "origFile" + i + ".txt"); FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8); @@ -79,12 +81,12 @@ public class TestFileBatch { assertEquals(fb.getOriginalUris(), fb2.getOriginalUris()); assertEquals(10, fb2.getFileBytes().size()); - for( int i=0; i<10; i++ ){ + for( int i = 0; i < 10; i++) { assertArrayEquals(fb.getFileBytes().get(i), fb2.getFileBytes().get(i)); } //Check that it is indeed a valid zip file: - File f = testDir.newFile(); + File f = Files.createTempFile(testDir,"testfile","zip").toFile(); f.delete(); fb.writeAsZip(f); @@ -101,7 +103,7 @@ public class TestFileBatch { assertTrue(names.contains(FileBatch.ORIGINAL_PATHS_FILENAME)); for( int i=0; i<10; i++ ){ String n = "file_" + i + ".txt"; - assertTrue(n, names.contains(n)); + assertTrue(names.contains(n),n); } } diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/AtomicTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/AtomicTest.java index 14533fb4b..5743cd063 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/AtomicTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/AtomicTest.java @@ -21,14 +21,14 @@ package org.nd4j.common.primitives; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Atomic; import org.nd4j.common.util.SerializationUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class AtomicTest { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterMapTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterMapTest.java index 233ce03aa..764d03c6c 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterMapTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterMapTest.java @@ -20,13 +20,13 @@ package org.nd4j.common.primitives; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.CounterMap; import org.nd4j.common.primitives.Pair; import java.util.Iterator; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class CounterMapTest { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterTest.java index a633514c0..2f25f2446 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/primitives/CounterTest.java @@ -21,12 +21,12 @@ package org.nd4j.common.primitives; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Counter; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class CounterTest { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestArchiveUtils.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestArchiveUtils.java index b0317d46c..7de739aa3 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestArchiveUtils.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestArchiveUtils.java @@ -21,9 +21,10 @@ package org.nd4j.common.resources; import org.apache.commons.io.FileUtils; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.util.ArchiveUtils; import java.io.File; @@ -31,17 +32,16 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; public class TestArchiveUtils { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testUnzipFileTo() throws IOException { + public void testUnzipFileTo(@TempDir Path testDir) throws IOException { //random txt file - File dir = testDir.newFolder(); + File dir = testDir.toFile(); String content = "test file content"; String path = "myDir/myTestFile.txt"; File testFile = new File(dir, path); diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestStrumpf.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestStrumpf.java index 292f0b8b7..9a107be86 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestStrumpf.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/resources/TestStrumpf.java @@ -23,10 +23,11 @@ package org.nd4j.common.resources; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.strumpf.StrumpfResolver; @@ -36,14 +37,14 @@ import java.io.File; import java.io.FileReader; import java.io.Reader; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -@Ignore +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +@Disabled public class TestStrumpf { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + /* @Test public void testResolvingReference() throws Exception { @@ -80,9 +81,9 @@ public class TestStrumpf { } @Test - public void testResolveLocal() throws Exception { + public void testResolveLocal(@TempDir Path testDir) throws Exception { - File dir = testDir.newFolder(); + File dir = testDir.toFile(); String content = "test file content"; String path = "myDir/myTestFile.txt"; diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/BToolsTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/BToolsTest.java index cf0c18361..6d5821959 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/BToolsTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/BToolsTest.java @@ -20,10 +20,10 @@ package org.nd4j.common.tools; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tools.BTools; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class BToolsTest { // diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoLineTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoLineTest.java index 321adb1cb..d9bbfdde6 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoLineTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoLineTest.java @@ -20,11 +20,11 @@ package org.nd4j.common.tools; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tools.InfoLine; import org.nd4j.common.tools.InfoValues; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InfoLineTest { // diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java index 6e2f9da8d..ee40a1089 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java @@ -20,10 +20,10 @@ package org.nd4j.common.tools; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tools.InfoValues; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InfoValuesTest { // diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/PropertyParserTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/PropertyParserTest.java index eb46f6535..dc4aabb51 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/PropertyParserTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/PropertyParserTest.java @@ -21,13 +21,15 @@ package org.nd4j.common.tools; import java.util.Properties; -import org.junit.After; +import org.junit.jupiter.api.AfterEach; import org.junit.AfterClass; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; -import org.junit.Before; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tools.PropertyParser; /** @@ -40,7 +42,7 @@ public class PropertyParserTest { public PropertyParserTest() { } - @BeforeClass + @BeforeAll public static void setUpClass() { } @@ -48,11 +50,11 @@ public class PropertyParserTest { public static void tearDownClass() { } - @Before + @BeforeEach public void setUp() { } - @After + @AfterEach public void tearDown() { } @@ -1330,5 +1332,5 @@ public class PropertyParserTest { result = instance.toChar("nonexistent", 't'); assertEquals(expResult, result); } - + } diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java index 16956cdfc..835ef6c1a 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java @@ -20,30 +20,31 @@ package org.nd4j.common.tools; -import org.junit.After; -import org.junit.Test; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + + +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tools.SIS; -import static org.junit.Assert.*; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.*; public class SISTest { // - @Rule - public TemporaryFolder tmpFld = new TemporaryFolder(); // private SIS sis; // @Test - public void testAll() throws Exception { + public void testAll(@TempDir Path tmpFld) throws Exception { // sis = new SIS(); // int mtLv = 0; // - sis.initValues( mtLv, "TEST", System.out, System.err, tmpFld.getRoot().getAbsolutePath(), "Test", "ABC", true, true ); + sis.initValues( mtLv, "TEST", System.out, System.err, tmpFld.getRoot().toAbsolutePath().toString(), "Test", "ABC", true, true ); // String fFName = sis.getfullFileName(); sis.info( fFName ); @@ -57,7 +58,7 @@ public class SISTest { // } - @After + @AfterEach public void after() { // int mtLv = 0; diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/ArrayUtilTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/ArrayUtilTest.java index 4994db145..6bb2f1173 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/ArrayUtilTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/ArrayUtilTest.java @@ -20,9 +20,9 @@ package org.nd4j.common.util; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.util.ArrayUtil; public class ArrayUtilTest { diff --git a/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/OneTimeLoggerTest.java b/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/OneTimeLoggerTest.java index 4011ca455..2faa01638 100644 --- a/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/OneTimeLoggerTest.java +++ b/nd4j/nd4j-common/src/test/java/org/nd4j/common/util/OneTimeLoggerTest.java @@ -21,11 +21,11 @@ package org.nd4j.common.util; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.util.OneTimeLogger; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class OneTimeLoggerTest { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 4f5a15cf6..2829a0709 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -26,10 +26,7 @@ import io.aeron.driver.ThreadingMode; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,7 +35,7 @@ import org.nd4j.parameterserver.client.ParameterServerClient; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class RemoteParameterServerClientTests extends BaseND4JTest { @@ -49,7 +46,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { private AtomicInteger slaveStatus = new AtomicInteger(0); private Aeron aeron; - @Before + @BeforeEach public void before() throws Exception { final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true) @@ -86,13 +83,15 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { } - @After + @AfterEach public void after() throws Exception { CloseHelper.close(mediaDriver); CloseHelper.close(aeron); } - @Test(timeout = 60000L) @Ignore //AB 20200425 https://github.com/eclipse/deeplearning4j/issues/8882 + @Test() + @Timeout(60000L) + @Disabled //AB 20200425 https://github.com/eclipse/deeplearning4j/issues/8882 public void remoteTests() throws Exception { if (masterStatus.get() != 0 || slaveStatus.get() != 0) throw new IllegalStateException("Master or slave failed to start. Exiting"); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java index d2faa3982..ded84b68e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java @@ -26,8 +26,10 @@ import io.aeron.driver.ThreadingMode; import lombok.extern.slf4j.Slf4j; import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; @@ -37,8 +39,8 @@ import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; import static junit.framework.TestCase.assertFalse; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class ParameterServerClientPartialTest extends BaseND4JTest { @@ -48,13 +50,13 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { private int[] shape = {2, 2}; private static Aeron aeron; - @BeforeClass + @BeforeAll public static void beforeClass() throws Exception { final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true) - .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) - .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true) + .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) + .receiverIdleStrategy(new BusySpinIdleStrategy()) + .senderIdleStrategy(new BusySpinIdleStrategy()); mediaDriver = MediaDriver.launchEmbedded(ctx); aeron = Aeron.connect(getContext()); @@ -63,8 +65,8 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { int masterPort = 40223 + new java.util.Random().nextInt(13000); int masterStatusPort = masterPort - 2000; masterNode.run(new String[] {"-m", "true", "-p", String.valueOf(masterPort), "-h", "localhost", "-id", "11", - "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(masterStatusPort), "-s", "2,2", - "-u", String.valueOf(1) + "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(masterStatusPort), "-s", "2,2", + "-u", String.valueOf(1) }); @@ -80,8 +82,8 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { int slavePort = masterPort + 100; int slaveStatusPort = slavePort - 2000; slaveNode.run(new String[] {"-p", String.valueOf(slavePort), "-h", "localhost", "-id", "10", "-pm", - masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", - String.valueOf(slaveStatusPort), "-u", String.valueOf(1) + masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", + String.valueOf(slaveStatusPort), "-u", String.valueOf(1) }); @@ -105,13 +107,14 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { } - @Test(timeout = 60000L) - @Ignore("AB 2019/06/01 - Intermittent failures - see issue 7657") + @Test() + @Timeout(60000L) + @Disabled("AB 2019/06/01 - Intermittent failures - see issue 7657") public void testServer() throws Exception { ParameterServerClient client = ParameterServerClient.builder().aeron(aeron) - .ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()) - .ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost") - .subscriberPort(40325).subscriberStream(12).build(); + .ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()) + .ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost") + .subscriberPort(40325).subscriberStream(12).build(); assertEquals("localhost:40325:12", client.connectionUrl()); //flow 1: /** @@ -137,10 +140,10 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { private static Aeron.Context getContext() { if (ctx == null) ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) - .availableImageHandler(AeronUtil::printAvailableImage) - .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000) - .errorHandler(e -> log.error(e.toString(), e)); + .availableImageHandler(AeronUtil::printAvailableImage) + .unavailableImageHandler(AeronUtil::printUnavailableImage) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000) + .errorHandler(e -> log.error(e.toString(), e)); return ctx; } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java index 6c8564d1c..8044be0a5 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java @@ -23,8 +23,10 @@ package org.nd4j.parameterserver.client; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -35,8 +37,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static junit.framework.TestCase.assertFalse; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ParameterServerClientTest extends BaseND4JTest { private static MediaDriver mediaDriver; @@ -45,7 +47,7 @@ public class ParameterServerClientTest extends BaseND4JTest { private static ParameterServerSubscriber masterNode, slaveNode; private static int parameterLength = 1000; - @BeforeClass + @BeforeAll public static void beforeClass() throws Exception { mediaDriver = MediaDriver.launchEmbedded(AeronUtil.getMediaDriverContext(parameterLength)); System.setProperty("play.server.dir", "/tmp"); @@ -54,8 +56,8 @@ public class ParameterServerClientTest extends BaseND4JTest { masterNode.setAeron(aeron); int masterPort = 40323 + new java.util.Random().nextInt(3000); masterNode.run(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", - String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md", - mediaDriver.aeronDirectoryName(), "-sp", "33000", "-u", String.valueOf(1)}); + String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md", + mediaDriver.aeronDirectoryName(), "-sp", "33000", "-u", String.valueOf(1)}); assertTrue(masterNode.isMaster()); assertEquals(masterPort, masterNode.getPort()); @@ -66,8 +68,8 @@ public class ParameterServerClientTest extends BaseND4JTest { slaveNode = new ParameterServerSubscriber(mediaDriver); slaveNode.setAeron(aeron); slaveNode.run(new String[] {"-p", String.valueOf(masterPort + 100), "-h", "localhost", "-id", "10", "-pm", - masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", - "31000", "-u", String.valueOf(1)}); + masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", + "31000", "-u", String.valueOf(1)}); assertFalse(slaveNode.isMaster()); assertEquals(masterPort + 100, slaveNode.getPort()); @@ -90,14 +92,15 @@ public class ParameterServerClientTest extends BaseND4JTest { - @Test(timeout = 60000L) - @Ignore("AB 2019/05/31 - Intermittent failures on CI - see issue 7657") + @Test() + @Timeout(60000L) + @Disabled("AB 2019/05/31 - Intermittent failures on CI - see issue 7657") public void testServer() throws Exception { int subscriberPort = 40625 + new java.util.Random().nextInt(100); ParameterServerClient client = ParameterServerClient.builder().aeron(aeron) - .ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()) - .ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost") - .subscriberPort(subscriberPort).subscriberStream(12).build(); + .ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()) + .ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost") + .subscriberPort(subscriberPort).subscriberStream(12).build(); assertEquals(String.format("localhost:%d:12", subscriberPort), client.connectionUrl()); //flow 1: /** @@ -120,10 +123,10 @@ public class ParameterServerClientTest extends BaseND4JTest { private static Aeron.Context getContext() { return new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) - .availableImageHandler(AeronUtil::printAvailableImage) - .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) - .errorHandler(e -> log.error(e.toString(), e)); + .availableImageHandler(AeronUtil::printAvailableImage) + .unavailableImageHandler(AeronUtil::printUnavailableImage) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) + .errorHandler(e -> log.error(e.toString(), e)); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java index eccba224e..4129c041d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java @@ -22,10 +22,7 @@ package org.nd4j.parameterserver.distributed; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -49,20 +46,20 @@ import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @Deprecated public class VoidParameterServerStressTest extends BaseND4JTest { private static final int NUM_WORDS = 100000; - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } @@ -71,7 +68,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { * This test measures performance of blocking messages processing, VectorRequestMessage in this case */ @Test - @Ignore + @Disabled public void testPerformanceStandalone1() { VoidConfiguration voidConfiguration = VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build(); @@ -132,7 +129,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { * This test measures performance of non-blocking messages processing, SkipGramRequestMessage in this case */ @Test - @Ignore + @Disabled public void testPerformanceStandalone2() { VoidConfiguration voidConfiguration = VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build(); @@ -193,7 +190,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { @Test - @Ignore + @Disabled public void testPerformanceMulticast1() throws Exception { VoidConfiguration voidConfiguration = VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build(); @@ -288,7 +285,8 @@ public class VoidParameterServerStressTest extends BaseND4JTest { /** * This is one of the MOST IMPORTANT tests */ - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testPerformanceUnicast1() { List list = new ArrayList<>(); for (int t = 0; t < 1; t++) { @@ -386,7 +384,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { * Here we send non-blocking messages */ @Test - @Ignore + @Disabled public void testPerformanceUnicast2() { List list = new ArrayList<>(); for (int t = 0; t < 5; t++) { @@ -488,7 +486,8 @@ public class VoidParameterServerStressTest extends BaseND4JTest { * * @throws Exception */ - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testPerformanceUnicast3() throws Exception { VoidConfiguration voidConfiguration = VoidConfiguration.builder().numberOfShards(1) .shardAddresses(Arrays.asList("127.0.0.1:49823")).build(); @@ -534,7 +533,8 @@ public class VoidParameterServerStressTest extends BaseND4JTest { * * @throws Exception */ - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testPerformanceUnicast4() throws Exception { VoidConfiguration voidConfiguration = VoidConfiguration.builder().numberOfShards(1) .shardAddresses(Arrays.asList("127.0.0.1:49823")).build(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java index 7b29c7d6a..964445b82 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java @@ -21,10 +21,7 @@ package org.nd4j.parameterserver.distributed; import lombok.extern.slf4j.Slf4j; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -51,17 +48,17 @@ import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @Deprecated public class VoidParameterServerTest extends BaseND4JTest { private static List localIPs; private static List badIPs; private static final Transport transport = new MulticastTransport(); - @Before + @BeforeEach public void setUp() throws Exception { if (localIPs == null) { localIPs = new ArrayList<>(VoidParameterServer.getLocalAddresses()); @@ -70,12 +67,13 @@ public class VoidParameterServerTest extends BaseND4JTest { } } - @After + @AfterEach public void tearDown() throws Exception { } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testNodeRole1() throws Exception { final VoidConfiguration conf = VoidConfiguration.builder().multicastPort(45678) .numberOfShards(10).multicastNetwork("224.0.1.1").shardAddresses(localIPs).ttl(4).build(); @@ -88,7 +86,8 @@ public class VoidParameterServerTest extends BaseND4JTest { node.shutdown(); } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testNodeRole2() throws Exception { final VoidConfiguration conf = VoidConfiguration.builder().multicastPort(45678) .numberOfShards(10).shardAddresses(badIPs).backupAddresses(localIPs) @@ -102,7 +101,8 @@ public class VoidParameterServerTest extends BaseND4JTest { node.shutdown(); } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testNodeRole3() throws Exception { final VoidConfiguration conf = VoidConfiguration.builder().multicastPort(45678) .numberOfShards(10).shardAddresses(badIPs).backupAddresses(badIPs).multicastNetwork("224.0.1.1") @@ -116,7 +116,8 @@ public class VoidParameterServerTest extends BaseND4JTest { node.shutdown(); } - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testNodeInitialization1() throws Exception { final AtomicInteger failCnt = new AtomicInteger(0); final AtomicInteger passCnt = new AtomicInteger(0); @@ -162,7 +163,8 @@ public class VoidParameterServerTest extends BaseND4JTest { * * @throws Exception */ - @Test(timeout = 60000L) + @Test() + @Timeout(60000L) public void testNodeInitialization2() throws Exception { final AtomicInteger failCnt = new AtomicInteger(0); final AtomicInteger passCnt = new AtomicInteger(0); @@ -251,8 +253,8 @@ public class VoidParameterServerTest extends BaseND4JTest { // now we check message queue within Shards for (int t = 0; t < threads.length; t++) { VoidMessage incMessage = shards[t].getTransport().takeMessage(); - assertNotEquals("Failed for shard " + t, null, incMessage); - assertEquals("Failed for shard " + t, message.getMessageType(), incMessage.getMessageType()); + assertNotEquals( null, incMessage,"Failed for shard " + t); + assertEquals(message.getMessageType(), incMessage.getMessageType(),"Failed for shard " + t); // we should put message back to corresponding shards[t].getTransport().putMessage(incMessage); @@ -269,7 +271,7 @@ public class VoidParameterServerTest extends BaseND4JTest { for (int t = 0; t < threads.length; t++) { VoidMessage incMessage = shards[t].getTransport().takeMessage(); - assertNotEquals("Failed for shard " + t, null, incMessage); + assertNotEquals(null, incMessage,"Failed for shard " + t); shards[t].handleMessage(message); /** @@ -415,7 +417,8 @@ public class VoidParameterServerTest extends BaseND4JTest { * * @throws Exception */ - @Test(timeout = 60000L) + @Test + @Timeout(60000L) public void testNodeInitialization3() throws Exception { final AtomicInteger failCnt = new AtomicInteger(0); final AtomicInteger passCnt = new AtomicInteger(0); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java index 4785a586b..fc5479974 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java @@ -20,20 +20,19 @@ package org.nd4j.parameterserver.distributed.conf; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.exception.ND4JIllegalStateException; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled public class VoidConfigurationTest extends BaseND4JTest { - @Rule - public Timeout globalTimeout = Timeout.seconds(30); + @Test public void testNetworkMask1() throws Exception { @@ -68,20 +67,26 @@ public class VoidConfigurationTest extends BaseND4JTest { assertEquals("192.168.0.0/8", configuration.getNetworkMask()); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNetworkMask3() throws Exception { - VoidConfiguration configuration = new VoidConfiguration(); - configuration.setNetworkMask("192.256.1.1/24"); + assertThrows(ND4JIllegalStateException.class,() -> { + VoidConfiguration configuration = new VoidConfiguration(); + configuration.setNetworkMask("192.256.1.1/24"); + + assertEquals("192.168.1.0/24", configuration.getNetworkMask()); + }); - assertEquals("192.168.1.0/24", configuration.getNetworkMask()); } - @Test(expected = ND4JIllegalStateException.class) + @Test() public void testNetworkMask4() throws Exception { - VoidConfiguration configuration = new VoidConfiguration(); - configuration.setNetworkMask("0.0.0.0/8"); + assertThrows(ND4JIllegalStateException.class,() -> { + VoidConfiguration configuration = new VoidConfiguration(); + configuration.setNetworkMask("0.0.0.0/8"); + + assertEquals("192.168.1.0/24", configuration.getNetworkMask()); + }); - assertEquals("192.168.1.0/24", configuration.getNetworkMask()); } @Override diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java index 3efe9e290..4f703ca32 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java @@ -21,8 +21,7 @@ package org.nd4j.parameterserver.distributed.logic; import lombok.extern.slf4j.Slf4j; -import org.junit.*; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; @@ -32,24 +31,23 @@ import org.nd4j.parameterserver.distributed.messages.VoidAggregation; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @Deprecated public class ClipboardTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } - @Rule - public Timeout globalTimeout = Timeout.seconds(30); + @Test public void testPin1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java index 3a7982a3f..c3e1c6d26 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java @@ -20,26 +20,25 @@ package org.nd4j.parameterserver.distributed.logic; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled @Deprecated public class FrameCompletionHandlerTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } - @Rule - public Timeout globalTimeout = Timeout.seconds(30); + /** * This test emulates 2 frames being processed at the same time diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java index a7ed8d614..cfa7d9afa 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java @@ -20,11 +20,11 @@ package org.nd4j.parameterserver.distributed.logic.routing; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; + +import org.junit.jupiter.api.Test; + import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.util.HashUtil; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -36,16 +36,16 @@ import org.nd4j.parameterserver.distributed.transport.Transport; import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled @Deprecated public class InterleavedRouterTest extends BaseND4JTest { VoidConfiguration configuration; Transport transport; long originator; - @Before + @BeforeEach public void setUp() { configuration = VoidConfiguration.builder() .shardAddresses(Arrays.asList("1.2.3.4", "2.3.4.5", "3.4.5.6", "4.5.6.7")).numberOfShards(4) // we set it manually here @@ -56,8 +56,7 @@ public class InterleavedRouterTest extends BaseND4JTest { originator = HashUtil.getLongHash(transport.getIp() + ":" + transport.getPort()); } - @Rule - public Timeout globalTimeout = Timeout.seconds(30); + /** * Testing default assignment for everything, but training requests diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java index 077f38af9..97d4b4de5 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java @@ -21,9 +21,10 @@ package org.nd4j.parameterserver.distributed.messages; import org.agrona.concurrent.UnsafeBuffer; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; @@ -35,12 +36,12 @@ import org.nd4j.parameterserver.distributed.transport.Transport; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled @Deprecated public class FrameTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } @@ -48,7 +49,8 @@ public class FrameTest extends BaseND4JTest { /** * Simple test for Frame functionality */ - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testFrame1() { final AtomicInteger count = new AtomicInteger(0); @@ -163,7 +165,8 @@ public class FrameTest extends BaseND4JTest { } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testJoin1() throws Exception { SkipGramRequestMessage sgrm = new SkipGramRequestMessage(0, 1, new int[] {3, 4, 5}, new byte[] {0, 1, 0}, (short) 0, 0.01, 119L); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java index 9215d8a34..a710e87bf 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java @@ -20,29 +20,27 @@ package org.nd4j.parameterserver.distributed.messages; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled @Deprecated public class VoidMessageTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testSerDe1() throws Exception { SkipGramRequestMessage message = new SkipGramRequestMessage(10, 12, new int[] {10, 20, 30, 40}, new byte[] {(byte) 0, (byte) 0, (byte) 1, (byte) 0}, (short) 0, 0.0, 117L); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java index a6d51390d..2e99bddf6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java @@ -21,8 +21,7 @@ package org.nd4j.parameterserver.distributed.messages.aggregations; import lombok.extern.slf4j.Slf4j; -import org.junit.*; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -30,27 +29,26 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @Deprecated public class VoidAggregationTest extends BaseND4JTest { private static final short NODES = 100; private static final int ELEMENTS_PER_NODE = 3; - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } - @Rule - public Timeout globalTimeout = Timeout.seconds(30); + /** * In this test we check for aggregation of sample vector. diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java index b11550fe4..068945181 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java @@ -20,10 +20,7 @@ package org.nd4j.parameterserver.distributed.transport; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; @@ -37,17 +34,17 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -@Ignore +@Disabled @Deprecated public class RoutedTransportTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } @@ -58,7 +55,8 @@ public class RoutedTransportTest extends BaseND4JTest { * * @throws Exception */ - @Test(timeout = 30000) + @Test() + @Timeout(30000) public void testMessaging1() throws Exception { List list = new ArrayList<>(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java index 9fd81a118..f20b3634e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java @@ -22,29 +22,26 @@ package org.nd4j.parameterserver.distributed.util; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; -import org.junit.*; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import java.util.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled public class NetworkOrganizerTest extends BaseND4JTest { - @Before + @BeforeEach public void setUp() throws Exception { } - @After + @AfterEach public void tearDown() throws Exception { } - @Rule - public Timeout globalTimeout = Timeout.seconds(20); // 20 seconds max per method tested @Test @@ -385,7 +382,7 @@ public class NetworkOrganizerTest extends BaseND4JTest { } @Test - @Ignore("AB 2019/05/30 - Intermittent issue or flaky test - see issue #7657") + @Disabled("AB 2019/05/30 - Intermittent issue or flaky test - see issue #7657") public void testNetTree6() throws Exception { List ips = new ArrayList<>(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java index 64ddad2db..78ce59bde 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java @@ -24,10 +24,7 @@ import io.reactivex.functions.Consumer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -49,26 +46,27 @@ import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@Ignore +@Disabled public class DelayedModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; - @Before + @BeforeEach public void setUp() throws Exception { MessageSplitter.getInstance().reset(); } - @After + @AfterEach public void setDown() throws Exception { MessageSplitter.getInstance().reset(); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testBasicInitialization_1() throws Exception { val connector = new DummyTransport.Connector(); val rootTransport = new DelayedDummyTransport(rootId, connector); @@ -83,7 +81,8 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { rootServer.shutdown(); } - @Test(timeout = 40000L) + @Test() + @Timeout(40000L) public void testBasicInitialization_2() throws Exception { for (int e = 0; e < 100; e++) { val connector = new DummyTransport.Connector(); @@ -107,17 +106,17 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { val meshA = clientTransportA.getMesh(); val meshB = clientTransportB.getMesh(); - assertEquals("Root node failed",3, meshR.totalNodes()); - assertEquals("B node failed", 3, meshB.totalNodes()); - assertEquals("A node failed", 3, meshA.totalNodes()); + assertEquals(3, meshR.totalNodes(),"Root node failed"); + assertEquals(3, meshB.totalNodes(),"B node failed"); + assertEquals(3, meshA.totalNodes(),"A node failed"); assertEquals(meshR, meshA); assertEquals(meshA, meshB); log.info("Iteration [{}] finished", e); } } - - @Test(timeout = 180000L) + @Test() + @Timeout(180000L) public void testUpdatesPropagation_1() throws Exception { val conf = VoidConfiguration.builder().meshBuildMode(MeshBuildMode.PLAIN).build(); val array = Nd4j.ones(10, 10); @@ -171,11 +170,12 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { for (int e = 0; e < servers.size(); e++) { val s = servers.get(e); - assertEquals("Failed at node [" + e + "]", 1, s.getUpdates().size()); + assertEquals(1, s.getUpdates().size(),"Failed at node [" + e + "]"); } } - @Test(timeout = 180000L) + @Test() + @Timeout(180000L) public void testModelAndUpdaterParamsUpdate_1() throws Exception { val config = VoidConfiguration.builder().meshBuildMode(MeshBuildMode.PLAIN).build(); val connector = new DummyTransport.Connector(); @@ -300,7 +300,7 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { // we're skipping node 23 since it was reconnected, and has different MPS instance // and node 96, since it sends update if (e != 23 && e != 96) - assertEquals("Failed at node: [" + e + "]", 1, counters[e].get()); + assertEquals(1, counters[e].get(),"Failed at node: [" + e + "]"); } assertTrue(updatedModel.get()); @@ -308,7 +308,8 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { assertTrue(gotGradients.get()); } - @Test(timeout = 180000L) + @Test() + @Timeout(180000L) public void testMeshConsistency_1() throws Exception { Nd4j.create(1); final int numMessages = 500; @@ -383,12 +384,13 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { // now we're checking all nodes, they should get numMessages - messages that were sent through them for (int e = 0; e < servers.size(); e++) { val server = servers.get(e); - assertEquals("Failed at node: [" + e + "]", numMessages - deductions[e], counters[e].get()); + assertEquals(numMessages - deductions[e], counters[e].get(),"Failed at node: [" + e + "]"); } } - @Test(timeout = 180000L) + @Test() + @Timeout(180000L) public void testMeshConsistency_2() throws Exception { Nd4j.create(1); final int numMessages = 100; @@ -468,7 +470,7 @@ public class DelayedModelParameterServerTest extends BaseND4JTest { // now we're checking all nodes, they should get numMessages - messages that were sent through them for (int e = 0; e < servers.size(); e++) { val server = servers.get(e); - assertEquals("Failed at node: [" + e + "]", numMessages - deductions[e], counters[e].get()); + assertEquals( numMessages - deductions[e], counters[e].get(),"Failed at node: [" + e + "]"); } } } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java index 1dc399684..413830d6d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java @@ -23,7 +23,8 @@ package org.nd4j.parameterserver.distributed.v2; import io.reactivex.functions.Consumer; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -44,13 +45,14 @@ import java.util.ArrayList; import java.util.concurrent.LinkedTransferQueue; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class ModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testBasicInitialization_1() throws Exception { val connector = new DummyTransport.Connector(); val rootTransport = new DummyTransport(rootId, connector); @@ -65,7 +67,8 @@ public class ModelParameterServerTest extends BaseND4JTest { rootServer.shutdown(); } - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testBasicInitialization_2() throws Exception { val connector = new DummyTransport.Connector(); val rootTransport = new DummyTransport(rootId, connector); @@ -436,7 +439,7 @@ public class ModelParameterServerTest extends BaseND4JTest { failedCnt++; } - assertEquals("Some nodes got no updates:", 0, failedCnt); + assertEquals(0, failedCnt,"Some nodes got no updates:"); assertTrue(updatedModel.get()); assertTrue(gotGradients.get()); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java index 245043fa5..32a2cca99 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java @@ -22,8 +22,8 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; @@ -32,10 +32,10 @@ import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter; import java.util.ArrayList; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled public class FileChunksTrackerTest extends BaseND4JTest { @Test public void testTracker_1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java index b152f00eb..ad2de54ba 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java @@ -21,8 +21,8 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; @@ -31,11 +31,11 @@ import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter; import java.util.ArrayList; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class InmemoryChunksTrackerTest extends BaseND4JTest { @Test - @Ignore + @Disabled public void testTracker_1() throws Exception { val array = Nd4j.linspace(1, 100000, 10000).reshape(-1, 1000); val splitter = MessageSplitter.getInstance(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java index ce9ec9ff6..fc7b5d496 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java @@ -22,12 +22,12 @@ package org.nd4j.parameterserver.distributed.v2.messages; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeRequest; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class VoidMessageTest extends BaseND4JTest { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java index 602c1361c..b9b93e299 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java @@ -22,10 +22,10 @@ package org.nd4j.parameterserver.distributed.v2.messages.history; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class HashHistoryHolderTest extends BaseND4JTest { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java index 4b2d69dcd..1aa250d87 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java @@ -22,12 +22,12 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class AeronUdpTransportTest extends BaseND4JTest { @@ -40,7 +40,7 @@ public class AeronUdpTransportTest extends BaseND4JTest { } @Test - @Ignore + @Disabled public void testBasic_Connection_1() throws Exception { // we definitely want to shutdown all transports after test, to avoid issues with shmem try(val transportA = new AeronUdpTransport(IP, ROOT_PORT, IP, ROOT_PORT, VoidConfiguration.builder().build()); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java index 45a6a5f04..4fed8645c 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java @@ -22,7 +22,7 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode; @@ -34,7 +34,7 @@ import org.nd4j.parameterserver.distributed.v2.transport.MessageCallable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class DummyTransportTest extends BaseND4JTest { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java index 1542aca78..33703356d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java @@ -22,7 +22,8 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.enums.MeshBuildMode; @@ -31,12 +32,13 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.util.ArrayList; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class MeshOrganizerTest extends BaseND4JTest { - @Test(timeout = 1000L) + @Test() + @Timeout(1000L) public void testDescendantsCount_1() { val node = MeshOrganizer.Node.builder().build(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java index aa58a8cf2..17e9dd7aa 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java @@ -22,7 +22,7 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Atomic; @@ -31,7 +31,7 @@ import org.nd4j.parameterserver.distributed.v2.messages.impl.GradientsUpdateMess import java.util.ArrayList; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class MessageSplitterTest extends BaseND4JTest { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java index d5aed6eae..b16e10d31 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java @@ -24,8 +24,9 @@ import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import lombok.extern.slf4j.Slf4j; import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; @@ -35,10 +36,10 @@ import org.nd4j.parameterserver.client.ParameterServerClient; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @Slf4j -@Ignore +@Disabled @Deprecated public class ParameterServerNodeTest extends BaseND4JTest { private static MediaDriver mediaDriver; @@ -48,16 +49,16 @@ public class ParameterServerNodeTest extends BaseND4JTest { private static int masterStatusPort = 40323 + new java.util.Random().nextInt(15999); private static int statusPort = masterStatusPort - 1299; - @BeforeClass + @BeforeAll public static void before() throws Exception { mediaDriver = MediaDriver.launchEmbedded(AeronUtil.getMediaDriverContext(parameterLength)); System.setProperty("play.server.dir", "/tmp"); aeron = Aeron.connect(getContext()); parameterServerNode = new ParameterServerNode(mediaDriver, statusPort); parameterServerNode.runMain(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", - String.valueOf(masterStatusPort), "-h", "localhost", "-id", "11", "-md", - mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(statusPort), "-sh", "localhost", "-u", - String.valueOf(Runtime.getRuntime().availableProcessors())}); + String.valueOf(masterStatusPort), "-h", "localhost", "-id", "11", "-md", + mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(statusPort), "-sh", "localhost", "-u", + String.valueOf(Runtime.getRuntime().availableProcessors())}); while (!parameterServerNode.subscriberLaunched()) { Thread.sleep(10000); @@ -73,11 +74,11 @@ public class ParameterServerNodeTest extends BaseND4JTest { String host = "localhost"; for (int i = 0; i < numCores; i++) { clients[i] = ParameterServerClient.builder().aeron(aeron).masterStatusHost(host) - .masterStatusPort(statusPort).subscriberHost(host).subscriberPort(40325 + i) - .subscriberStream(10 + i) - .ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()) - .ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()) - .build(); + .masterStatusPort(statusPort).subscriberHost(host).subscriberPort(40325 + i) + .subscriberStream(10 + i) + .ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()) + .ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()) + .build(); } Thread.sleep(60000); @@ -119,10 +120,10 @@ public class ParameterServerNodeTest extends BaseND4JTest { private static Aeron.Context getContext() { return new Aeron.Context().driverTimeoutMs(10000) - .availableImageHandler(AeronUtil::printAvailableImage) - .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) - .errorHandler(e -> log.error(e.toString(), e)); + .availableImageHandler(AeronUtil::printAvailableImage) + .unavailableImageHandler(AeronUtil::printUnavailableImage) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) + .errorHandler(e -> log.error(e.toString(), e)); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index 9005be0e6..4e3c688d8 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -20,7 +20,8 @@ package org.nd4j.parameterserver.updater.storage; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -29,7 +30,8 @@ import static junit.framework.TestCase.assertEquals; public class UpdaterStorageTests extends BaseND4JTest { - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testInMemory() { UpdateStorage updateStorage = new RocksDbStorage("/tmp/rocksdb"); NDArrayMessage message = NDArrayMessage.wholeArrayUpdate(Nd4j.scalar(1.0)); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java index 419a1a871..4d65842c0 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java @@ -20,13 +20,15 @@ package org.nd4j.parameterserver.status.play; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import play.server.Server; public class StatusServerTests extends BaseND4JTest { - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void runStatusServer() { Server server = StatusServer.startServer(new InMemoryStatusStorage(), 65236); server.stop(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java index f0ed8a206..7d0ac67c8 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java @@ -20,16 +20,18 @@ package org.nd4j.parameterserver.status.play; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.parameterserver.model.SubscriberState; import static junit.framework.TestCase.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; public class StorageTests extends BaseND4JTest { - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testMapStorage() throws Exception { StatusStorage mapDb = new MapDbStatusStorage(); assertEquals(SubscriberState.empty(), mapDb.getState(-1)); @@ -44,7 +46,8 @@ public class StorageTests extends BaseND4JTest { } - @Test(timeout = 20000L) + @Test() + @Timeout(20000L) public void testStorage() throws Exception { StatusStorage statusStorage = new InMemoryStatusStorage(); assertEquals(SubscriberState.empty(), statusStorage.getState(-1)); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java index 1d60f9c20..37b995b16 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java @@ -20,20 +20,22 @@ package org.nd4j.parameterserver.updater; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.updater.storage.NoUpdateStorage; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.Assume.assumeNotNull; public class ParameterServerUpdaterTests extends BaseND4JTest { - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void synchronousTest() { int cores = Runtime.getRuntime().availableProcessors(); ParameterServerUpdater updater = new SynchronousParameterUpdater(new NoUpdateStorage(), diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index df12a9b41..1efbc3e09 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -20,27 +20,33 @@ package org.nd4j.parameterserver.updater.storage; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; import static junit.framework.TestCase.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; public class UpdaterStorageTests extends BaseND4JTest { - @Test(expected = UnsupportedOperationException.class) + @Test() public void testNone() { - UpdateStorage updateStorage = new NoUpdateStorage(); - NDArrayMessage message = NDArrayMessage.wholeArrayUpdate(Nd4j.scalar(1.0)); - updateStorage.addUpdate(message); - assertEquals(1, updateStorage.numUpdates()); - assertEquals(message, updateStorage.getUpdate(0)); - updateStorage.close(); + assertThrows(UnsupportedOperationException.class,() -> { + UpdateStorage updateStorage = new NoUpdateStorage(); + NDArrayMessage message = NDArrayMessage.wholeArrayUpdate(Nd4j.scalar(1.0)); + updateStorage.addUpdate(message); + assertEquals(1, updateStorage.numUpdates()); + assertEquals(message, updateStorage.getUpdate(0)); + updateStorage.close(); + }); + } - @Test(timeout = 30000L) + @Test() + @Timeout(30000L) public void testInMemory() { UpdateStorage updateStorage = new InMemoryUpdateStorage(); NDArrayMessage message = NDArrayMessage.wholeArrayUpdate(Nd4j.scalar(1.0)); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java index d02bfce2a..0fc74a500 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java @@ -22,8 +22,8 @@ package org.nd4j.aeron.ipc; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.lang3.time.StopWatch; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,10 +33,10 @@ import java.io.BufferedOutputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class AeronNDArraySerdeTest extends BaseND4JTest { @@ -62,7 +62,7 @@ public class AeronNDArraySerdeTest extends BaseND4JTest { @Test - @Ignore // timeout, skip step ignored + @Disabled // timeout, skip step ignored public void testToAndFromCompressedLarge() { skipUnlessIntegrationTests(); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index 71c25662c..3862da096 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -24,10 +24,10 @@ import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,11 +35,11 @@ import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; import java.util.concurrent.atomic.AtomicBoolean; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertFalse; @Slf4j @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class LargeNdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private Aeron.Context ctx; @@ -52,7 +52,7 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { return 180000L; } - @Before + @BeforeEach public void before() { if(isIntegrationTests()) { //MediaDriver.loadPropertiesFile("aeron.properties"); @@ -63,7 +63,7 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { } } - @After + @AfterEach public void after() { if(isIntegrationTests()) { CloseHelper.quietClose(mediaDriver); @@ -71,7 +71,7 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { } @Test - @Ignore + @Disabled public void testMultiThreadedIpcBig() throws Exception { skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java index fc0d76cc0..1bab3cc6b 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java @@ -21,18 +21,18 @@ package org.nd4j.aeron.ipc; import org.agrona.DirectBuffer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class NDArrayMessageTest extends BaseND4JTest { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index cf1631f00..20620b32c 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -23,10 +23,10 @@ package org.nd4j.aeron.ipc; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import org.agrona.CloseHelper; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -38,10 +38,10 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertFalse; @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class NdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; @@ -56,7 +56,7 @@ public class NdArrayIpcTest extends BaseND4JTest { return 120000L; } - @Before + @BeforeEach public void before() { if(isIntegrationTests()) { MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); @@ -66,7 +66,7 @@ public class NdArrayIpcTest extends BaseND4JTest { } } - @After + @AfterEach public void after() { if(isIntegrationTests()) { CloseHelper.quietClose(mediaDriver); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java index 3c2114f9e..fe273a729 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java @@ -20,18 +20,18 @@ package org.nd4j.aeron.ipc.chunk; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class ChunkAccumulatorTests extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java index 54e2f5773..daef2beb1 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java @@ -21,8 +21,8 @@ package org.nd4j.aeron.ipc.chunk; import org.agrona.DirectBuffer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.util.BufferUtil; @@ -31,11 +31,11 @@ import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; import java.nio.ByteBuffer; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class NDArrayMessageChunkTests extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java index 9c279544e..c98fcc01f 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java @@ -26,9 +26,9 @@ import io.aeron.driver.ThreadingMode; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.*; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,11 +38,11 @@ import javax.annotation.concurrent.NotThreadSafe; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j @NotThreadSafe -@Ignore("Tests are too flaky") +@Disabled("Tests are too flaky") public class AeronNDArrayResponseTest extends BaseND4JTest { private MediaDriver mediaDriver; @@ -51,7 +51,7 @@ public class AeronNDArrayResponseTest extends BaseND4JTest { return 180000L; } - @Before + @BeforeEach public void before() { if(isIntegrationTests()) { final MediaDriver.Context ctx = diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java index ae794fa65..e8d4362aa 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java @@ -21,12 +21,12 @@ package org.nd4j.arrow; import org.apache.arrow.flatbuf.Tensor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class ArrowSerdeTest extends BaseND4JTest { diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java index 457b462ea..195f1beee 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java +++ b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java @@ -27,10 +27,10 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.serializer.SerializerInstance; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.*; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; @@ -42,14 +42,14 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -@Ignore("Ignoring due to flaky nature of tests") +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +@Disabled("Ignoring due to flaky nature of tests") public class TestNd4jKryoSerialization extends BaseND4JTest { private JavaSparkContext sc; - @Before + @BeforeEach public void before() { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[*]"); @@ -117,11 +117,11 @@ public class TestNd4jKryoSerialization extends BaseND4JTest { // assertEquals(in, deserialized); boolean equals = in.equals(deserialized); - assertTrue(in.getClass() + "\t" + in.toString(), equals); + assertTrue(equals,in.getClass() + "\t" + in.toString()); } - @After + @AfterEach public void after() { if (sc != null) sc.close(); diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt index a75c39237..6ce83d1b6 100644 --- a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt +++ b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt @@ -20,10 +20,10 @@ package org.nd4j.samediff.frameworkimport.onnx -import junit.framework.Assert -import junit.framework.Assert.* + import onnx.Onnx -import org.junit.Ignore +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.nd4j.ir.OpNamespace import org.nd4j.linalg.api.buffer.DataType @@ -102,7 +102,7 @@ class TestOnnxIR { @Test - @Ignore + @Disabled fun testOpsMapped() { val onnxOpRegistry = registry() @@ -164,11 +164,11 @@ class TestOnnxIR { onnxOpDef.inputList.forEach { inputName -> - Assert.assertTrue(onnxAssertionNames.contains(inputName)) + assertTrue(onnxAssertionNames.contains(inputName)) } onnxOpDef.attributeList.map { attrDef -> attrDef.name }.forEach { attrName -> - Assert.assertTrue(onnxAssertionNames.contains(attrName)) + assertTrue(onnxAssertionNames.contains(attrName)) } @@ -184,19 +184,19 @@ class TestOnnxIR { * we just log a warning for unmapped inputs. Otherwise we can do an assertion. */ if(numRequiredInputs == nd4jInputs) - assertTrue("Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}", nd4jNamesMapped.contains(argDef.name)) + assertTrue(nd4jNamesMapped.contains(argDef.name),"Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}") else if(!nd4jNamesMapped.contains(argDef.name)) { println("Warning: Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}") } } OpNamespace.ArgDescriptor.ArgType.INT32,OpNamespace.ArgDescriptor.ArgType.INT64 -> { - assertTrue("Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}", nd4jNamesMapped.contains(argDef.name)) + assertTrue(nd4jNamesMapped.contains(argDef.name),"Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}") } OpNamespace.ArgDescriptor.ArgType.DOUBLE, OpNamespace.ArgDescriptor.ArgType.FLOAT -> { - assertTrue("Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}", nd4jNamesMapped.contains(argDef.name)) + assertTrue(nd4jNamesMapped.contains(argDef.name),"Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}") } OpNamespace.ArgDescriptor.ArgType.BOOL -> { - assertTrue("Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}", nd4jNamesMapped.contains(argDef.name)) + assertTrue(nd4jNamesMapped.contains(argDef.name),"Nd4j op name ${opDef.name} with onnx mapping ${onnxOpDef.name} has missing mapping ${argDef.name}") } } @@ -343,7 +343,7 @@ class TestOnnxIR { val inputs = mapOf("input" to input) val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") - assertEquals("Function ${nd4jOpDef.name} failed with input $input",assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1)) + assertEquals(assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $input") finishedOps.add(nd4jOpDef.name) } else if(scalarFloatOps.containsKey(nd4jOpDef.name)) { @@ -371,7 +371,7 @@ class TestOnnxIR { val inputs = mapOf("input" to input) val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") - assertEquals("Function ${nd4jOpDef.name} failed with input $input",assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1)) + assertEquals(assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $input") finishedOps.add(nd4jOpDef.name) } @@ -403,7 +403,7 @@ class TestOnnxIR { val inputs = mapOf("input" to input) val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") - assertEquals("Function ${nd4jOpDef.name} failed with input $input",assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1)) + assertEquals(assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $input") finishedOps.add(nd4jOpDef.name) } @@ -438,7 +438,7 @@ class TestOnnxIR { val inputs = mapOf("x" to x,"y" to y) val result = importedGraph.output(inputs,"output") val assertion = onnxGraphRunner.run(inputs) - assertEquals("Function ${nd4jOpDef.name} failed with input $x $y",assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0)) + assertEquals(assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0),"Function ${nd4jOpDef.name} failed with input $x $y") finishedOps.add(nd4jOpDef.name) } else if(pairWiseBooleanInputs.containsKey(nd4jOpDef.name)) { @@ -469,7 +469,7 @@ class TestOnnxIR { val inputs = mapOf("x" to x,"y" to y) val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") - assertEquals("Function ${nd4jOpDef.name} failed with input $x $y",assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0)) + assertEquals(assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0),"Function ${nd4jOpDef.name} failed with input $x $y") finishedOps.add(nd4jOpDef.name) } else if(pairWiseBooleanOps.containsKey(nd4jOpDef.name)) { @@ -502,7 +502,7 @@ class TestOnnxIR { val inputs = mapOf("x" to x,"y" to y) val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") - assertEquals("Function ${nd4jOpDef.name} failed with input $x $y",assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0)) + assertEquals(assertion["output"]!!.getDouble(0),result["output"]!!.getDouble(0),"Function ${nd4jOpDef.name} failed with input $x $y") finishedOps.add(nd4jOpDef.name) } @@ -573,7 +573,7 @@ class TestOnnxIR { val result = importedGraph.output(inputs,"output") val onnxGraphRunner = OnnxIRGraphRunner(onnxIRGraph,listOf("x"),listOf("output")) val assertion = onnxGraphRunner.run(inputs) - assertEquals("Function ${nd4jOpDef.name} failed with input $x",assertion["output"]!!.reshape(1,2),result["output"]!!.reshape(1,2)) + assertEquals(assertion["output"]!!.reshape(1,2),result["output"]!!.reshape(1,2),"Function ${nd4jOpDef.name} failed with input $x") finishedOps.add(nd4jOpDef.name) } else if(mappedOps.contains(nd4jOpDef.name)){ @@ -592,10 +592,10 @@ class TestOnnxIR { assertEquals(assertion.keys,result.keys) result.forEach { name,arr -> if(arr.length().toInt() == 1) { - assertEquals("Function ${nd4jOpDef.name} failed with input ${graph.inputNames}",assertion[name]!!.getDouble(0),arr.getDouble(0),1e-3) + assertEquals(assertion[name]!!.getDouble(0),arr.getDouble(0),1e-3,"Function ${nd4jOpDef.name} failed with input ${graph.inputNames}") } else { - assertEquals("Function ${nd4jOpDef.name} failed with input ${graph.inputNames}",assertion[name],arr) + assertEquals(assertion[name],arr,"Function ${nd4jOpDef.name} failed with input ${graph.inputNames}") } } @@ -630,9 +630,9 @@ class TestOnnxIR { val assertion = onnxGraphRunner.run(inputs) val result = importedGraph.output(inputs,"output") if(assertion["output"]!!.length() == 1L) - assertEquals("Function ${nd4jOpDef.name} failed with input $input",assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1)) + assertEquals(assertion["output"]!!.reshape(1,1),result["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $input") else - assertEquals("Function ${nd4jOpDef.name} failed with input $input",assertion["output"]!!.ravel(),result["output"]!!.ravel()) + assertEquals(assertion["output"]!!.ravel(),result["output"]!!.ravel(),"Function ${nd4jOpDef.name} failed with input $input") finishedOps.add(nd4jOpDef.name) } diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt index 26282fadd..4c6ea41f0 100644 --- a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt +++ b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt @@ -20,13 +20,13 @@ package org.nd4j.samediff.frameworkimport.onnx.importer import junit.framework.Assert.assertNotNull -import org.junit.Ignore +import org.junit.jupiter.api.Disabled import org.junit.Test import org.nd4j.common.io.ClassPathResource class TestOnnxFrameworkImporter { @Test - @Ignore + @Disabled fun testOnnxImporter() { val onnxImport = OnnxFrameworkImporter() val onnxFile = ClassPathResource("lenet.onnx").file diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt index ee71246af..0849d4ea2 100644 --- a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt +++ b/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt @@ -37,7 +37,7 @@ package org.nd4j.samediff.frameworkimport.onnx.modelzoo import onnx.Onnx import org.apache.commons.io.FileUtils -import org.junit.Ignore +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.nd4j.common.resources.Downloader import org.nd4j.common.util.ArchiveUtils @@ -50,7 +50,7 @@ import java.io.File import java.net.URI data class InputDataset(val dataSetIndex: Int,val inputPaths: List,val outputPaths: List) -@Ignore +@Disabled class TestPretrainedModels { val modelBaseUrl = "https://media.githubusercontent.com/media/onnx/models/master" @@ -201,7 +201,7 @@ class TestPretrainedModels { @Test - @Ignore + @Disabled fun test() { modelPaths.forEach { pullModel(it) diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt index 409b611cb..c554c6505 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt @@ -20,12 +20,11 @@ package org.nd4j.samediff.frameworkimport.tensorflow -import junit.framework.Assert.assertEquals -import junit.framework.Assert.assertTrue + import org.apache.commons.io.FileUtils import org.apache.commons.io.IOUtils -import org.junit.Assert -import org.junit.Ignore +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.nd4j.common.io.ClassPathResource import org.nd4j.imports.graphmapper.tf.TFGraphMapper @@ -50,6 +49,9 @@ import java.nio.charset.StandardCharsets import java.util.* import kotlin.collections.HashMap import kotlin.collections.HashSet +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue + data class GraphInput(val graphDef: GraphDef,val inputNames: List,val outputNames: List, val inputArrays: Map,val dynamicArrays: Map) @@ -68,7 +70,7 @@ class TestTensorflowIR { @Test - @Ignore + @Disabled fun manualTest() { val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset()) val parsedGraph = GraphDef.newBuilder() @@ -152,7 +154,7 @@ class TestTensorflowIR { @Test - @Ignore + @Disabled fun manualTestBinary() { val path = "C:\\Users\\agibs\\.nd4jtests\\resnetv2_imagenet_frozen_graph\\resnetv2_imagenet_frozen_graph.pb" val bytes = FileUtils.readFileToByteArray(File(path)) @@ -201,7 +203,7 @@ class TestTensorflowIR { //Perform inference val inputs: List = importedGraph.inputs() - Assert.assertEquals(1, inputs.size.toLong()) + assertEquals(1, inputs.size.toLong()) val out = "softmax_tensor" val m: Map = importedGraph.output(Collections.singletonMap(inputs[0], img), out) @@ -222,7 +224,7 @@ class TestTensorflowIR { val prob = outArr!!.getDouble(classIdx.toLong()) println("Predicted class: $classIdx - \"$className\" - probability = $prob") - Assert.assertEquals(expClass, className) + assertEquals(expClass, className) val inputMap = Collections.singletonMap(inputs[0], img) val tensorflowIRGraph = TensorflowIRGraph(parsedGraph,tensorflowOps,tfImporter.registry) @@ -295,7 +297,7 @@ class TestTensorflowIR { @Test - @Ignore + @Disabled fun manualTest2() { val manualGraph = FileUtils.readFileToString(File("test.pbtxt"),Charset.defaultCharset()) val parsedGraph = GraphDef.newBuilder() @@ -338,7 +340,7 @@ class TestTensorflowIR { @Test - @Ignore + @Disabled fun testTensorflowMappingContext() { val tensorflowOpRegistry = registry() @@ -423,7 +425,6 @@ class TestTensorflowIR { @Test - @Ignore @org.junit.jupiter.api.Disabled fun testOpExecution() { Nd4j.getRandom().setSeed(12345) @@ -778,7 +779,7 @@ class TestTensorflowIR { assertTrue(tfOutput.isScalar) val nd4jOutput = results["output"]!! assertTrue(nd4jOutput.isScalar) - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal",nd4jOutput.getDouble(0), tfOutput.getDouble(0),1e-3) + assertEquals(nd4jOutput.getDouble(0), tfOutput.getDouble(0),1e-3,"Function ${nd4jOpDef.name} failed with input $xVal") testedOps.add(nd4jOpDef.name) } else if(singularReduceNames.contains(nd4jOpDef.name)) { @@ -841,9 +842,9 @@ class TestTensorflowIR { val tfResults = tensorflowRunner.run(inputs) //2 dimensions means sum the whole array, sometimes there are subtle differences in the shape like 1,1 vs a zero length array which is effectively the same thing if(dimensions.size < 2) - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}",tfResults["output"]!!, results["output"]!!) + assertEquals(tfResults["output"]!!, results["output"]!!,"Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}") else - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}",tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1)) + assertEquals(tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}") } @@ -909,9 +910,9 @@ class TestTensorflowIR { val tfResults = tensorflowRunner.run(inputs) //2 dimensions means sum the whole array, sometimes there are subtle differences in the shape like 1,1 vs a zero length array which is effectively the same thing if(dimensions.size < 2) - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}",tfResults["output"]!!, results["output"]!!) + assertEquals(tfResults["output"]!!, results["output"]!!,"Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}") else - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}",tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1)) + assertEquals(tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $xVal and dimension ${dimensions}") } @@ -972,7 +973,7 @@ class TestTensorflowIR { val inputs = mapOf("x" to xVal,"y" to yVal) val results = mappedGraph.output(inputs,"output") val tfResults = tensorflowRunner.run(inputs) - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal",tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1)) + assertEquals(tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $xVal") testedOps.add(nd4jOpDef.name) } else if(pairWiseIntOps.contains(nd4jOpDef.name)) { @@ -1023,7 +1024,7 @@ class TestTensorflowIR { val inputs = mapOf("x" to xVal,"y" to yVal) val results = mappedGraph.output(inputs,"output") val tfResults = tensorflowRunner.run(inputs) - assertEquals("Function ${nd4jOpDef.name} failed with input $xVal",tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1)) + assertEquals(tfResults["output"]!!.reshape(1,1), results["output"]!!.reshape(1,1),"Function ${nd4jOpDef.name} failed with input $xVal") testedOps.add(nd4jOpDef.name) } else if(mappedOps.contains(mappingProcess.opName())) { @@ -1047,7 +1048,7 @@ class TestTensorflowIR { val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) @@ -1062,20 +1063,20 @@ class TestTensorflowIR { println(Nd4j.getExecutioner().exec(DynamicCustomOp.builder("bincount").addInputs(inputVal,weightVal).addIntegerArguments(0,3).build())[0]) println() } - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames} " + + assertEquals(tfResults.values.first(), results.values.first(),"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames} " + "with tfValue of shape ${tfResults.values.first().shapeInfoToString()} and nd4j ${results.values.first().shapeInfoToString()} and ${graphInput}" - ,tfResults.values.first(), results.values.first()) + ) } else if(mappingProcess.opName() == "unique_with_counts" || mappingProcess.opName() == "unique") { //note: this is a separate case since the results are equal, minus dimensions val tensorflowRunner = TensorflowIRGraphRunner(irGraph = tensorflowGraph,inputNames = graphInput.inputNames,outputNames = graphInput.outputNames) val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults.values.first().ravel(), results.values.first().ravel()) + assertEquals(tfResults.values.first().ravel(), results.values.first().ravel(),"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") }//slight difference in scalar result, doesn't matter in practice else if(mappingProcess.opName() == "matrix_determinant" || mappingProcess.opName() == "log_matrix_determinant") { //note: this is a separate case since the results are equal, minus dimensions @@ -1083,12 +1084,12 @@ class TestTensorflowIR { val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") if(mappingProcess.opName() == "matrix_determinant") { val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults["output"]!!.ravel().getDouble(0), results["output"]!!.ravel().getDouble(0),1e-3) + assertEquals(tfResults["output"]!!.ravel().getDouble(0), results["output"]!!.ravel().getDouble(0),1e-3,"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") } } @@ -1097,19 +1098,19 @@ class TestTensorflowIR { val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults, results) + assertEquals(tfResults, results,"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") } else if(mappingProcess.opName() == "draw_bounding_boxes") { val tensorflowRunner = TensorflowIRGraphRunner(irGraph = tensorflowGraph,inputNames = graphInput.inputNames,outputNames = graphInput.outputNames) val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults, results) + assertEquals(tfResults, results,"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") } else if(mappingProcess.opName() == "fused_batch_norm" && !tf2Ops.contains(mappingProcess.inputFrameworkOpName())) { @@ -1117,11 +1118,11 @@ class TestTensorflowIR { val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults["y"], results["y"]) + assertEquals(tfResults["y"], results["y"],"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") } @@ -1131,11 +1132,11 @@ class TestTensorflowIR { val mappedGraph = importGraph.importGraph(tensorflowGraph,null,null,dynamicOpsMap,OpRegistryHolder.tensorflow()) - assertEquals("Input name mismatch with input array elements",graphInput.inputArrays.keys,graphInput.inputNames.toSet()) + assertEquals(graphInput.inputArrays.keys,graphInput.inputNames.toSet(),"Input name mismatch with input array elements") val tfResults = tensorflowRunner.run(graphInput.inputArrays) val results = mappedGraph!!.output(graphInput.inputArrays,graphInput.outputNames) - assertEquals("Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}",tfResults["finalResult"]!!.ravel().getDouble(0), results["finalResult"]!!.ravel().getDouble(0),1e-3) + assertEquals(tfResults["finalResult"]!!.ravel().getDouble(0), results["finalResult"]!!.ravel().getDouble(0),1e-3,"Function ${nd4jOpDef.name} failed with input ${graphInput.inputNames}") } diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt index 854a377bb..4e944929f 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt @@ -20,14 +20,14 @@ package org.nd4j.samediff.frameworkimport.tensorflow.importer import junit.framework.Assert -import org.junit.Ignore -import org.junit.Test +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test import org.nd4j.common.io.ClassPathResource class TestTensorflowImporter { @Test - @Ignore + @Disabled fun testImporter() { val tfFrameworkImport = TensorflowFrameworkImporter() val tfFile = ClassPathResource("lenet_frozen.pb").file diff --git a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java index eaa044d0e..9e859651c 100644 --- a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java +++ b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java @@ -20,19 +20,24 @@ import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.python4j.*; import javax.annotation.concurrent.NotThreadSafe; import java.util.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + @NotThreadSafe public class PythonBasicExecutionTest { - @Test(expected = IllegalStateException.class) + @Test() public void testSimpleExecIllegal() { - String code = "print('Hello World')"; - PythonExecutioner.exec(code); + assertThrows(IllegalStateException.class,() -> { + String code = "print('Hello World')"; + PythonExecutioner.exec(code); + }); + } diff --git a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java index d2307a530..395582d8a 100644 --- a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java +++ b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java @@ -21,7 +21,7 @@ import org.nd4j.python4j.*; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.*; diff --git a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java index d362c78ef..ef06d5095 100644 --- a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java +++ b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java @@ -24,7 +24,7 @@ import org.nd4j.python4j.Python; import org.nd4j.python4j.PythonContextManager; import org.nd4j.python4j.PythonExecutioner; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.python4j.PythonGIL; import javax.annotation.concurrent.NotThreadSafe; diff --git a/python4j/python4j-core/src/test/java/PythonGCTest.java b/python4j/python4j-core/src/test/java/PythonGCTest.java index b66dc1807..57dcc02ac 100644 --- a/python4j/python4j-core/src/test/java/PythonGCTest.java +++ b/python4j/python4j-core/src/test/java/PythonGCTest.java @@ -23,7 +23,7 @@ import org.nd4j.python4j.PythonGC; import org.nd4j.python4j.PythonGIL; import org.nd4j.python4j.PythonObject; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import javax.annotation.concurrent.NotThreadSafe; diff --git a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java index da595b382..67e107b3a 100644 --- a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java +++ b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java @@ -18,13 +18,11 @@ * ***************************************************************************** */ -import org.bytedeco.cpython.PyThreadState; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.python4j.*; import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; @@ -32,9 +30,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import static org.bytedeco.cpython.global.python.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.bytedeco.cpython.global.python.PyGILState_Check; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @NotThreadSafe @@ -145,7 +143,7 @@ public class PythonMultiThreadTest { public void run() { try(PythonGIL pythonGIL = PythonGIL.lock()) { System.out.println("Using thread " + Thread.currentThread().getId() + " to invoke python"); - assertTrue("Thread " + Thread.currentThread().getId() + " does not hold the gil.", PyGILState_Check() > 0); + assertTrue(PyGILState_Check() > 0,"Thread " + Thread.currentThread().getId() + " does not hold the gil."); PythonExecutioner.exec("import time; time.sleep(10)"); System.out.println("Finished execution on thread " + Thread.currentThread().getId()); finishedExecutionCount.incrementAndGet(); diff --git a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java index ca0727f0b..980d2f72f 100644 --- a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java +++ b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java @@ -21,7 +21,7 @@ import org.nd4j.python4j.*; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java index 080b91858..0332c6d94 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -21,7 +21,7 @@ import org.nd4j.python4j.*; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java index 33a3f09dc..2dbe8305c 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -24,7 +24,7 @@ import org.nd4j.python4j.PythonGIL; import org.nd4j.python4j.PythonObject; import org.nd4j.python4j.PythonTypes; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java index d74746593..997eda5e8 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java @@ -23,7 +23,7 @@ import org.nd4j.python4j.PythonGC; import org.nd4j.python4j.PythonGIL; import org.nd4j.python4j.PythonObject; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java index c3124fcdc..70a6ac7c6 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java @@ -20,7 +20,7 @@ import org.nd4j.python4j.*; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java index dae0486d9..17a794015 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -20,7 +20,7 @@ import org.nd4j.python4j.*; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java index 05a5dee6e..5d39d27ae 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java @@ -20,7 +20,7 @@ import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.python4j.NumpyArray; diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index 0aef69056..eb63be1c8 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -118,6 +118,22 @@ 3.3.3 test + + org.mockito + mockito-junit-jupiter + 2.23.0 + test + + + org.junit.platform + junit-platform-runner + 1.2.0 + test + + + org.junit.vintage + junit-vintage-engine + diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java index 5a6c4c62c..92f64ab45 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java @@ -28,7 +28,7 @@ import org.deeplearning4j.rl4j.environment.StepResult; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.policy.IPolicy; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -43,7 +43,7 @@ import java.util.Map; import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(MockitoJUnitRunner.class) public class AgentLearnerTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java index 3c58f0a07..89c4ee824 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java @@ -25,13 +25,17 @@ import org.deeplearning4j.rl4j.environment.*; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.policy.IPolicy; -import org.junit.Rule; -import org.junit.Test; -import static org.junit.Assert.*; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.platform.runner.JUnitPlatform; import org.junit.runner.RunWith; import org.mockito.*; +import org.mockito.exceptions.base.MockitoException; import org.mockito.junit.*; +import org.mockito.junit.jupiter.MockitoExtension; import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; @@ -40,15 +44,15 @@ import java.util.Map; import static org.mockito.Mockito.*; -@RunWith(MockitoJUnitRunner.class) +@RunWith(JUnitPlatform.class) +@ExtendWith(MockitoExtension.class) public class AgentTest { @Mock Environment environmentMock; @Mock TransformProcess transformProcessMock; @Mock IPolicy policyMock; @Mock AgentListener listenerMock; - @Rule - public MockitoRule mockitoRule = MockitoJUnit.rule(); + @Test public void when_buildingWithNullEnvironment_expect_exception() { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java index 986c1d9e5..bac000d73 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java @@ -20,12 +20,12 @@ package org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class NonRecurrentActorCriticHelperTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java index 766ad60e6..9609949ca 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java @@ -27,8 +27,8 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) @@ -58,7 +58,7 @@ public class NonRecurrentAdvantageActorCriticTest { private AdvantageActorCritic sut; - @Before + @BeforeEach public void init() { when(neuralNetOutputMock.get(CommonOutputNames.ActorCritic.Value)).thenReturn(Nd4j.create(new double[] { 123.0 })); when(configurationMock.getGamma()).thenReturn(GAMMA); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java index ce6437b72..855cc7ddf 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java @@ -20,12 +20,12 @@ package org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class RecurrentActorCriticHelperTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java index 1773b76c3..802be9a84 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java @@ -27,8 +27,8 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -40,7 +40,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) @@ -59,7 +59,7 @@ public class RecurrentAdvantageActorCriticTest { private AdvantageActorCritic sut; - @Before + @BeforeEach public void init() { when(neuralNetOutputMock.get(CommonOutputNames.ActorCritic.Value)).thenReturn(Nd4j.create(new double[] { 123.0 })); when(configurationMock.getGamma()).thenReturn(GAMMA); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java index e9e496ca2..871d026aa 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java @@ -28,8 +28,8 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -56,7 +56,7 @@ public class DoubleDQNTest { .gamma(0.5) .build(); - @Before + @BeforeEach public void setup() { when(qNetworkMock.output(any(Features.class))).thenAnswer(i -> { NeuralNetOutput result = new NeuralNetOutput(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java index 68ad77948..1e760b8fe 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java @@ -28,8 +28,8 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -39,7 +39,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -56,7 +56,7 @@ public class StandardDQNTest { .gamma(0.5) .build(); - @Before + @BeforeEach public void setup() { when(qNetworkMock.output(any(Features.class))).thenAnswer(i -> { NeuralNetOutput result = new NeuralNetOutput(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java index 09693e4c8..97ec2c44f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java @@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; public class NonRecurrentNStepQLearningHelperTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java index 3f928b1fd..a2c4d54c8 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java @@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.experience.StateActionReward; import org.deeplearning4j.rl4j.network.*; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -36,7 +36,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java index b53836679..e364d272f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.network.CommonOutputNames; import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNetOutput; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,8 +34,8 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; public class RecurrentNStepQLearningHelperTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java index e27c25ecb..003f1667c 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.experience.StateActionReward; import org.deeplearning4j.rl4j.network.*; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -37,7 +37,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java index 37551c2bc..16201f4c5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java @@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.agent.learning.behavior.LearningBehavior; import org.deeplearning4j.rl4j.agent.learning.update.IUpdateRule; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -36,8 +36,8 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) @@ -51,7 +51,7 @@ public class LearningBehaviorTest { LearningBehavior sut; - @Before + @BeforeEach public void setup() { sut = LearningBehavior.builder() .experienceHandler(experienceHandlerMock) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java index 2dc36ab9c..e788ffb32 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java @@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.experience.StateActionReward; import org.deeplearning4j.rl4j.experience.StateActionRewardState; import org.deeplearning4j.rl4j.observation.IObservationSource; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,8 +33,8 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; @RunWith(MockitoJUnitRunner.class) public class FeaturesBuilderTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java index 4c1ea7211..ca3f2f0a2 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java @@ -20,13 +20,13 @@ package org.deeplearning4j.rl4j.agent.learning.update; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java index bc020a39c..43f4d3c31 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java @@ -20,12 +20,12 @@ package org.deeplearning4j.rl4j.agent.learning.update; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; public class FeaturesTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java index c741d4a3f..43372713f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java @@ -21,12 +21,12 @@ package org.deeplearning4j.rl4j.agent.learning.update; import org.deeplearning4j.nn.gradient.Gradient; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.mock; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java index d91d1e51a..07837bbdb 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java @@ -22,8 +22,8 @@ package org.deeplearning4j.rl4j.agent.learning.update; import org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm; import org.deeplearning4j.rl4j.agent.learning.update.updater.INeuralNetUpdater; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -31,7 +31,7 @@ import org.mockito.junit.MockitoJUnitRunner; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) @@ -45,7 +45,7 @@ public class UpdateRuleTest { private UpdateRule sut; - @Before + @BeforeEach public void init() { sut = new UpdateRule(updateAlgorithm, updater); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java index 2d80aa9eb..8f02fa800 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java @@ -23,7 +23,7 @@ package org.deeplearning4j.rl4j.agent.learning.update.updater.async; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java index ac1da9461..d953202d7 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java @@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java index 575b89e33..56694cfd9 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java @@ -23,13 +23,13 @@ package org.deeplearning4j.rl4j.agent.learning.update.updater.async; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java index 35d8564b0..f11e6e45d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java @@ -23,7 +23,7 @@ package org.deeplearning4j.rl4j.agent.learning.update.updater.sync; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java index 4578ea01b..dd11d8a75 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java @@ -23,13 +23,13 @@ package org.deeplearning4j.rl4j.agent.learning.update.updater.sync; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java index 73a958f88..962c09469 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java @@ -29,8 +29,8 @@ import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.policy.IPolicy; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -72,7 +72,7 @@ public class BaseAgentLearnerBuilderTest { BaseAgentLearnerBuilder sut; - @Before + @BeforeEach public void setup() { sut = mock( BaseAgentLearnerBuilder.class, diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java index c737ce8f0..31adba1d5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java @@ -22,7 +22,7 @@ package org.deeplearning4j.rl4j.experience; import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -31,7 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java index 46350a8fe..c2c79a363 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java @@ -21,12 +21,12 @@ package org.deeplearning4j.rl4j.experience; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class StateActionExperienceHandlerTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java index 5b1bd75a3..b82c9a219 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -20,11 +20,11 @@ package org.deeplearning4j.rl4j.helper; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class INDArrayHelperTest { @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java index 2c5753caa..9ab2f5c3e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java @@ -20,11 +20,11 @@ package org.deeplearning4j.rl4j.learning; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index a2acb32fd..75d02d483 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Box; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -54,7 +54,7 @@ public class AsyncLearningTest { @Mock IAsyncLearningConfiguration mockConfiguration; - @Before + @BeforeEach public void setup() { asyncLearning = mock(AsyncLearning.class, Mockito.withSettings() .useConstructor() diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index f550cd940..8ab512dc2 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -31,8 +31,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -41,9 +41,9 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; @@ -107,7 +107,7 @@ public class AsyncThreadDiscreteTest { when(mockGlobalTargetNetwork.clone()).thenReturn(mockGlobalTargetNetwork); } - @Before + @BeforeEach public void setup() { setupMDPMocks(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 514578c49..af55d76af 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -29,8 +29,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.util.IDataManager; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -39,7 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.shade.guava.base.Preconditions; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; @@ -78,7 +78,7 @@ public class AsyncThreadTest { AsyncThread, NeuralNet> thread; - @Before + @BeforeEach public void setup() { setupMDPMocks(); setupThreadMocks(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java index 831a53b06..d0a09deff 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -37,7 +37,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -53,7 +53,7 @@ public class AdvantageActorCriticUpdateAlgorithmTest { IActorCritic mockActorCritic; @Test - @Ignore + @Disabled public void refac_calcGradient_non_terminal() { // Arrange int[] observationShape = new int[]{5}; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java index b3b354d2f..d6d865666 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java @@ -24,10 +24,10 @@ import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.util.IDataManager; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class AsyncTrainingListenerListTest { @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java index 937c70a52..afe122206 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java @@ -25,7 +25,7 @@ import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -36,7 +36,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.*; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java index 9d2b3fd7a..7eb2db655 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java @@ -23,10 +23,10 @@ package org.deeplearning4j.rl4j.learning.listener; import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.util.IDataManager; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mock; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java index 7d4f6ffbd..304c67694 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -23,12 +23,12 @@ package org.deeplearning4j.rl4j.learning.sync; import org.deeplearning4j.rl4j.experience.StateActionRewardState; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.support.MockRandom; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class ExpReplayTest { @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java index ef48d629c..d4c35e016 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java @@ -22,11 +22,11 @@ package org.deeplearning4j.rl4j.learning.sync; import org.deeplearning4j.rl4j.experience.StateActionRewardState; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class StateActionRewardStateTest { @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index c96e68a30..32cdb8fa0 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -27,8 +27,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.util.IDataManager; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -53,7 +53,7 @@ public class SyncLearningTest { @Mock ILearningConfiguration mockLearningConfiguration; - @Before + @BeforeEach public void setup() { syncLearning = mock(SyncLearning.class, Mockito.withSettings() diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java index aec26be38..96ae8ef54 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java @@ -22,13 +22,10 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning; import com.fasterxml.jackson.databind.ObjectMapper; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; + +import org.junit.jupiter.api.Test; public class QLearningConfigurationTest { - @Rule - public ExpectedException thrown = ExpectedException.none(); @Test public void serialize() throws Exception { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 68a3ec85b..eddc51a3a 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -33,8 +33,8 @@ import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; @@ -42,8 +42,8 @@ import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; @@ -141,7 +141,7 @@ public class QLearningDiscreteTest { qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor); } - @Before + @BeforeEach public void setup() { setupMDPMocks(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java index 2cdd4b570..8cc49d3d6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java @@ -26,14 +26,14 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.*; import static org.mockito.Mockito.times; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java index 0e47b7eeb..667a04447 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java @@ -24,7 +24,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -33,8 +33,8 @@ import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java index 69e77e097..7318fe1c6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java @@ -22,13 +22,13 @@ package org.deeplearning4j.rl4j.network; import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; @RunWith(MockitoJUnitRunner.class) public class ChannelToNetworkInputMapperTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java index 4fe8fc04c..156e0fc75 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java @@ -24,14 +24,14 @@ import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java index 43203210d..604617a83 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java @@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.junit.MockitoJUnitRunner; @@ -41,7 +41,7 @@ import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collection; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java index 2b2af6e26..e6ba2a857 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java @@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.junit.MockitoJUnitRunner; @@ -41,7 +41,7 @@ import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collection; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java index 8692b467e..03d17e4f6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java @@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,8 +33,8 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java index c77a35e73..3564cffc6 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java @@ -26,14 +26,14 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.agent.learning.update.Features; import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels; import org.deeplearning4j.rl4j.agent.learning.update.Gradients; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java index 50e015f3d..091bce4ef 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.rl4j.network.ac; import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -30,8 +30,8 @@ import org.nd4j.linalg.learning.config.RmsProp; import java.io.File; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author saudet diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java index 9cfd4ee3f..f1737d65f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java @@ -21,13 +21,13 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.learning.config.RmsProp; import java.io.File; import java.io.IOException; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author saudet diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java index 5230f9829..b0bbcbce8 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.rl4j.observation.transform; import org.deeplearning4j.rl4j.observation.Observation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.factory.Nd4j; @@ -30,40 +30,52 @@ import org.datavec.api.transform.Operation; import java.util.HashMap; import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class TransformProcessTest { - @Test(expected = IllegalArgumentException.class) + @Test() public void when_noChannelNameIsSuppliedToBuild_expect_exception() { // Arrange - TransformProcess.builder().build(); + assertThrows(IllegalArgumentException.class,() -> { + TransformProcess.builder().build(); + + }); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_callingTransformWithNullArg_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .build("test"); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + + // Act + sut.transform(null, 0, false); + }); - // Act - sut.transform(null, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_callingTransformWithEmptyChannelData_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .build("test"); - Map channelsData = new HashMap(); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap(); + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } - @Test(expected = NullPointerException.class) + @Test() public void when_addingNullFilter_expect_nullException() { - // Act - TransformProcess.builder().filter(null); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().filter(null); + }); + } @Test @@ -86,16 +98,22 @@ public class TransformProcessTest { assertFalse(transformOperationMock.isCalled); } - @Test(expected = NullPointerException.class) + @Test() public void when_addingTransformOnNullChannel_expect_nullException() { - // Act - TransformProcess.builder().transform(null, new IntegerTransformOperationMock()); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().transform(null, new IntegerTransformOperationMock()); + }); + } - @Test(expected = NullPointerException.class) + @Test() public void when_addingTransformWithNullTransform_expect_nullException() { - // Act - TransformProcess.builder().transform("test", null); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().transform("test", null); + }); + } @Test @@ -118,16 +136,21 @@ public class TransformProcessTest { assertEquals(-1.0, result.getData().getDouble(0), 0.00001); } - @Test(expected = NullPointerException.class) + @Test() public void when_addingPreProcessOnNullChannel_expect_nullException() { - // Act - TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock()); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock()); + }); } - @Test(expected = NullPointerException.class) + @Test() public void when_addingPreProcessWithNullTransform_expect_nullException() { - // Act - TransformProcess.builder().transform("test", null); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().transform("test", null); + }); + } @Test @@ -153,54 +176,69 @@ public class TransformProcessTest { assertEquals(-10.0, result.getData().getDouble(0), 0.00001); } - @Test(expected = IllegalStateException.class) + @Test() public void when_transformingNullData_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .transform("test", new IntegerTransformOperationMock()) - .build("test"); - Map channelsData = new HashMap() {{ - put("test", 1); - }}; + assertThrows(IllegalStateException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + }); - // Act - Observation result = sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_transformingAndChannelsNotDataSet_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .preProcess("test", new DataSetPreProcessorMock()) - .build("test"); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + + // Act + Observation result = sut.transform(null, 0, false); + }); - // Act - Observation result = sut.transform(null, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_transformingAndChannelsEmptyDataSet_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .preProcess("test", new DataSetPreProcessorMock()) - .build("test"); - Map channelsData = new HashMap(); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap(); + + // Act + Observation result = sut.transform(channelsData, 0, false); + }); - // Act - Observation result = sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_buildIsCalledWithoutChannelNames_expect_exception() { - // Act - TransformProcess.builder().build(); + assertThrows(IllegalArgumentException.class,() -> { + // Act + TransformProcess.builder().build(); + }); + } - @Test(expected = NullPointerException.class) + @Test() public void when_buildIsCalledWithNullChannelName_expect_exception() { - // Act - TransformProcess.builder().build(null); + assertThrows(NullPointerException.class,() -> { + // Act + TransformProcess.builder().build(null); + }); + } @Test @@ -257,87 +295,105 @@ public class TransformProcessTest { assertEquals(1.0, result.getData().getDouble(0), 0.00001); } - @Test(expected = IllegalStateException.class) + @Test() public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .build("test"); - Map channelsData = new HashMap() {{ - put("test", 1); - }}; + assertThrows(IllegalStateException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + }); - // Act - Observation result = sut.transform(channelsData, 123, true); } - @Test(expected = NullPointerException.class) + @Test() public void when_channelDataIsNull_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .transform("test", new IntegerTransformOperationMock()) - .build("test"); - Map channelsData = new HashMap() {{ - put("test", null); - }}; + assertThrows(NullPointerException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", null); + }}; + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_transformAppliedOnChannelNotInMap_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .transform("test", new IntegerTransformOperationMock()) - .build("test"); - Map channelsData = new HashMap() {{ - put("not-test", 1); - }}; + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_preProcessAppliedOnChannelNotInMap_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .preProcess("test", new DataSetPreProcessorMock()) - .build("test"); - Map channelsData = new HashMap() {{ - put("not-test", 1); - }}; + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_buildContainsChannelNotInMap_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .transform("test", new IntegerTransformOperationMock()) - .build("not-test"); - Map channelsData = new HashMap() {{ - put("test", 1); - }}; + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("not-test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } - @Test(expected = IllegalArgumentException.class) + @Test() public void when_preProcessNotAppliedOnDataSet_expect_exception() { - // Arrange - TransformProcess sut = TransformProcess.builder() - .preProcess("test", new DataSetPreProcessorMock()) - .build("test"); - Map channelsData = new HashMap() {{ - put("test", 1); - }}; + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + }); - // Act - sut.transform(channelsData, 0, false); } @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java index 5af9463e0..6e004b3cd 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java @@ -21,17 +21,19 @@ package org.deeplearning4j.rl4j.observation.transform.filter; import org.deeplearning4j.rl4j.observation.transform.FilterOperation; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class UniformSkippingFilterTest { - @Test(expected = IllegalArgumentException.class) + @Test public void when_negativeSkipFrame_expect_exception() { - // Act - new UniformSkippingFilter(-1); + assertThrows(IllegalArgumentException.class,() -> { + // Act + new UniformSkippingFilter(-1); + }); + } @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java index f33aa8be9..b6a0d091f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java @@ -20,11 +20,11 @@ package org.deeplearning4j.rl4j.observation.transform.operation; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class ArrayToINDArrayTransformTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java index c1053f3bb..172f58c2d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java @@ -22,11 +22,11 @@ package org.deeplearning4j.rl4j.observation.transform.operation; import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler; import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class HistoryMergeTransformTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java index deb0cde63..9c2f4eebc 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java @@ -20,18 +20,21 @@ package org.deeplearning4j.rl4j.observation.transform.operation; -import org.deeplearning4j.rl4j.helper.INDArrayHelper; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; public class SimpleNormalizationTransformTest { - @Test(expected = IllegalArgumentException.class) + @Test() public void when_maxIsLessThanMin_expect_exception() { - // Arrange - SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0); + }); + } @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java index 9bcfcadf0..14d0a117e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java @@ -20,18 +20,21 @@ package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class CircularFifoStoreTest { - @Test(expected = IllegalArgumentException.class) + @Test() public void when_fifoSizeIsLessThan1_expect_exception() { - // Arrange - CircularFifoStore sut = new CircularFifoStore(0); + assertThrows(IllegalArgumentException.class,() -> { + // Arrange + CircularFifoStore sut = new CircularFifoStore(0); + }); + } @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java index 9fdaf5437..6adfdbb19 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java @@ -20,11 +20,11 @@ package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; public class HistoryStackAssemblerTest { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index b0aeb4c6b..f74713466 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -40,7 +40,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -49,7 +49,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.IOException; import java.io.OutputStream; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; /** * diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java index 355efe01c..54dc3bd32 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java @@ -22,8 +22,8 @@ package org.deeplearning4j.rl4j.trainer; import org.apache.commons.lang3.builder.Builder; import org.deeplearning4j.rl4j.agent.IAgentLearner; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -31,7 +31,7 @@ import org.mockito.junit.MockitoJUnitRunner; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) @@ -46,7 +46,7 @@ public class AsyncTrainerTest { @Mock IAgentLearner agentLearnerMock; - @Before + @BeforeEach public void setup() { when(agentLearnerBuilderMock.build()).thenReturn(agentLearnerMock); when(agentLearnerMock.getEpisodeStepCount()).thenReturn(100); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java index 58dfe7e7a..8c920be95 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java @@ -22,15 +22,15 @@ package org.deeplearning4j.rl4j.trainer; import org.apache.commons.lang3.builder.Builder; import org.deeplearning4j.rl4j.agent.IAgentLearner; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import java.util.function.Predicate; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java index 97795b503..8fc55bffe 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java @@ -34,10 +34,10 @@ import org.deeplearning4j.rl4j.support.MockDataManager; import org.deeplearning4j.rl4j.support.MockHistoryProcessor; import org.deeplearning4j.rl4j.support.MockMDP; import org.deeplearning4j.rl4j.support.MockObservationSpace; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; public class DataManagerTrainingListenerTest { diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java index ce97d0196..49ea6de44 100644 --- a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java +++ b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java @@ -24,11 +24,11 @@ import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.junit.Test; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; /** *