Fix#8383,getMostPopulatedClusters should sort by cluster's points size desc.

master
CT 2019-11-27 22:35:55 +08:00
parent 47c58cf69d
commit c01496ce5d
2 changed files with 27 additions and 1 deletions

View File

@ -234,7 +234,7 @@ public class ClusterSet implements Serializable {
List<Cluster> mostPopulated = new ArrayList<>(clusters); List<Cluster> mostPopulated = new ArrayList<>(clusters);
Collections.sort(mostPopulated, new Comparator<Cluster>() { Collections.sort(mostPopulated, new Comparator<Cluster>() {
public int compare(Cluster o1, Cluster o2) { public int compare(Cluster o1, Cluster o2) {
return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size())); return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
} }
}); });
return mostPopulated.subList(0, count); return mostPopulated.subList(0, count);

View File

@ -0,0 +1,26 @@
package org.deeplearning4j.clustering.cluster;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
public class ClusterSetTest {
@Test
public void testGetMostPopulatedClusters() {
ClusterSet clusterSet = new ClusterSet(false);
List<Cluster> clusters = new ArrayList<>();
for (int i = 0; i < 5; i++) {
Cluster cluster = new Cluster();
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
clusters.add(cluster);
}
clusterSet.setClusters(clusters);
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
for (int i = 0; i < 5; i++) {
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
}
}
}