Merge pull request #8457 from chentao106/Fix#8383
Fix#8383,getMostPopulatedClusters should sort by cluster's points siz…master
commit
9592072cef
|
@ -234,7 +234,7 @@ public class ClusterSet implements Serializable {
|
|||
List<Cluster> mostPopulated = new ArrayList<>(clusters);
|
||||
Collections.sort(mostPopulated, new Comparator<Cluster>() {
|
||||
public int compare(Cluster o1, Cluster o2) {
|
||||
return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size()));
|
||||
return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
|
||||
}
|
||||
});
|
||||
return mostPopulated.subList(0, count);
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ClusterSetTest {
|
||||
@Test
|
||||
public void testGetMostPopulatedClusters() {
|
||||
ClusterSet clusterSet = new ClusterSet(false);
|
||||
List<Cluster> clusters = new ArrayList<>();
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Cluster cluster = new Cluster();
|
||||
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
|
||||
clusters.add(cluster);
|
||||
}
|
||||
clusterSet.setClusters(clusters);
|
||||
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue