81 lines
3.4 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class RandomDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testDSI(){
DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM,
RandomDataSetIterator.Values.ONE_HOT);
int count = 0;
while(iter.hasNext()){
count++;
DataSet ds = iter.next();
assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape());
assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());
assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
}
assertEquals(5, count);
}
@Test
public void testMDSI(){
Nd4j.getRandom().setSeed(12345);
MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5)
.addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
.addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
.addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
.build();
int count = 0;
while(iter.hasNext()){
count++;
MultiDataSet mds = iter.next();
assertEquals(2, mds.numFeatureArrays());
assertEquals(1, mds.numLabelsArrays());
assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape());
assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape());
assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape());
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
&& mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
}
assertEquals(5, count);
}
}