Fix checkstyle warnings in netconf-cli
[controller.git] / third-party / net.sf.jung2 / src / main / java / edu / uci / ics / jung / algorithms / util / KMeansClusterer.java
1 /*
2  * Copyright (c) 2003, the JUNG Project and the Regents of the University
3  * of California
4  * All rights reserved.
5  *
6  * This software is open-source under the BSD license; see either
7  * "license.txt" or
8  * http://jung.sourceforge.net/license.txt for a description.
9  */
10 /*
11  * Created on Aug 9, 2004
12  *
13  */
14 package edu.uci.ics.jung.algorithms.util;
15
16 import java.util.ArrayList;
17 import java.util.Arrays;
18 import java.util.Collection;
19 import java.util.HashMap;
20 import java.util.HashSet;
21 import java.util.Iterator;
22 import java.util.Map;
23 import java.util.Random;
24 import java.util.Set;
25
26
27
28 /**
29  * Groups items into a specified number of clusters, based on their proximity in
30  * d-dimensional space, using the k-means algorithm. Calls to
31  * <code>cluster</code> will terminate when either of the two following
32  * conditions is true:
33  * <ul>
34  * <li/>the number of iterations is &gt; <code>max_iterations</code> 
35  * <li/>none of the centroids has moved as much as <code>convergence_threshold</code>
36  * since the previous iteration
37  * </ul>
38  * 
39  * @author Joshua O'Madadhain
40  */
41 public class KMeansClusterer<T>
42 {
43     protected int max_iterations;
44     protected double convergence_threshold;
45     protected Random rand;
46
47     /**
48      * Creates an instance whose termination conditions are set according
49      * to the parameters.  
50      */
51     public KMeansClusterer(int max_iterations, double convergence_threshold)
52     {
53         this.max_iterations = max_iterations;
54         this.convergence_threshold = convergence_threshold;
55         this.rand = new Random();
56     }
57
58     /**
59      * Creates an instance with max iterations of 100 and convergence threshold
60      * of 0.001.
61      */
62     public KMeansClusterer()
63     {
64         this(100, 0.001);
65     }
66
67     /**
68      * Returns the maximum number of iterations.
69      */
70     public int getMaxIterations()
71     {
72         return max_iterations;
73     }
74
75     /**
76      * Sets the maximum number of iterations.
77      */
78     public void setMaxIterations(int max_iterations)
79     {
80         if (max_iterations < 0)
81             throw new IllegalArgumentException("max iterations must be >= 0");
82
83         this.max_iterations = max_iterations;
84     }
85
86     /**
87      * Returns the convergence threshold.
88      */
89     public double getConvergenceThreshold()
90     {
91         return convergence_threshold;
92     }
93
94     /**
95      * Sets the convergence threshold.
96      * @param convergence_threshold
97      */
98     public void setConvergenceThreshold(double convergence_threshold)
99     {
100         if (convergence_threshold <= 0)
101             throw new IllegalArgumentException("convergence threshold " +
102                 "must be > 0");
103
104         this.convergence_threshold = convergence_threshold;
105     }
106
107     /**
108      * Returns a <code>Collection</code> of clusters, where each cluster is
109      * represented as a <code>Map</code> of <code>Objects</code> to locations
110      * in d-dimensional space.
111      * @param object_locations  a map of the Objects to cluster, to
112      * <code>double</code> arrays that specify their locations in d-dimensional space.
113      * @param num_clusters  the number of clusters to create
114      * @throws NotEnoughClustersException
115      */
116     @SuppressWarnings("unchecked")
117     public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
118     {
119         if (object_locations == null || object_locations.isEmpty())
120             throw new IllegalArgumentException("'objects' must be non-empty");
121
122         if (num_clusters < 2 || num_clusters > object_locations.size())
123             throw new IllegalArgumentException("number of clusters " +
124                 "must be >= 2 and <= number of objects (" +
125                 object_locations.size() + ")");
126
127
128         Set<double[]> centroids = new HashSet<double[]>();
129
130         Object[] obj_array = object_locations.keySet().toArray();
131         Set<T> tried = new HashSet<T>();
132
133         // create the specified number of clusters
134         while (centroids.size() < num_clusters && tried.size() < object_locations.size())
135         {
136             T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
137             tried.add(o);
138             double[] mean_value = object_locations.get(o);
139             boolean duplicate = false;
140             for (double[] cur : centroids)
141             {
142                 if (Arrays.equals(mean_value, cur))
143                     duplicate = true;
144             }
145             if (!duplicate)
146                 centroids.add(mean_value);
147         }
148
149         if (tried.size() >= object_locations.size())
150             throw new NotEnoughClustersException();
151
152         // put items in their initial clusters
153         Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
154
155         // keep reconstituting clusters until either
156         // (a) membership is stable, or
157         // (b) number of iterations passes max_iterations, or
158         // (c) max movement of any centroid is <= convergence_threshold
159         int iterations = 0;
160         double max_movement = Double.POSITIVE_INFINITY;
161         while (iterations++ < max_iterations && max_movement > convergence_threshold)
162         {
163             max_movement = 0;
164             Set<double[]> new_centroids = new HashSet<double[]>();
165             // calculate new mean for each cluster
166             for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
167             {
168                 double[] centroid = entry.getKey();
169                 Map<T, double[]> elements = entry.getValue();
170                 ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
171
172                 double[] mean = DiscreteDistribution.mean(locations);
173                 max_movement = Math.max(max_movement,
174                     Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
175                 new_centroids.add(mean);
176             }
177
178             // TODO: check membership of clusters: have they changed?
179
180             // regenerate cluster membership based on means
181             clusterMap = assignToClusters(object_locations, new_centroids);
182         }
183         return clusterMap.values();
184     }
185
186     /**
187      * Assigns each object to the cluster whose centroid is closest to the
188      * object.
189      * @param object_locations  a map of objects to locations
190      * @param centroids         the centroids of the clusters to be formed
191      * @return a map of objects to assigned clusters
192      */
193     protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
194     {
195         Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
196         for (double[] centroid : centroids)
197             clusterMap.put(centroid, new HashMap<T, double[]>());
198
199         for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
200         {
201             T object = object_location.getKey();
202             double[] location = object_location.getValue();
203
204             // find the cluster with the closest centroid
205             Iterator<double[]> c_iter = centroids.iterator();
206             double[] closest = c_iter.next();
207             double distance = DiscreteDistribution.squaredError(location, closest);
208
209             while (c_iter.hasNext())
210             {
211                 double[] centroid = c_iter.next();
212                 double dist_cur = DiscreteDistribution.squaredError(location, centroid);
213                 if (dist_cur < distance)
214                 {
215                     distance = dist_cur;
216                     closest = centroid;
217                 }
218             }
219             clusterMap.get(closest).put(object, location);
220         }
221
222         return clusterMap;
223     }
224
225     /**
226      * Sets the seed used by the internal random number generator.
227      * Enables consistent outputs.
228      */
229     public void setSeed(int random_seed)
230     {
231         this.rand = new Random(random_seed);
232     }
233
234     /**
235      * An exception that indicates that the specified data points cannot be
236      * clustered into the number of clusters requested by the user.
237      * This will happen if and only if there are fewer distinct points than
238      * requested clusters.  (If there are fewer total data points than
239      * requested clusters, <code>IllegalArgumentException</code> will be thrown.)
240      *
241      * @author Joshua O'Madadhain
242      */
243     @SuppressWarnings("serial")
244     public static class NotEnoughClustersException extends RuntimeException
245     {
246         @Override
247         public String getMessage()
248         {
249             return "Not enough distinct points in the input data set to form " +
250                     "the requested number of clusters";
251         }
252     }
253 }