Disallow creating bad classification iterator (#8042)
When doing classification we need to know the `numPossibleLabels`. If it's set to -1, then we get obscure and confusing null-pointers when accessing labels when calling `ComputationGraph.fit` on the iterator. This PR blocks the user from shooting themselves in the foot.master
parent
57efc2bf95
commit
22993f853f
|
@ -160,6 +160,11 @@ public class RecordReaderDataSetIterator implements DataSetIterator {
|
||||||
*/
|
*/
|
||||||
public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo,
|
public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo,
|
||||||
boolean regression) {
|
boolean regression) {
|
||||||
|
if (!regression) {
|
||||||
|
throw new IllegalArgumentException("This constructor is only for creating regression iterators. " +
|
||||||
|
"If you're doing classification you need to use another constructor that " +
|
||||||
|
"(implicitly) specifies numPossibleLabels");
|
||||||
|
}
|
||||||
this(recordReader, new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression);
|
this(recordReader, new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue