81 lines
3.4 KiB
Java
81 lines
3.4 KiB
Java
|
|
/*******************************************************************************
|
||
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
|
*
|
||
|
|
* This program and the accompanying materials are made available under the
|
||
|
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
|
*
|
||
|
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
|
* License for the specific language governing permissions and limitations
|
||
|
|
* under the License.
|
||
|
|
*
|
||
|
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
|
******************************************************************************/
|
||
|
|
|
||
|
|
package org.deeplearning4j.datasets.iterator;
|
||
|
|
|
||
|
|
import org.deeplearning4j.BaseDL4JTest;
|
||
|
|
import org.junit.Test;
|
||
|
|
import org.nd4j.linalg.dataset.DataSet;
|
||
|
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||
|
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||
|
|
import org.nd4j.linalg.factory.Nd4j;
|
||
|
|
|
||
|
|
import static org.junit.Assert.assertArrayEquals;
|
||
|
|
import static org.junit.Assert.assertEquals;
|
||
|
|
import static org.junit.Assert.assertTrue;
|
||
|
|
|
||
|
|
public class RandomDataSetIteratorTest extends BaseDL4JTest {
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testDSI(){
|
||
|
|
DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM,
|
||
|
|
RandomDataSetIterator.Values.ONE_HOT);
|
||
|
|
|
||
|
|
int count = 0;
|
||
|
|
while(iter.hasNext()){
|
||
|
|
count++;
|
||
|
|
DataSet ds = iter.next();
|
||
|
|
|
||
|
|
assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape());
|
||
|
|
assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());
|
||
|
|
|
||
|
|
assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
|
||
|
|
assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
|
||
|
|
}
|
||
|
|
assertEquals(5, count);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testMDSI(){
|
||
|
|
Nd4j.getRandom().setSeed(12345);
|
||
|
|
MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5)
|
||
|
|
.addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
|
||
|
|
.addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
|
||
|
|
.addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
int count = 0;
|
||
|
|
while(iter.hasNext()){
|
||
|
|
count++;
|
||
|
|
MultiDataSet mds = iter.next();
|
||
|
|
|
||
|
|
assertEquals(2, mds.numFeatureArrays());
|
||
|
|
assertEquals(1, mds.numLabelsArrays());
|
||
|
|
assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape());
|
||
|
|
assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape());
|
||
|
|
assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape());
|
||
|
|
|
||
|
|
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
|
||
|
|
&& mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
|
||
|
|
assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
|
||
|
|
assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
|
||
|
|
}
|
||
|
|
assertEquals(5, count);
|
||
|
|
}
|
||
|
|
|
||
|
|
}
|