From c01496ce5d4124632a4628eaa569c8d8364f958e Mon Sep 17 00:00:00 2001 From: CT Date: Wed, 27 Nov 2019 22:35:55 +0800 Subject: [PATCH] =?UTF-8?q?Fix#8383=EF=BC=8CgetMostPopulatedClusters=20sho?= =?UTF-8?q?uld=20sort=20by=20cluster's=20points=20size=20desc.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../clustering/cluster/ClusterSet.java | 2 +- .../clustering/cluster/ClusterSetTest.java | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java index 1275b8875..f39516a7b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java @@ -234,7 +234,7 @@ public class ClusterSet implements Serializable { List mostPopulated = new ArrayList<>(clusters); Collections.sort(mostPopulated, new Comparator() { public int compare(Cluster o1, Cluster o2) { - return new Integer(o1.getPoints().size()).compareTo(new Integer(o2.getPoints().size())); + return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); } }); return mostPopulated.subList(0, count); diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java new file mode 100644 index 000000000..381502b93 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java @@ -0,0 +1,26 @@ +package org.deeplearning4j.clustering.cluster; + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +public class ClusterSetTest { + @Test + public void testGetMostPopulatedClusters() { + ClusterSet clusterSet = new ClusterSet(false); + List clusters = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + Cluster cluster = new Cluster(); + cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); + clusters.add(cluster); + } + clusterSet.setClusters(clusters); + List mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); + } + } +}