Disallow creating bad classification iterator (#8042)

When doing classification we need to know the `numPossibleLabels`. If it's set to -1, then we get obscure and confusing null-pointers when accessing labels when calling `ComputationGraph.fit` on the iterator. This PR blocks the user from shooting themselves in the foot.
master
jxtps 2019-07-24 21:51:57 -07:00 committed by Alex Black
parent 57efc2bf95
commit 22993f853f
1 changed files with 5 additions and 0 deletions

View File

@ -160,6 +160,11 @@ public class RecordReaderDataSetIterator implements DataSetIterator {
*/ */
public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo, public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo,
boolean regression) { boolean regression) {
if (!regression) {
throw new IllegalArgumentException("This constructor is only for creating regression iterators. " +
"If you're doing classification you need to use another constructor that " +
"(implicitly) specifies numPossibleLabels");
}
this(recordReader, new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression); this(recordReader, new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression);
} }