diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java index 1275b8875..f39516a7b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java @@ -234,7 +234,7 @@ public class ClusterSet implements Serializable { List mostPopulated = new ArrayList<>(clusters); Collections.sort(mostPopulated, new Comparator() { public int compare(Cluster o1, Cluster o2) { - return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size())); + return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); } }); return mostPopulated.subList(0, count); diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java new file mode 100644 index 000000000..381502b93 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java @@ -0,0 +1,26 @@ +package org.deeplearning4j.clustering.cluster; + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +public class ClusterSetTest { + @Test + public void testGetMostPopulatedClusters() { + ClusterSet clusterSet = new ClusterSet(false); + List clusters = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + Cluster cluster = new Cluster(); + cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); + clusters.add(cluster); + } + clusterSet.setClusters(clusters); + List mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); + } + } +}