Alexander Stoyakin 455a5d112d
Fixes for codegen generated classes and build improvements (#367)
* Input format extended

* Deleted redundant code

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Namespaces used instead of factory

* Regenerated base classes

* Input format extended

* Added weights format to conv2d config

* Refactoring

* dl4j base test functionality

* Different tests base class per module

* Check base class for dl4j-graph subproject tests

* Check if test classes extend BaseDL4JTest

* Use nd4j-common-tests as transient dependency

* Enums and tests added

* Added codegenerated methods

* Use namespace methods

* Replace DifferentialFunctionFactory  with codegenerated classes

* Fixed linspace

* Namespaces regenerated

* Regenerated base classes

* Regenerated namespaces

* Generate nd4j namespaces

* INDArrays accepting constructors

* Generated some ops

* Some fixes

* SameDiff ops regenerated

* Regenerated nd4j ops

* externalErrors moved

* Compilation fixes

* SquaredDifference - strict number of args

* Deprecated code cleanup. Proper base class for tests.

* Extend test classes with BaseND4JTest

* Extend test classes with BaseDL4JTest

* Legacy code

* DL4J cleanup

* Exclude test utils from base class check

* Tests fixed

* Arbiter tests fix

* Test dependency scope fix + pom.xml formatting

Signed-off-by: Alex Black <blacka101@gmail.com>

* Significant number of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another round of fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Few additional fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* DataVec missing test scope dependencies

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-20 10:27:13 +10:00

139 lines
5.5 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 10/22/16.
*/
public class NormalizationTests extends BaseSparkTest {
@Test
public void testMeanStdZeros() {
List<List<Writable>> data = new ArrayList<>();
Schema.Builder builder = new Schema.Builder();
int numColumns = 6;
for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i));
Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns);
for (int i = 0; i < 5; i++) {
List<Writable> record = new ArrayList<>(numColumns);
data.add(record);
for (int j = 0; j < numColumns; j++) {
record.add(new DoubleWritable(arr.getDouble(i, j)));
}
}
Schema schema = builder.build();
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
//assert equivalent to the ndarray pre processing
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
INDArray zeroToOnes = arr.dup();
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns());
INDArray assertion = DataFrames.toMatrix(rows);
INDArray expStd = arr.std(true, true, 0);
INDArray std = assertion.getRow(0, true);
assertTrue(expStd.equalsWithEps(std, 1e-3));
//compare mean
INDArray expMean = arr.mean(true, 0);
assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3));
}
@Test
public void normalizationTests() {
List<List<Writable>> data = new ArrayList<>();
Schema.Builder builder = new Schema.Builder();
int numColumns = 6;
for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i));
for (int i = 0; i < 5; i++) {
List<Writable> record = new ArrayList<>(numColumns);
data.add(record);
for (int j = 0; j < numColumns; j++) {
record.add(new DoubleWritable(1.0));
}
}
INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, data);
Schema schema = builder.build();
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
dataFrame.show();
Normalization.zeromeanUnitVariance(dataFrame).show();
Normalization.normalize(dataFrame).show();
//assert equivalent to the ndarray pre processing
NormalizerStandardize standardScaler = new NormalizerStandardize();
standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
INDArray standardScalered = arr.dup();
standardScaler.transform(new DataSet(standardScalered, standardScalered));
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
INDArray zeroToOnes = arr.dup();
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
INDArray zeroMeanUnitVarianceDataFrame =
RecordConverter.toMatrix(DataType.DOUBLE, Normalization.zeromeanUnitVariance(schema, rdd).collect());
INDArray zeroMeanUnitVarianceDataFrameZeroToOne =
RecordConverter.toMatrix(DataType.DOUBLE, Normalization.normalize(schema, rdd).collect());
assertEquals(standardScalered, zeroMeanUnitVarianceDataFrame);
assertTrue(zeroToOnes.equalsWithEps(zeroMeanUnitVarianceDataFrameZeroToOne, 1e-1));
}
}