Fix#8383,getMostPopulatedClusters should sort by cluster's points size desc.
parent
47c58cf69d
commit
c01496ce5d
|
@ -234,7 +234,7 @@ public class ClusterSet implements Serializable {
|
||||||
List<Cluster> mostPopulated = new ArrayList<>(clusters);
|
List<Cluster> mostPopulated = new ArrayList<>(clusters);
|
||||||
Collections.sort(mostPopulated, new Comparator<Cluster>() {
|
Collections.sort(mostPopulated, new Comparator<Cluster>() {
|
||||||
public int compare(Cluster o1, Cluster o2) {
|
public int compare(Cluster o1, Cluster o2) {
|
||||||
return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size()));
|
return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return mostPopulated.subList(0, count);
|
return mostPopulated.subList(0, count);
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package org.deeplearning4j.clustering.cluster;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class ClusterSetTest {
|
||||||
|
@Test
|
||||||
|
public void testGetMostPopulatedClusters() {
|
||||||
|
ClusterSet clusterSet = new ClusterSet(false);
|
||||||
|
List<Cluster> clusters = new ArrayList<>();
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
Cluster cluster = new Cluster();
|
||||||
|
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
|
||||||
|
clusters.add(cluster);
|
||||||
|
}
|
||||||
|
clusterSet.setClusters(clusters);
|
||||||
|
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue