2 * Copyright (c) 2003, the JUNG Project and the Regents of the University
6 * This software is open-source under the BSD license; see either
8 * http://jung.sourceforge.net/license.txt for a description.
11 * Created on Aug 9, 2004
14 package edu.uci.ics.jung.algorithms.util;
16 import java.util.ArrayList;
17 import java.util.Arrays;
18 import java.util.Collection;
19 import java.util.HashMap;
20 import java.util.HashSet;
21 import java.util.Iterator;
23 import java.util.Random;
29 * Groups items into a specified number of clusters, based on their proximity in
30 * d-dimensional space, using the k-means algorithm. Calls to
31 * <code>cluster</code> will terminate when either of the two following
34 * <li/>the number of iterations is > <code>max_iterations</code>
35 * <li/>none of the centroids has moved as much as <code>convergence_threshold</code>
36 * since the previous iteration
39 * @author Joshua O'Madadhain
41 public class KMeansClusterer<T>
43 protected int max_iterations;
44 protected double convergence_threshold;
45 protected Random rand;
48 * Creates an instance whose termination conditions are set according
51 public KMeansClusterer(int max_iterations, double convergence_threshold)
53 this.max_iterations = max_iterations;
54 this.convergence_threshold = convergence_threshold;
55 this.rand = new Random();
59 * Creates an instance with max iterations of 100 and convergence threshold
62 public KMeansClusterer()
68 * Returns the maximum number of iterations.
70 public int getMaxIterations()
72 return max_iterations;
76 * Sets the maximum number of iterations.
78 public void setMaxIterations(int max_iterations)
80 if (max_iterations < 0)
81 throw new IllegalArgumentException("max iterations must be >= 0");
83 this.max_iterations = max_iterations;
87 * Returns the convergence threshold.
89 public double getConvergenceThreshold()
91 return convergence_threshold;
95 * Sets the convergence threshold.
96 * @param convergence_threshold
98 public void setConvergenceThreshold(double convergence_threshold)
100 if (convergence_threshold <= 0)
101 throw new IllegalArgumentException("convergence threshold " +
104 this.convergence_threshold = convergence_threshold;
108 * Returns a <code>Collection</code> of clusters, where each cluster is
109 * represented as a <code>Map</code> of <code>Objects</code> to locations
110 * in d-dimensional space.
111 * @param object_locations a map of the Objects to cluster, to
112 * <code>double</code> arrays that specify their locations in d-dimensional space.
113 * @param num_clusters the number of clusters to create
114 * @throws NotEnoughClustersException
116 @SuppressWarnings("unchecked")
117 public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
119 if (object_locations == null || object_locations.isEmpty())
120 throw new IllegalArgumentException("'objects' must be non-empty");
122 if (num_clusters < 2 || num_clusters > object_locations.size())
123 throw new IllegalArgumentException("number of clusters " +
124 "must be >= 2 and <= number of objects (" +
125 object_locations.size() + ")");
128 Set<double[]> centroids = new HashSet<double[]>();
130 Object[] obj_array = object_locations.keySet().toArray();
131 Set<T> tried = new HashSet<T>();
133 // create the specified number of clusters
134 while (centroids.size() < num_clusters && tried.size() < object_locations.size())
136 T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
138 double[] mean_value = object_locations.get(o);
139 boolean duplicate = false;
140 for (double[] cur : centroids)
142 if (Arrays.equals(mean_value, cur))
146 centroids.add(mean_value);
149 if (tried.size() >= object_locations.size())
150 throw new NotEnoughClustersException();
152 // put items in their initial clusters
153 Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
155 // keep reconstituting clusters until either
156 // (a) membership is stable, or
157 // (b) number of iterations passes max_iterations, or
158 // (c) max movement of any centroid is <= convergence_threshold
160 double max_movement = Double.POSITIVE_INFINITY;
161 while (iterations++ < max_iterations && max_movement > convergence_threshold)
164 Set<double[]> new_centroids = new HashSet<double[]>();
165 // calculate new mean for each cluster
166 for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
168 double[] centroid = entry.getKey();
169 Map<T, double[]> elements = entry.getValue();
170 ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
172 double[] mean = DiscreteDistribution.mean(locations);
173 max_movement = Math.max(max_movement,
174 Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
175 new_centroids.add(mean);
178 // TODO: check membership of clusters: have they changed?
180 // regenerate cluster membership based on means
181 clusterMap = assignToClusters(object_locations, new_centroids);
183 return clusterMap.values();
187 * Assigns each object to the cluster whose centroid is closest to the
189 * @param object_locations a map of objects to locations
190 * @param centroids the centroids of the clusters to be formed
191 * @return a map of objects to assigned clusters
193 protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
195 Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
196 for (double[] centroid : centroids)
197 clusterMap.put(centroid, new HashMap<T, double[]>());
199 for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
201 T object = object_location.getKey();
202 double[] location = object_location.getValue();
204 // find the cluster with the closest centroid
205 Iterator<double[]> c_iter = centroids.iterator();
206 double[] closest = c_iter.next();
207 double distance = DiscreteDistribution.squaredError(location, closest);
209 while (c_iter.hasNext())
211 double[] centroid = c_iter.next();
212 double dist_cur = DiscreteDistribution.squaredError(location, centroid);
213 if (dist_cur < distance)
219 clusterMap.get(closest).put(object, location);
226 * Sets the seed used by the internal random number generator.
227 * Enables consistent outputs.
229 public void setSeed(int random_seed)
231 this.rand = new Random(random_seed);
235 * An exception that indicates that the specified data points cannot be
236 * clustered into the number of clusters requested by the user.
237 * This will happen if and only if there are fewer distinct points than
238 * requested clusters. (If there are fewer total data points than
239 * requested clusters, <code>IllegalArgumentException</code> will be thrown.)
241 * @author Joshua O'Madadhain
243 @SuppressWarnings("serial")
244 public static class NotEnoughClustersException extends RuntimeException
247 public String getMessage()
249 return "Not enough distinct points in the input data set to form " +
250 "the requested number of clusters";