Fixes for codegen generated classes and build improvements (#367)

* Input format extended

* Deleted redundant code

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Namespaces used instead of factory

* Regenerated base classes

* Input format extended

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Regenerated base classes

* Regenerated namespaces

* Generate nd4j namespaces

* INDArrays accepting constructors

* Generated some ops

* Some fixes

* SameDiff ops regenerated

* Regenerated nd4j ops

* externalErrors moved

* Compilation fixes

* SquaredDifference - strict number of args

* Deprecated code cleanup. Proper base class for tests.

* Extend test classes with BaseND4JTest

* Extend test classes with BaseDL4JTest

* Legacy code

* DL4J cleanup

* Exclude test utils from base class check

* Tests fixed

* Arbiter tests fix

* Test dependency scope fix + pom.xml formatting

Signed-off-by: Alex Black <blacka101@gmail.com>

* Significant number of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Few additional fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* DataVec missing test scope dependencies

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Alexander Stoyakin 2020-04-20 03:27:13 +03:00 committed by GitHub
parent 73aa760c0f
commit 455a5d112d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
358 changed files with 4531 additions and 3919 deletions

View File

@ -14,7 +14,8 @@
~ SPDX-License-Identifier: Apache-2.0 ~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>arbiter</artifactId> <artifactId>arbiter</artifactId>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
@ -33,10 +34,10 @@
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
<exclusions> <exclusions>
<exclusion> <exclusion>
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId> <artifactId>*</artifactId>
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.optimize";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.server;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.server";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.server; package org.deeplearning4j.arbiter.server;
import lombok.Data; import lombok.Data;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
@ -27,7 +28,7 @@ import java.io.IOException;
* Created by agibsonccc on 3/13/17. * Created by agibsonccc on 3/13/17.
*/ */
@Data @Data
public class MnistDataSetIteratorFactory implements DataSetIteratorFactory { public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory {
/** /**
* @return * @return
*/ */

View File

@ -17,13 +17,14 @@
package org.deeplearning4j.arbiter.server; package org.deeplearning4j.arbiter.server;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
@AllArgsConstructor @AllArgsConstructor
public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory {
private int batchSize; private int batchSize;
private int terminationIter; private int terminationIter;

View File

@ -54,6 +54,13 @@
<version>${dl4j.version}</version> <version>${dl4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${dl4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.deeplearning4j.BaseDL4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.optimize";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize; package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.MultiLayerSpace;
@ -70,7 +71,7 @@ import java.util.concurrent.TimeUnit;
/** /**
* Created by Alex on 19/07/2017. * Created by Alex on 19/07/2017.
*/ */
public class TestBasic { public class TestBasic extends BaseDL4JTest {
@Test @Test
@Ignore @Ignore

View File

@ -15,7 +15,8 @@
~ SPDX-License-Identifier: Apache-2.0 ~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>datavec-parent</artifactId> <artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
@ -79,6 +80,14 @@
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>

View File

@ -47,9 +47,8 @@ public class RecordConverter {
* *
* @return the array * @return the array
*/ */
@Deprecated public static INDArray toArray(DataType dataType, Collection<Writable> record, int size) {
public static INDArray toArray(Collection<Writable> record, int size) { return toArray(dataType, record);
return toArray(record);
} }
/** /**
@ -78,13 +77,23 @@ public class RecordConverter {
/** /**
* Convert a set of records in to a matrix * Convert a set of records in to a matrix
* As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype
* @param records the records ot convert * @param records the records ot convert
* @return the matrix for the records * @return the matrix for the records
*/ */
public static INDArray toMatrix(List<List<Writable>> records) { public static INDArray toMatrix(List<List<Writable>> records) {
return toMatrix(DataType.FLOAT, records);
}
/**
* Convert a set of records in to a matrix
* @param records the records ot convert
* @return the matrix for the records
*/
public static INDArray toMatrix(DataType dataType, List<List<Writable>> records) {
List<INDArray> toStack = new ArrayList<>(); List<INDArray> toStack = new ArrayList<>();
for(List<Writable> l : records){ for(List<Writable> l : records){
toStack.add(toArray(l)); toStack.add(toArray(dataType, l));
} }
return Nd4j.vstack(toStack); return Nd4j.vstack(toStack);
@ -92,10 +101,20 @@ public class RecordConverter {
/** /**
* Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
* As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype
* @param record the record to convert * @param record the record to convert
* @return the array * @return the array
*/ */
public static INDArray toArray(Collection<? extends Writable> record) { public static INDArray toArray(Collection<? extends Writable> record){
return toArray(DataType.FLOAT, record);
}
/**
* Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
* @param record the record to convert
* @return the array
*/
public static INDArray toArray(DataType dataType, Collection<? extends Writable> record) {
List<Writable> l; List<Writable> l;
if(record instanceof List){ if(record instanceof List){
l = (List<Writable>)record; l = (List<Writable>)record;
@ -124,7 +143,7 @@ public class RecordConverter {
} }
} }
INDArray arr = Nd4j.create(1, length); INDArray arr = Nd4j.create(dataType, 1, length);
int k = 0; int k = 0;
for (Writable w : record ) { for (Writable w : record ) {

View File

@ -0,0 +1,57 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
Set<Class<?>> res = new HashSet<>();
res.add(CustomCondition.class);
res.add(CustomFilter.class);
res.add(CustomTransform.class);
return res;
}
@Override
protected String getPackageName() {
return "org.datavec.api";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -25,6 +25,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -34,7 +35,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class CSVLineSequenceRecordReaderTest { public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -26,6 +26,8 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.BaseCompatOp;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -37,7 +39,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
public class CSVMultiSequenceRecordReaderTest { public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 19/09/2016. * Created by Alex on 19/09/2016.
*/ */
public class CSVNLinesSequenceRecordReaderTest { public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVNLinesSequenceRecordReader() throws Exception { public void testCSVNLinesSequenceRecordReader() throws Exception {

View File

@ -31,6 +31,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -44,7 +45,7 @@ import java.util.NoSuchElementException;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class CSVRecordReaderTest { public class CSVRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testNext() throws Exception { public void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader(); CSVRecordReader reader = new CSVRecordReader();

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -39,7 +40,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class CSVSequenceRecordReaderTest { public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder tempDir = new TemporaryFolder(); public TemporaryFolder tempDir = new TemporaryFolder();

View File

@ -22,6 +22,7 @@ import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordRea
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.util.LinkedList; import java.util.LinkedList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author Justin Long (crockpotveggies) * @author Justin Long (crockpotveggies)
*/ */
public class CSVVariableSlidingWindowRecordReaderTest { public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception { public void testCSVVariableSlidingWindowRecordReader() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.api.loader.FileBatch; import org.nd4j.api.loader.FileBatch;
import java.io.File; import java.io.File;
@ -36,7 +37,7 @@ import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class FileBatchRecordReaderTest { public class FileBatchRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -23,6 +23,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.net.URI; import java.net.URI;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse;
/** /**
* Created by nyghtowl on 11/14/15. * Created by nyghtowl on 11/14/15.
*/ */
public class FileRecordReaderTest { public class FileRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testReset() throws Exception { public void testReset() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -39,7 +40,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class JacksonLineRecordReaderTest { public class JacksonLineRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -30,6 +30,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -48,7 +49,7 @@ import static org.junit.Assert.assertFalse;
/** /**
* Created by Alex on 11/04/2016. * Created by Alex on 11/04/2016.
*/ */
public class JacksonRecordReaderTest { public class JacksonRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
@ -44,7 +45,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author dave@skymind.io * @author dave@skymind.io
*/ */
public class LibSvmRecordReaderTest { public class LibSvmRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { public void testBasicRecord() throws IOException, InterruptedException {

View File

@ -29,6 +29,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -48,7 +49,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by agibsonccc on 11/17/14. * Created by agibsonccc on 11/17/14.
*/ */
public class LineReaderTest { public class LineReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -32,6 +32,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -45,7 +46,7 @@ import static org.junit.Assert.assertFalse;
/** /**
* Created by Alex on 12/04/2016. * Created by Alex on 12/04/2016.
*/ */
public class RegexRecordReaderTest { public class RegexRecordReaderTest extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
@ -42,7 +43,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author dave@skymind.io * @author dave@skymind.io
*/ */
public class SVMLightRecordReaderTest { public class SVMLightRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { public void testBasicRecord() throws IOException, InterruptedException {

View File

@ -23,6 +23,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -33,7 +34,7 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 21/05/2016. * Created by Alex on 21/05/2016.
*/ */
public class TestCollectionRecordReaders { public class TestCollectionRecordReaders extends BaseND4JTest {
@Test @Test
public void testCollectionSequenceRecordReader() throws Exception { public void testCollectionSequenceRecordReader() throws Exception {

View File

@ -20,11 +20,12 @@ import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class TestConcatenatingRecordReader { public class TestConcatenatingRecordReader extends BaseND4JTest {
@Test @Test
public void test() throws Exception { public void test() throws Exception {

View File

@ -34,6 +34,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
* Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader * Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader
* and the rest are generally used without being initialized on a particular dataset * and the rest are generally used without being initialized on a particular dataset
*/ */
public class TestSerialization { public class TestSerialization extends BaseND4JTest {
@Test @Test
public void testRR() throws Exception { public void testRR() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
@ -39,7 +40,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by agibsonccc on 3/21/17. * Created by agibsonccc on 3/21/17.
*/ */
public class TransformProcessRecordReaderTests { public class TransformProcessRecordReaderTests extends BaseND4JTest {
@Test @Test
public void simpleTransformTest() throws Exception { public void simpleTransformTest() throws Exception {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class CSVRecordWriterTest { public class CSVRecordWriterTest extends BaseND4JTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author dave@skymind.io * @author dave@skymind.io
*/ */
public class LibSvmRecordWriterTest { public class LibSvmRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { public void testBasic() throws Exception {

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
@ -47,7 +48,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author dave@skymind.io * @author dave@skymind.io
*/ */
public class SVMLightRecordWriterTest { public class SVMLightRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { public void testBasic() throws Exception {

View File

@ -16,6 +16,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.filters.RandomPathFilter;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author saudet * @author saudet
*/ */
public class InputSplitTests { public class InputSplitTests extends BaseND4JTest {
@Test @Test
public void testSample() throws URISyntaxException { public void testSample() throws URISyntaxException {

View File

@ -17,13 +17,14 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.net.URI; import java.net.URI;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class NumberedFileInputSplitTests { public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test @Test
public void testNumberedFileInputSplitBasic() { public void testNumberedFileInputSplitBasic() {
String baseString = "/path/to/files/prefix%d.suffix"; String baseString = "/path/to/files/prefix%d.suffix";

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
import java.io.File; import java.io.File;
@ -40,7 +41,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
public class TestStreamInputSplit { public class TestStreamInputSplit extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -17,6 +17,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
@ -28,7 +29,7 @@ import static org.junit.Assert.assertArrayEquals;
/** /**
* @author Ede Meijer * @author Ede Meijer
*/ */
public class TransformSplitTest { public class TransformSplitTest extends BaseND4JTest {
@Test @Test
public void testTransform() throws URISyntaxException { public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));

View File

@ -16,6 +16,7 @@
package org.datavec.api.split.parittion; package org.datavec.api.split.parittion;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
@ -31,7 +32,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
public class PartitionerTests { public class PartitionerTests extends BaseND4JTest {
@Test @Test
public void testRecordsPerFilePartition() { public void testRecordsPerFilePartition() {
Partitioner partitioner = new NumberOfRecordsPartitioner(); Partitioner partitioner = new NumberOfRecordsPartitioner();

View File

@ -26,12 +26,13 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class TestTransformProcess { public class TestTransformProcess extends BaseND4JTest {
@Test @Test
public void testExecution(){ public void testExecution(){

View File

@ -24,6 +24,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
@ -33,7 +34,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by Alex on 24/03/2016. * Created by Alex on 24/03/2016.
*/ */
public class TestConditions { public class TestConditions extends BaseND4JTest {
@Test @Test
public void testIntegerCondition() { public void testIntegerCondition() {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by Alex on 21/03/2016. * Created by Alex on 21/03/2016.
*/ */
public class TestFilters { public class TestFilters extends BaseND4JTest {
@Test @Test

View File

@ -23,6 +23,7 @@ import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -33,7 +34,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 18/04/2016. * Created by Alex on 18/04/2016.
*/ */
public class TestJoin { public class TestJoin extends BaseND4JTest {
@Test @Test
public void testJoin() { public void testJoin() {

View File

@ -18,6 +18,7 @@ package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.*;
@ -27,7 +28,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by huitseeker on 5/14/17. * Created by huitseeker on 5/14/17.
*/ */
public class AggregableMultiOpTest { public class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));

View File

@ -19,6 +19,7 @@ package org.datavec.api.transform.ops;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -30,7 +31,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by huitseeker on 5/14/17. * Created by huitseeker on 5/14/17.
*/ */
public class AggregatorImplsTest { public class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));

View File

@ -18,6 +18,7 @@ package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by huitseeker on 5/14/17. * Created by huitseeker on 5/14/17.
*/ */
public class DispatchOpTest { public class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));

View File

@ -29,6 +29,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
@ -38,7 +39,7 @@ import static org.junit.Assert.fail;
/** /**
* Created by Alex on 21/03/2016. * Created by Alex on 21/03/2016.
*/ */
public class TestMultiOpReduce { public class TestMultiOpReduce extends BaseND4JTest {
@Test @Test
public void testMultiOpReducerDouble() { public void testMultiOpReducerDouble() {

View File

@ -21,13 +21,14 @@ import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class TestReductions { public class TestReductions extends BaseND4JTest {
@Test @Test
public void testGeographicMidPointReduction(){ public void testGeographicMidPointReduction(){

View File

@ -19,13 +19,14 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 18/07/2016. * Created by Alex on 18/07/2016.
*/ */
public class TestJsonYaml { public class TestJsonYaml extends BaseND4JTest {
@Test @Test
public void testToFromJsonYaml() { public void testToFromJsonYaml() {

View File

@ -18,13 +18,14 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 04/09/2016. * Created by Alex on 04/09/2016.
*/ */
public class TestSchemaMethods { public class TestSchemaMethods extends BaseND4JTest {
@Test @Test
public void testNumberedColumnAdding() { public void testNumberedColumnAdding() {

View File

@ -30,6 +30,7 @@ import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -41,7 +42,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 16/04/2016. * Created by Alex on 16/04/2016.
*/ */
public class TestReduceSequenceByWindowFunction { public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
@Test @Test
public void testReduceSequenceByWindowFunction() { public void testReduceSequenceByWindowFunction() {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -35,7 +36,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 19/04/2016. * Created by Alex on 19/04/2016.
*/ */
public class TestSequenceSplit { public class TestSequenceSplit extends BaseND4JTest {
@Test @Test
public void testSequenceSplitTimeSeparation() { public void testSequenceSplitTimeSeparation() {

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 16/04/2016. * Created by Alex on 16/04/2016.
*/ */
public class TestWindowFunctions { public class TestWindowFunctions extends BaseND4JTest {
@Test @Test
public void testTimeWindowFunction() { public void testTimeWindowFunction() {

View File

@ -23,13 +23,14 @@ import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 11/01/2017. * Created by Alex on 11/01/2017.
*/ */
public class TestCustomTransformJsonYaml { public class TestCustomTransformJsonYaml extends BaseND4JTest {
@Test @Test
public void testCustomTransform() { public void testCustomTransform() {

View File

@ -61,6 +61,7 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -70,7 +71,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 20/07/2016. * Created by Alex on 20/07/2016.
*/ */
public class TestYamlJsonSerde { public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer(); public static YamlSerializer y = new YamlSerializer();
public static JsonSerializer j = new JsonSerializer(); public static JsonSerializer j = new JsonSerializer();

View File

@ -21,6 +21,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 21/03/2016. * Created by Alex on 21/03/2016.
*/ */
public class TestReduce { public class TestReduce extends BaseND4JTest {
@Test @Test
public void testReducerDouble() { public void testReducerDouble() {

View File

@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -58,7 +59,7 @@ import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class RegressionTestJson { public class RegressionTestJson extends BaseND4JTest {
@Test @Test
public void regressionTestJson100a() throws Exception { public void regressionTestJson100a() throws Exception {

View File

@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 18/07/2016. * Created by Alex on 18/07/2016.
*/ */
public class TestJsonYaml { public class TestJsonYaml extends BaseND4JTest {
@Test @Test
public void testToFromJsonYaml() { public void testToFromJsonYaml() {

View File

@ -56,6 +56,7 @@ import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -72,7 +73,7 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 21/03/2016. * Created by Alex on 21/03/2016.
*/ */
public class TestTransforms { public class TestTransforms extends BaseND4JTest {
public static Schema getSchema(ColumnType type, String... colNames) { public static Schema getSchema(ColumnType type, String... colNames) {

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -39,7 +40,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 02/06/2017. * Created by Alex on 02/06/2017.
*/ */
public class TestNDArrayWritableTransforms { public class TestNDArrayWritableTransforms extends BaseND4JTest {
@Test @Test
public void testNDArrayWritableBasic() { public void testNDArrayWritableBasic() {

View File

@ -27,6 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer; import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 20/07/2016. * Created by Alex on 20/07/2016.
*/ */
public class TestYamlJsonSerde { public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer(); public static YamlSerializer y = new YamlSerializer();
public static JsonSerializer j = new JsonSerializer(); public static JsonSerializer j = new JsonSerializer();

View File

@ -20,6 +20,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by agibsonccc on 10/22/16. * Created by agibsonccc on 10/22/16.
*/ */
public class ParseDoubleTransformTest { public class ParseDoubleTransformTest extends BaseND4JTest {
@Test @Test
public void testDoubleTransform() { public void testDoubleTransform() {
List<Writable> record = new ArrayList<>(); List<Writable> record = new ArrayList<>();

View File

@ -35,6 +35,7 @@ import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
@ -46,7 +47,7 @@ import static org.junit.Assert.assertEquals;
/** /**
* Created by Alex on 25/03/2016. * Created by Alex on 25/03/2016.
*/ */
public class TestUI { public class TestUI extends BaseND4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -18,6 +18,7 @@ package org.datavec.api.util;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File; import java.io.File;
@ -33,7 +34,7 @@ import static org.hamcrest.core.IsEqual.equalTo;
/** /**
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class ClassPathResourceTest { public class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows

View File

@ -20,6 +20,7 @@ import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;
@ -27,7 +28,7 @@ import java.util.List;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
public class TimeSeriesUtilsTest { public class TimeSeriesUtilsTest extends BaseND4JTest {
@Test @Test
public void testTimeSeriesCreation() { public void testTimeSeriesCreation() {

View File

@ -16,6 +16,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
@ -31,7 +32,7 @@ import java.util.TimeZone;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class RecordConverterTest { public class RecordConverterTest extends BaseND4JTest {
@Test @Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
@ -86,7 +87,7 @@ public class RecordConverterTest {
new IntWritable(1)); new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT); INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
INDArray act = RecordConverter.toArray(l); INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act); assertEquals(exp, act);
} }
@ -101,7 +102,7 @@ public class RecordConverterTest {
{1,2,3,4,5}, {1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT); {6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(Arrays.asList(l1,l2)); INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
assertEquals(exp, act); assertEquals(exp, act);
} }

View File

@ -18,6 +18,7 @@ package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -28,7 +29,7 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 02/06/2017. * Created by Alex on 02/06/2017.
*/ */
public class TestNDArrayWritableAndSerialization { public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
@Test @Test
public void testIsValid() { public void testIsValid() {

View File

@ -18,6 +18,7 @@ package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -31,9 +32,7 @@ import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class WritableTest { public class WritableTest extends BaseND4JTest {
@Test @Test
public void testWritableEqualityReflexive() { public void testWritableEqualityReflexive() {

View File

@ -49,6 +49,12 @@
<artifactId>arrow-format</artifactId> <artifactId>arrow-format</artifactId>
<version>${arrow.version}</version> <version>${arrow.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -40,6 +40,7 @@ import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -56,7 +57,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
public class ArrowConverterTest { public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.BaseND4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.arrow";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -31,6 +31,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.primitives.Triple;
import java.io.File; import java.io.File;
@ -41,7 +42,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class RecordMapperTest { public class RecordMapperTest extends BaseND4JTest {
@Test @Test
public void testMultiWrite() throws Exception { public void testMultiWrite() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter; import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -35,7 +36,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests { public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -57,6 +57,13 @@
<classifier>with-dependencies</classifier> <classifier>with-dependencies</classifier>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. --> <!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!-- <!--
<dependency> <dependency>

View File

@ -0,0 +1,55 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
public long getTimeoutMilliseconds() {
return 60000;
}
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.audio";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.audio.recordreader.NativeAudioRecordReader; import org.datavec.audio.recordreader.NativeAudioRecordReader;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.ShortBuffer; import java.nio.ShortBuffer;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* @author saudet * @author saudet
*/ */
public class AudioReaderTest { public class AudioReaderTest extends BaseND4JTest {
@Ignore @Ignore
@Test @Test
public void testNativeAudioReader() throws Exception { public void testNativeAudioReader() throws Exception {

View File

@ -19,8 +19,9 @@ package org.datavec.audio;
import org.datavec.audio.dsp.FastFourierTransform; import org.datavec.audio.dsp.FastFourierTransform;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.BaseND4JTest;
public class TestFastFourierTransform { public class TestFastFourierTransform extends BaseND4JTest {
@Test @Test
public void testFastFourierTransformComplex() { public void testFastFourierTransformComplex() {

View File

@ -44,6 +44,13 @@
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. --> <!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!-- <!--
<dependency> <dependency>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.codec.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.codec.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -99,6 +99,13 @@
<artifactId>hdf5-platform</artifactId> <artifactId>hdf5-platform</artifactId>
<version>${hdf5.version}-${javacpp-presets.version}</version> <version>${hdf5.version}-${javacpp-presets.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.image;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.image";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -60,6 +60,13 @@
<version>${logback.version}</version> <version>${logback.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.nlp;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.nlp";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -15,7 +15,8 @@
~ SPDX-License-Identifier: Apache-2.0 ~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>datavec-parent</artifactId> <artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
@ -31,6 +32,12 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>com.maxmind.geoip2</groupId> <groupId>com.maxmind.geoip2</groupId>
<artifactId>geoip2</artifactId> <artifactId>geoip2</artifactId>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.api.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -60,6 +60,13 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,48 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.hadoop;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.hadoop";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -51,6 +51,13 @@
<artifactId>poi-ooxml</artifactId> <artifactId>poi-ooxml</artifactId>
<version>${poi.version}</version> <version>${poi.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.poi.excel;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.poi.excel";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -58,6 +58,13 @@
<version>${derby.version}</version> <version>${derby.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.records.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.api.records.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -81,6 +81,13 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.local.transforms;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.local.transforms";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -28,6 +28,7 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
@ -63,7 +64,7 @@ public class TestAnalyzeLocal {
list.add(rr.next()); list.add(rr.next());
} }
INDArray arr = RecordConverter.toMatrix(list); INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list);
INDArray mean = arr.mean(0); INDArray mean = arr.mean(0);
INDArray std = arr.std(0); INDArray std = arr.std(0);

View File

@ -64,6 +64,13 @@
<artifactId>nd4j-native-api</artifactId> <artifactId>nd4j-native-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.python";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -51,6 +51,13 @@
<artifactId>datavec-spark-inference-model</artifactId> <artifactId>datavec-spark-inference-model</artifactId>
<version>${project.parent.version}</version> <version>${project.parent.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.transform.client;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.transform.client";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -45,6 +45,13 @@
<artifactId>datavec-local</artifactId> <artifactId>datavec-local</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -33,6 +33,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.BatchCSVRecord;
import org.datavec.spark.transform.model.SequenceBatchCSVRecord; import org.datavec.spark.transform.model.SequenceBatchCSVRecord;
import org.datavec.spark.transform.model.SingleCSVRecord; import org.datavec.spark.transform.model.SingleCSVRecord;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.base64.Nd4jBase64;
@ -91,7 +92,7 @@ public class CSVSparkTransform {
transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema(),record.getValues()),
transformProcess.getInitialSchema()); transformProcess.getInitialSchema());
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
INDArray convert = RecordConverter.toArray(finalRecord); INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
} }

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.spark.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

Some files were not shown because too many files have changed in this diff Show More