/* * Copyright (c) 2003, the JUNG Project and the Regents of the University * of California * All rights reserved. * * This software is open-source under the BSD license; see either * "license.txt" or * http://jung.sourceforge.net/license.txt for a description. */ /* * Created on Aug 9, 2004 * */ package edu.uci.ics.jung.algorithms.util; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Random; import java.util.Set; /** * Groups items into a specified number of clusters, based on their proximity in * d-dimensional space, using the k-means algorithm. Calls to * cluster will terminate when either of the two following * conditions is true: * * * @author Joshua O'Madadhain */ public class KMeansClusterer { protected int max_iterations; protected double convergence_threshold; protected Random rand; /** * Creates an instance whose termination conditions are set according * to the parameters. */ public KMeansClusterer(int max_iterations, double convergence_threshold) { this.max_iterations = max_iterations; this.convergence_threshold = convergence_threshold; this.rand = new Random(); } /** * Creates an instance with max iterations of 100 and convergence threshold * of 0.001. */ public KMeansClusterer() { this(100, 0.001); } /** * Returns the maximum number of iterations. */ public int getMaxIterations() { return max_iterations; } /** * Sets the maximum number of iterations. */ public void setMaxIterations(int max_iterations) { if (max_iterations < 0) throw new IllegalArgumentException("max iterations must be >= 0"); this.max_iterations = max_iterations; } /** * Returns the convergence threshold. */ public double getConvergenceThreshold() { return convergence_threshold; } /** * Sets the convergence threshold. * @param convergence_threshold */ public void setConvergenceThreshold(double convergence_threshold) { if (convergence_threshold <= 0) throw new IllegalArgumentException("convergence threshold " + "must be > 0"); this.convergence_threshold = convergence_threshold; } /** * Returns a Collection of clusters, where each cluster is * represented as a Map of Objects to locations * in d-dimensional space. * @param object_locations a map of the Objects to cluster, to * double arrays that specify their locations in d-dimensional space. * @param num_clusters the number of clusters to create * @throws NotEnoughClustersException */ @SuppressWarnings("unchecked") public Collection> cluster(Map object_locations, int num_clusters) { if (object_locations == null || object_locations.isEmpty()) throw new IllegalArgumentException("'objects' must be non-empty"); if (num_clusters < 2 || num_clusters > object_locations.size()) throw new IllegalArgumentException("number of clusters " + "must be >= 2 and <= number of objects (" + object_locations.size() + ")"); Set centroids = new HashSet(); Object[] obj_array = object_locations.keySet().toArray(); Set tried = new HashSet(); // create the specified number of clusters while (centroids.size() < num_clusters && tried.size() < object_locations.size()) { T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)]; tried.add(o); double[] mean_value = object_locations.get(o); boolean duplicate = false; for (double[] cur : centroids) { if (Arrays.equals(mean_value, cur)) duplicate = true; } if (!duplicate) centroids.add(mean_value); } if (tried.size() >= object_locations.size()) throw new NotEnoughClustersException(); // put items in their initial clusters Map> clusterMap = assignToClusters(object_locations, centroids); // keep reconstituting clusters until either // (a) membership is stable, or // (b) number of iterations passes max_iterations, or // (c) max movement of any centroid is <= convergence_threshold int iterations = 0; double max_movement = Double.POSITIVE_INFINITY; while (iterations++ < max_iterations && max_movement > convergence_threshold) { max_movement = 0; Set new_centroids = new HashSet(); // calculate new mean for each cluster for (Map.Entry> entry : clusterMap.entrySet()) { double[] centroid = entry.getKey(); Map elements = entry.getValue(); ArrayList locations = new ArrayList(elements.values()); double[] mean = DiscreteDistribution.mean(locations); max_movement = Math.max(max_movement, Math.sqrt(DiscreteDistribution.squaredError(centroid, mean))); new_centroids.add(mean); } // TODO: check membership of clusters: have they changed? // regenerate cluster membership based on means clusterMap = assignToClusters(object_locations, new_centroids); } return clusterMap.values(); } /** * Assigns each object to the cluster whose centroid is closest to the * object. * @param object_locations a map of objects to locations * @param centroids the centroids of the clusters to be formed * @return a map of objects to assigned clusters */ protected Map> assignToClusters(Map object_locations, Set centroids) { Map> clusterMap = new HashMap>(); for (double[] centroid : centroids) clusterMap.put(centroid, new HashMap()); for (Map.Entry object_location : object_locations.entrySet()) { T object = object_location.getKey(); double[] location = object_location.getValue(); // find the cluster with the closest centroid Iterator c_iter = centroids.iterator(); double[] closest = c_iter.next(); double distance = DiscreteDistribution.squaredError(location, closest); while (c_iter.hasNext()) { double[] centroid = c_iter.next(); double dist_cur = DiscreteDistribution.squaredError(location, centroid); if (dist_cur < distance) { distance = dist_cur; closest = centroid; } } clusterMap.get(closest).put(object, location); } return clusterMap; } /** * Sets the seed used by the internal random number generator. * Enables consistent outputs. */ public void setSeed(int random_seed) { this.rand = new Random(random_seed); } /** * An exception that indicates that the specified data points cannot be * clustered into the number of clusters requested by the user. * This will happen if and only if there are fewer distinct points than * requested clusters. (If there are fewer total data points than * requested clusters, IllegalArgumentException will be thrown.) * * @author Joshua O'Madadhain */ @SuppressWarnings("serial") public static class NotEnoughClustersException extends RuntimeException { @Override public String getMessage() { return "Not enough distinct points in the input data set to form " + "the requested number of clusters"; } } }