Fixes for codegen generated classes and build improvements (#367)

* Input format extended

* Deleted redundant code

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Namespaces used instead of factory

* Regenerated base classes

* Input format extended

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Regenerated base classes

* Regenerated namespaces

* Generate nd4j namespaces

* INDArrays accepting constructors

* Generated some ops

* Some fixes

* SameDiff ops regenerated

* Regenerated nd4j ops

* externalErrors moved

* Compilation fixes

* SquaredDifference - strict number of args

* Deprecated code cleanup. Proper base class for tests.

* Extend test classes with BaseND4JTest

* Extend test classes with BaseDL4JTest

* Legacy code

* DL4J cleanup

* Exclude test utils from base class check

* Tests fixed

* Arbiter tests fix

* Test dependency scope fix + pom.xml formatting

Signed-off-by: Alex Black <blacka101@gmail.com>

* Significant number of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Few additional fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* DataVec missing test scope dependencies

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Alexander Stoyakin 2020-04-20 03:27:13 +03:00 committed by GitHub
parent 73aa760c0f
commit 455a5d112d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
358 changed files with 4531 additions and 3919 deletions

View File

@ -14,7 +14,8 @@
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>arbiter</artifactId>
<groupId>org.deeplearning4j</groupId>
@ -33,10 +34,10 @@
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.optimize";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.server;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.server";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.server;
import lombok.Data;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
@ -27,7 +28,7 @@ import java.io.IOException;
* Created by agibsonccc on 3/13/17.
*/
@Data
public class MnistDataSetIteratorFactory implements DataSetIteratorFactory {
public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory {
/**
* @return
*/

View File

@ -17,13 +17,14 @@
package org.deeplearning4j.arbiter.server;
import lombok.AllArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
@AllArgsConstructor
public class TestDataFactoryProviderMnist implements DataSetIteratorFactory {
public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory {
private int batchSize;
private int terminationIter;

View File

@ -54,6 +54,13 @@
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${dl4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.deeplearning4j.BaseDL4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.deeplearning4j.arbiter.optimize";
}
@Override
protected Class<?> getBaseClass() {
return BaseDL4JTest.class;
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
@ -70,7 +71,7 @@ import java.util.concurrent.TimeUnit;
/**
* Created by Alex on 19/07/2017.
*/
public class TestBasic {
public class TestBasic extends BaseDL4JTest {
@Test
@Ignore

View File

@ -15,7 +15,8 @@
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
@ -79,6 +80,14 @@
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>

View File

@ -47,9 +47,8 @@ public class RecordConverter {
*
* @return the array
*/
@Deprecated
public static INDArray toArray(Collection<Writable> record, int size) {
return toArray(record);
public static INDArray toArray(DataType dataType, Collection<Writable> record, int size) {
return toArray(dataType, record);
}
/**
@ -78,13 +77,23 @@ public class RecordConverter {
/**
* Convert a set of records in to a matrix
* As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype
* @param records the records ot convert
* @return the matrix for the records
*/
public static INDArray toMatrix(List<List<Writable>> records) {
return toMatrix(DataType.FLOAT, records);
}
/**
* Convert a set of records in to a matrix
* @param records the records ot convert
* @return the matrix for the records
*/
public static INDArray toMatrix(DataType dataType, List<List<Writable>> records) {
List<INDArray> toStack = new ArrayList<>();
for(List<Writable> l : records){
toStack.add(toArray(l));
toStack.add(toArray(dataType, l));
}
return Nd4j.vstack(toStack);
@ -92,10 +101,20 @@ public class RecordConverter {
/**
* Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
* As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype
* @param record the record to convert
* @return the array
*/
public static INDArray toArray(Collection<? extends Writable> record) {
public static INDArray toArray(Collection<? extends Writable> record){
return toArray(DataType.FLOAT, record);
}
/**
* Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
* @param record the record to convert
* @return the array
*/
public static INDArray toArray(DataType dataType, Collection<? extends Writable> record) {
List<Writable> l;
if(record instanceof List){
l = (List<Writable>)record;
@ -124,7 +143,7 @@ public class RecordConverter {
}
}
INDArray arr = Nd4j.create(1, length);
INDArray arr = Nd4j.create(dataType, 1, length);
int k = 0;
for (Writable w : record ) {

View File

@ -0,0 +1,57 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
Set<Class<?>> res = new HashSet<>();
res.add(CustomCondition.class);
res.add(CustomFilter.class);
res.add(CustomTransform.class);
return res;
}
@Override
protected String getPackageName() {
return "org.datavec.api";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -25,6 +25,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import java.io.File;
import java.nio.charset.StandardCharsets;
@ -34,7 +35,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
public class CSVLineSequenceRecordReaderTest {
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -26,6 +26,8 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.BaseCompatOp;
import java.io.File;
import java.nio.charset.StandardCharsets;
@ -37,7 +39,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
public class CSVMultiSequenceRecordReaderTest {
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.util.ArrayList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 19/09/2016.
*/
public class CSVNLinesSequenceRecordReaderTest {
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVNLinesSequenceRecordReader() throws Exception {

View File

@ -31,6 +31,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
@ -44,7 +45,7 @@ import java.util.NoSuchElementException;
import static org.junit.Assert.*;
public class CSVRecordReaderTest {
public class CSVRecordReaderTest extends BaseND4JTest {
@Test
public void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader();

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
@ -39,7 +40,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
public class CSVSequenceRecordReaderTest {
public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();

View File

@ -22,6 +22,7 @@ import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordRea
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.util.LinkedList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
*
* @author Justin Long (crockpotveggies)
*/
public class CSVVariableSlidingWindowRecordReaderTest {
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.api.loader.FileBatch;
import java.io.File;
@ -36,7 +37,7 @@ import java.util.List;
import static org.junit.Assert.*;
public class FileBatchRecordReaderTest {
public class FileBatchRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -23,6 +23,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.net.URI;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse;
/**
* Created by nyghtowl on 11/14/15.
*/
public class FileRecordReaderTest {
public class FileRecordReaderTest extends BaseND4JTest {
@Test
public void testReset() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -39,7 +40,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
public class JacksonLineRecordReaderTest {
public class JacksonLineRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -30,6 +30,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -48,7 +49,7 @@ import static org.junit.Assert.assertFalse;
/**
* Created by Alex on 11/04/2016.
*/
public class JacksonRecordReaderTest {
public class JacksonRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.IOException;
@ -44,7 +45,7 @@ import static org.junit.Assert.assertEquals;
*
* @author dave@skymind.io
*/
public class LibSvmRecordReaderTest {
public class LibSvmRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {

View File

@ -29,6 +29,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -48,7 +49,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 11/17/14.
*/
public class LineReaderTest {
public class LineReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -32,6 +32,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
@ -45,7 +46,7 @@ import static org.junit.Assert.assertFalse;
/**
* Created by Alex on 12/04/2016.
*/
public class RegexRecordReaderTest {
public class RegexRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.IOException;
@ -42,7 +43,7 @@ import static org.junit.Assert.assertEquals;
*
* @author dave@skymind.io
*/
public class SVMLightRecordReaderTest {
public class SVMLightRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {

View File

@ -23,6 +23,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -33,7 +34,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 21/05/2016.
*/
public class TestCollectionRecordReaders {
public class TestCollectionRecordReaders extends BaseND4JTest {
@Test
public void testCollectionSequenceRecordReader() throws Exception {

View File

@ -20,11 +20,12 @@ import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import static org.junit.Assert.assertEquals;
public class TestConcatenatingRecordReader {
public class TestConcatenatingRecordReader extends BaseND4JTest {
@Test
public void test() throws Exception {

View File

@ -34,6 +34,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
* Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader
* and the rest are generally used without being initialized on a particular dataset
*/
public class TestSerialization {
public class TestSerialization extends BaseND4JTest {
@Test
public void testRR() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.util.ArrayList;
@ -39,7 +40,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by agibsonccc on 3/21/17.
*/
public class TransformProcessRecordReaderTests {
public class TransformProcessRecordReaderTests extends BaseND4JTest {
@Test
public void simpleTransformTest() throws Exception {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.File;
import java.util.ArrayList;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
/**
* @author raver119@gmail.com
*/
public class CSVRecordWriterTest {
public class CSVRecordWriterTest extends BaseND4JTest {
@Before
public void setUp() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
*
* @author dave@skymind.io
*/
public class LibSvmRecordWriterTest {
public class LibSvmRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*;
import org.datavec.api.writable.NDArrayWritable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
@ -47,7 +48,7 @@ import static org.junit.Assert.assertEquals;
*
* @author dave@skymind.io
*/
public class SVMLightRecordWriterTest {
public class SVMLightRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {

View File

@ -16,6 +16,7 @@
package org.datavec.api.split;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
*
* @author saudet
*/
public class InputSplitTests {
public class InputSplitTests extends BaseND4JTest {
@Test
public void testSample() throws URISyntaxException {

View File

@ -17,13 +17,14 @@
package org.datavec.api.split;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.net.URI;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class NumberedFileInputSplitTests {
public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test
public void testNumberedFileInputSplitBasic() {
String baseString = "/path/to/files/prefix%d.suffix";

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.function.Function;
import java.io.File;
@ -40,7 +41,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
public class TestStreamInputSplit {
public class TestStreamInputSplit extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -17,6 +17,7 @@
package org.datavec.api.split;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.net.URI;
import java.net.URISyntaxException;
@ -28,7 +29,7 @@ import static org.junit.Assert.assertArrayEquals;
/**
* @author Ede Meijer
*/
public class TransformSplitTest {
public class TransformSplitTest extends BaseND4JTest {
@Test
public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));

View File

@ -16,6 +16,7 @@
package org.datavec.api.split.parittion;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit;
@ -31,7 +32,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class PartitionerTests {
public class PartitionerTests extends BaseND4JTest {
@Test
public void testRecordsPerFilePartition() {
Partitioner partitioner = new NumberOfRecordsPartitioner();

View File

@ -26,12 +26,13 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
public class TestTransformProcess {
public class TestTransformProcess extends BaseND4JTest {
@Test
public void testExecution(){

View File

@ -24,6 +24,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
@ -33,7 +34,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by Alex on 24/03/2016.
*/
public class TestConditions {
public class TestConditions extends BaseND4JTest {
@Test
public void testIntegerCondition() {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by Alex on 21/03/2016.
*/
public class TestFilters {
public class TestFilters extends BaseND4JTest {
@Test

View File

@ -23,6 +23,7 @@ import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -33,7 +34,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 18/04/2016.
*/
public class TestJoin {
public class TestJoin extends BaseND4JTest {
@Test
public void testJoin() {

View File

@ -18,6 +18,7 @@ package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.Serializable;
import java.util.*;
@ -27,7 +28,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by huitseeker on 5/14/17.
*/
public class AggregableMultiOpTest {
public class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));

View File

@ -19,6 +19,7 @@ package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -30,7 +31,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by huitseeker on 5/14/17.
*/
public class AggregatorImplsTest {
public class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));

View File

@ -18,6 +18,7 @@ package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by huitseeker on 5/14/17.
*/
public class DispatchOpTest {
public class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));

View File

@ -29,6 +29,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
@ -38,7 +39,7 @@ import static org.junit.Assert.fail;
/**
* Created by Alex on 21/03/2016.
*/
public class TestMultiOpReduce {
public class TestMultiOpReduce extends BaseND4JTest {
@Test
public void testMultiOpReducerDouble() {

View File

@ -21,13 +21,14 @@ import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class TestReductions {
public class TestReductions extends BaseND4JTest {
@Test
public void testGeographicMidPointReduction(){

View File

@ -19,13 +19,14 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 18/07/2016.
*/
public class TestJsonYaml {
public class TestJsonYaml extends BaseND4JTest {
@Test
public void testToFromJsonYaml() {

View File

@ -18,13 +18,14 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 04/09/2016.
*/
public class TestSchemaMethods {
public class TestSchemaMethods extends BaseND4JTest {
@Test
public void testNumberedColumnAdding() {

View File

@ -30,6 +30,7 @@ import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -41,7 +42,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 16/04/2016.
*/
public class TestReduceSequenceByWindowFunction {
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
@Test
public void testReduceSequenceByWindowFunction() {

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -35,7 +36,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 19/04/2016.
*/
public class TestSequenceSplit {
public class TestSequenceSplit extends BaseND4JTest {
@Test
public void testSequenceSplitTimeSeparation() {

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 16/04/2016.
*/
public class TestWindowFunctions {
public class TestWindowFunctions extends BaseND4JTest {
@Test
public void testTimeWindowFunction() {

View File

@ -23,13 +23,14 @@ import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 11/01/2017.
*/
public class TestCustomTransformJsonYaml {
public class TestCustomTransformJsonYaml extends BaseND4JTest {
@Test
public void testCustomTransform() {

View File

@ -61,6 +61,7 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
@ -70,7 +71,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 20/07/2016.
*/
public class TestYamlJsonSerde {
public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer();
public static JsonSerializer j = new JsonSerializer();

View File

@ -21,6 +21,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 21/03/2016.
*/
public class TestReduce {
public class TestReduce extends BaseND4JTest {
@Test
public void testReducerDouble() {

View File

@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
@ -58,7 +59,7 @@ import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
public class RegressionTestJson {
public class RegressionTestJson extends BaseND4JTest {
@Test
public void regressionTestJson100a() throws Exception {

View File

@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 18/07/2016.
*/
public class TestJsonYaml {
public class TestJsonYaml extends BaseND4JTest {
@Test
public void testToFromJsonYaml() {

View File

@ -56,6 +56,7 @@ import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -72,7 +73,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 21/03/2016.
*/
public class TestTransforms {
public class TestTransforms extends BaseND4JTest {
public static Schema getSchema(ColumnType type, String... colNames) {

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -39,7 +40,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 02/06/2017.
*/
public class TestNDArrayWritableTransforms {
public class TestNDArrayWritableTransforms extends BaseND4JTest {
@Test
public void testNDArrayWritableBasic() {

View File

@ -27,6 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 20/07/2016.
*/
public class TestYamlJsonSerde {
public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer();
public static JsonSerializer j = new JsonSerializer();

View File

@ -20,6 +20,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 10/22/16.
*/
public class ParseDoubleTransformTest {
public class ParseDoubleTransformTest extends BaseND4JTest {
@Test
public void testDoubleTransform() {
List<Writable> record = new ArrayList<>();

View File

@ -35,6 +35,7 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import java.io.File;
import java.util.ArrayList;
@ -46,7 +47,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 25/03/2016.
*/
public class TestUI {
public class TestUI extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -18,6 +18,7 @@ package org.datavec.api.util;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.BufferedReader;
import java.io.File;
@ -33,7 +34,7 @@ import static org.hamcrest.core.IsEqual.equalTo;
/**
* @author raver119@gmail.com
*/
public class ClassPathResourceTest {
public class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows

View File

@ -20,6 +20,7 @@ import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
@ -27,7 +28,7 @@ import java.util.List;
import static org.junit.Assert.assertArrayEquals;
public class TimeSeriesUtilsTest {
public class TimeSeriesUtilsTest extends BaseND4JTest {
@Test
public void testTimeSeriesCreation() {

View File

@ -16,6 +16,7 @@
package org.datavec.api.writable;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
@ -31,7 +32,7 @@ import java.util.TimeZone;
import static org.junit.Assert.assertEquals;
public class RecordConverterTest {
public class RecordConverterTest extends BaseND4JTest {
@Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
@ -86,7 +87,7 @@ public class RecordConverterTest {
new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
INDArray act = RecordConverter.toArray(l);
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act);
}
@ -101,7 +102,7 @@ public class RecordConverterTest {
{1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(Arrays.asList(l1,l2));
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
assertEquals(exp, act);
}

View File

@ -18,6 +18,7 @@ package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -28,7 +29,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 02/06/2017.
*/
public class TestNDArrayWritableAndSerialization {
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
@Test
public void testIsValid() {

View File

@ -18,6 +18,7 @@ package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -31,9 +32,7 @@ import java.util.List;
import static org.junit.Assert.*;
public class WritableTest {
public class WritableTest extends BaseND4JTest {
@Test
public void testWritableEqualityReflexive() {

View File

@ -49,6 +49,12 @@
<artifactId>arrow-format</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -40,6 +40,7 @@ import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
@ -56,7 +57,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
public class ArrowConverterTest {
public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.BaseND4JTest;
import org.nd4j.AbstractAssertTestsClass;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.arrow";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -31,6 +31,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.primitives.Triple;
import java.io.File;
@ -41,7 +42,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
public class RecordMapperTest {
public class RecordMapperTest extends BaseND4JTest {
@Test
public void testMultiWrite() throws Exception {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
@ -35,7 +36,7 @@ import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests {
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -57,6 +57,13 @@
<classifier>with-dependencies</classifier>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>

View File

@ -0,0 +1,55 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
public long getTimeoutMilliseconds() {
return 60000;
}
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.audio";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.audio.recordreader.NativeAudioRecordReader;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import java.io.File;
import java.nio.ShortBuffer;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertTrue;
/**
* @author saudet
*/
public class AudioReaderTest {
public class AudioReaderTest extends BaseND4JTest {
@Ignore
@Test
public void testNativeAudioReader() throws Exception {

View File

@ -19,8 +19,9 @@ package org.datavec.audio;
import org.datavec.audio.dsp.FastFourierTransform;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
public class TestFastFourierTransform {
public class TestFastFourierTransform extends BaseND4JTest {
@Test
public void testFastFourierTransformComplex() {

View File

@ -44,6 +44,13 @@
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!--
<dependency>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.codec.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.codec.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -99,6 +99,13 @@
<artifactId>hdf5-platform</artifactId>
<version>${hdf5.version}-${javacpp-presets.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.image;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.image";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -60,6 +60,13 @@
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.nlp;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.nlp";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -15,7 +15,8 @@
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
@ -31,6 +32,12 @@
<artifactId>datavec-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.maxmind.geoip2</groupId>
<artifactId>geoip2</artifactId>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.api.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -60,6 +60,13 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,48 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.hadoop;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.hadoop";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -51,6 +51,13 @@
<artifactId>poi-ooxml</artifactId>
<version>${poi.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.poi.excel;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.poi.excel";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -58,6 +58,13 @@
<version>${derby.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.records.reader;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.api.records.reader";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -81,6 +81,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.local.transforms;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.local.transforms";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -28,6 +28,7 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource;
@ -63,7 +64,7 @@ public class TestAnalyzeLocal {
list.add(rr.next());
}
INDArray arr = RecordConverter.toMatrix(list);
INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list);
INDArray mean = arr.mean(0);
INDArray std = arr.std(0);

View File

@ -64,6 +64,13 @@
<artifactId>nd4j-native-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseND4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.python";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -51,6 +51,13 @@
<artifactId>datavec-spark-inference-model</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -0,0 +1,49 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.transform.client;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.transform.client";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -45,6 +45,13 @@
<artifactId>datavec-local</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -33,6 +33,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
import org.datavec.spark.transform.model.BatchCSVRecord;
import org.datavec.spark.transform.model.SequenceBatchCSVRecord;
import org.datavec.spark.transform.model.SingleCSVRecord;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
@ -91,7 +92,7 @@ public class CSVSparkTransform {
transformProcess.getInitialSchema(),record.getValues()),
transformProcess.getInitialSchema());
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
INDArray convert = RecordConverter.toArray(finalRecord);
INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
}

View File

@ -0,0 +1,50 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.AbstractAssertTestsClass;
import org.nd4j.BaseND4JTest;
import java.util.*;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.spark.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

Some files were not shown because too many files have changed in this diff Show More