2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
* ******************************************************************************
|
|
|
|
* *
|
|
|
|
* *
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
* * under the License.
|
|
|
|
* *
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
* *****************************************************************************
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
package org.nd4j.linalg.dataset;
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
import com.google.common.collect.Lists;
|
|
|
|
import com.google.common.collect.Maps;
|
2019-06-06 15:21:15 +03:00
|
|
|
import lombok.AllArgsConstructor;
|
|
|
|
import lombok.Builder;
|
|
|
|
import lombok.Data;
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
|
|
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
|
|
|
|
|
|
|
import java.io.File;
|
|
|
|
import java.util.ArrayList;
|
|
|
|
import java.util.List;
|
|
|
|
import java.util.Map;
|
|
|
|
|
|
|
|
@AllArgsConstructor
|
|
|
|
@Builder
|
|
|
|
@Data
|
|
|
|
public class BalanceMinibatches {
|
|
|
|
private DataSetIterator dataSetIterator;
|
|
|
|
private int numLabels;
|
2022-10-06 13:22:06 +02:00
|
|
|
@Builder.Default
|
2019-06-06 15:21:15 +03:00
|
|
|
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
2022-10-06 13:22:06 +02:00
|
|
|
@Builder.Default
|
2019-06-06 15:21:15 +03:00
|
|
|
private int miniBatchSize = -1;
|
2022-10-06 13:22:06 +02:00
|
|
|
@Builder.Default
|
2019-06-06 15:21:15 +03:00
|
|
|
private File rootDir = new File("minibatches");
|
2022-10-06 13:22:06 +02:00
|
|
|
@Builder.Default
|
2019-06-06 15:21:15 +03:00
|
|
|
private File rootSaveDir = new File("minibatchessave");
|
2022-10-06 13:22:06 +02:00
|
|
|
@Builder.Default
|
2019-06-06 15:21:15 +03:00
|
|
|
private List<File> labelRootDirs = new ArrayList<>();
|
|
|
|
private DataNormalization dataNormalization;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Generate a balanced
|
|
|
|
* dataset minibatch fileset.
|
|
|
|
*/
|
|
|
|
public void balance() {
|
|
|
|
if (!rootDir.exists())
|
|
|
|
rootDir.mkdirs();
|
|
|
|
if (!rootSaveDir.exists())
|
|
|
|
rootSaveDir.mkdirs();
|
|
|
|
|
|
|
|
if (paths == null)
|
|
|
|
paths = Maps.newHashMap();
|
|
|
|
if (labelRootDirs == null)
|
|
|
|
labelRootDirs = Lists.newArrayList();
|
|
|
|
|
|
|
|
for (int i = 0; i < numLabels; i++) {
|
|
|
|
paths.put(i, new ArrayList<File>());
|
|
|
|
labelRootDirs.add(new File(rootDir, String.valueOf(i)));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//lay out each example in their respective label directories tracking the paths along the way
|
|
|
|
while (dataSetIterator.hasNext()) {
|
|
|
|
DataSet next = dataSetIterator.next();
|
|
|
|
//infer minibatch size from iterator
|
|
|
|
if (miniBatchSize < 0)
|
|
|
|
miniBatchSize = next.numExamples();
|
|
|
|
for (int i = 0; i < next.numExamples(); i++) {
|
|
|
|
DataSet currExample = next.get(i);
|
|
|
|
if (!labelRootDirs.get(currExample.outcome()).exists())
|
|
|
|
labelRootDirs.get(currExample.outcome()).mkdirs();
|
|
|
|
|
|
|
|
//individual example will be saved to: labelrootdir/examples.size()
|
|
|
|
File example = new File(labelRootDirs.get(currExample.outcome()),
|
|
|
|
String.valueOf(paths.get(currExample.outcome()).size()));
|
|
|
|
currExample.save(example);
|
|
|
|
paths.get(currExample.outcome()).add(example);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
int numsSaved = 0;
|
|
|
|
//loop till all file paths have been removed
|
|
|
|
while (!paths.isEmpty()) {
|
|
|
|
List<DataSet> miniBatch = new ArrayList<>();
|
|
|
|
while (miniBatch.size() < miniBatchSize && !paths.isEmpty()) {
|
|
|
|
for (int i = 0; i < numLabels; i++) {
|
|
|
|
if (paths.get(i) != null && !paths.get(i).isEmpty()) {
|
|
|
|
DataSet d = new DataSet();
|
|
|
|
d.load(paths.get(i).remove(0));
|
|
|
|
miniBatch.add(d);
|
|
|
|
} else
|
|
|
|
paths.remove(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!rootSaveDir.exists())
|
|
|
|
rootSaveDir.mkdirs();
|
|
|
|
//save with an incremental count of the number of minibatches saved
|
|
|
|
if (!miniBatch.isEmpty()) {
|
|
|
|
DataSet merge = DataSet.merge(miniBatch);
|
|
|
|
if (dataNormalization != null)
|
|
|
|
dataNormalization.transform(merge);
|
|
|
|
merge.save(new File(rootSaveDir, String.format("dataset-%d.bin", numsSaved++)));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|