Initial opendaylight infrastructure commit!!
[controller.git] / third-party / net.sf.jung2 / src / main / java / edu / uci / ics / jung / algorithms / util / KMeansClusterer.java
diff --git a/third-party/net.sf.jung2/src/main/java/edu/uci/ics/jung/algorithms/util/KMeansClusterer.java b/third-party/net.sf.jung2/src/main/java/edu/uci/ics/jung/algorithms/util/KMeansClusterer.java
new file mode 100644 (file)
index 0000000..dce550f
--- /dev/null
@@ -0,0 +1,253 @@
+/*
+ * Copyright (c) 2003, the JUNG Project and the Regents of the University
+ * of California
+ * All rights reserved.
+ *
+ * This software is open-source under the BSD license; see either
+ * "license.txt" or
+ * http://jung.sourceforge.net/license.txt for a description.
+ */
+/*
+ * Created on Aug 9, 2004
+ *
+ */
+package edu.uci.ics.jung.algorithms.util;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+
+
+/**
+ * Groups items into a specified number of clusters, based on their proximity in
+ * d-dimensional space, using the k-means algorithm. Calls to
+ * <code>cluster</code> will terminate when either of the two following
+ * conditions is true:
+ * <ul>
+ * <li/>the number of iterations is &gt; <code>max_iterations</code> 
+ * <li/>none of the centroids has moved as much as <code>convergence_threshold</code>
+ * since the previous iteration
+ * </ul>
+ * 
+ * @author Joshua O'Madadhain
+ */
+public class KMeansClusterer<T>
+{
+    protected int max_iterations;
+    protected double convergence_threshold;
+    protected Random rand;
+
+    /**
+     * Creates an instance whose termination conditions are set according
+     * to the parameters.  
+     */
+    public KMeansClusterer(int max_iterations, double convergence_threshold)
+    {
+        this.max_iterations = max_iterations;
+        this.convergence_threshold = convergence_threshold;
+        this.rand = new Random();
+    }
+
+    /**
+     * Creates an instance with max iterations of 100 and convergence threshold
+     * of 0.001.
+     */
+    public KMeansClusterer()
+    {
+        this(100, 0.001);
+    }
+
+    /**
+     * Returns the maximum number of iterations.
+     */
+    public int getMaxIterations()
+    {
+        return max_iterations;
+    }
+
+    /**
+     * Sets the maximum number of iterations.
+     */
+    public void setMaxIterations(int max_iterations)
+    {
+        if (max_iterations < 0)
+            throw new IllegalArgumentException("max iterations must be >= 0");
+
+        this.max_iterations = max_iterations;
+    }
+
+    /**
+     * Returns the convergence threshold.
+     */
+    public double getConvergenceThreshold()
+    {
+        return convergence_threshold;
+    }
+
+    /**
+     * Sets the convergence threshold.
+     * @param convergence_threshold
+     */
+    public void setConvergenceThreshold(double convergence_threshold)
+    {
+        if (convergence_threshold <= 0)
+            throw new IllegalArgumentException("convergence threshold " +
+                "must be > 0");
+
+        this.convergence_threshold = convergence_threshold;
+    }
+
+    /**
+     * Returns a <code>Collection</code> of clusters, where each cluster is
+     * represented as a <code>Map</code> of <code>Objects</code> to locations
+     * in d-dimensional space.
+     * @param object_locations  a map of the Objects to cluster, to
+     * <code>double</code> arrays that specify their locations in d-dimensional space.
+     * @param num_clusters  the number of clusters to create
+     * @throws NotEnoughClustersException
+     */
+    @SuppressWarnings("unchecked")
+    public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
+    {
+        if (object_locations == null || object_locations.isEmpty())
+            throw new IllegalArgumentException("'objects' must be non-empty");
+
+        if (num_clusters < 2 || num_clusters > object_locations.size())
+            throw new IllegalArgumentException("number of clusters " +
+                "must be >= 2 and <= number of objects (" +
+                object_locations.size() + ")");
+
+
+        Set<double[]> centroids = new HashSet<double[]>();
+
+        Object[] obj_array = object_locations.keySet().toArray();
+        Set<T> tried = new HashSet<T>();
+
+        // create the specified number of clusters
+        while (centroids.size() < num_clusters && tried.size() < object_locations.size())
+        {
+            T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
+            tried.add(o);
+            double[] mean_value = object_locations.get(o);
+            boolean duplicate = false;
+            for (double[] cur : centroids)
+            {
+                if (Arrays.equals(mean_value, cur))
+                    duplicate = true;
+            }
+            if (!duplicate)
+                centroids.add(mean_value);
+        }
+
+        if (tried.size() >= object_locations.size())
+            throw new NotEnoughClustersException();
+
+        // put items in their initial clusters
+        Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
+
+        // keep reconstituting clusters until either
+        // (a) membership is stable, or
+        // (b) number of iterations passes max_iterations, or
+        // (c) max movement of any centroid is <= convergence_threshold
+        int iterations = 0;
+        double max_movement = Double.POSITIVE_INFINITY;
+        while (iterations++ < max_iterations && max_movement > convergence_threshold)
+        {
+            max_movement = 0;
+            Set<double[]> new_centroids = new HashSet<double[]>();
+            // calculate new mean for each cluster
+            for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
+            {
+                double[] centroid = entry.getKey();
+                Map<T, double[]> elements = entry.getValue();
+                ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
+
+                double[] mean = DiscreteDistribution.mean(locations);
+                max_movement = Math.max(max_movement,
+                    Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
+                new_centroids.add(mean);
+            }
+
+            // TODO: check membership of clusters: have they changed?
+
+            // regenerate cluster membership based on means
+            clusterMap = assignToClusters(object_locations, new_centroids);
+        }
+        return clusterMap.values();
+    }
+
+    /**
+     * Assigns each object to the cluster whose centroid is closest to the
+     * object.
+     * @param object_locations  a map of objects to locations
+     * @param centroids         the centroids of the clusters to be formed
+     * @return a map of objects to assigned clusters
+     */
+    protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
+    {
+        Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
+        for (double[] centroid : centroids)
+            clusterMap.put(centroid, new HashMap<T, double[]>());
+
+        for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
+        {
+            T object = object_location.getKey();
+            double[] location = object_location.getValue();
+
+            // find the cluster with the closest centroid
+            Iterator<double[]> c_iter = centroids.iterator();
+            double[] closest = c_iter.next();
+            double distance = DiscreteDistribution.squaredError(location, closest);
+
+            while (c_iter.hasNext())
+            {
+                double[] centroid = c_iter.next();
+                double dist_cur = DiscreteDistribution.squaredError(location, centroid);
+                if (dist_cur < distance)
+                {
+                    distance = dist_cur;
+                    closest = centroid;
+                }
+            }
+            clusterMap.get(closest).put(object, location);
+        }
+
+        return clusterMap;
+    }
+
+    /**
+     * Sets the seed used by the internal random number generator.
+     * Enables consistent outputs.
+     */
+    public void setSeed(int random_seed)
+    {
+        this.rand = new Random(random_seed);
+    }
+
+    /**
+     * An exception that indicates that the specified data points cannot be
+     * clustered into the number of clusters requested by the user.
+     * This will happen if and only if there are fewer distinct points than
+     * requested clusters.  (If there are fewer total data points than
+     * requested clusters, <code>IllegalArgumentException</code> will be thrown.)
+     *
+     * @author Joshua O'Madadhain
+     */
+    @SuppressWarnings("serial")
+    public static class NotEnoughClustersException extends RuntimeException
+    {
+        @Override
+        public String getMessage()
+        {
+            return "Not enough distinct points in the input data set to form " +
+                    "the requested number of clusters";
+        }
+    }
+}