These first few posts will focus on K-means clustering, beginning with a brief introduction to the technique and a simplified implementation in one dimension to demonstrate the concept. In the next post, the concept will be extended to N dimensions. These posts and, for that matter, probably most future posts, will focus not only on the technique in question, but also on the code–Python, in this case. In other words, this post is at least as much about Python–or, perhaps, programming in general–as it is about K-means clustering. To that end, we will go through the accompanying code, line by line, to understand not only what the code is doing, but how and why.
What is K-means clustering?
Without delving into too much detail (there are, after all, numerous resources out there that discuss K-means clustering), K-means clustering is an unsupervised machine learning technique whose purpose is to segment a data set into K clusters. In other words, given a bunch of data, K-means clustering seeks to divide it into distinct groups–usually without prior knowledge of where to start and without feedback on whether the resulting grouping is correct. The K-means clustering algorithm achieves this via several major steps:
- Initialize K centroids, one for each cluster.
- Assign each point in the data set to its nearest centroid.
- After each point has been assigned to a cluster (based on its proximity to the cluster centroids), recalculate the centroid of each cluster.
- Repeat steps 2-3 until the centroids no longer change, or until a certain number of iterations is reached.
K-means clustering isn’t usually used for one-dimensional data, but the one-dimensional case makes for a relatively simple example that demonstrates how the algorithm works. As the title suggests, the aim of this post is to visualize K-means clustering in one dimension with Python, like so:
We’ll take a look at the code used to create the animations in this video in the following section.
The code
The Python script used to create the animations in the video above can be found on Github. Some familiarity with Python, or at least with programming, is assumed, but most of the content will be explained in detail below-this particular script strives for clarity over brevity. The script should work in both Python 2.x and Python 3.x, and requires you to have the numpy and matplotlib packages installed.
OK, let’s get started. Open up your favorite IDE or text editor (I prefer vim) and create a file named kmeans1d_demo.py, or download the script from the Github link above and follow along.
|
|
numClusters
is the number of clusters we’ll actually generate, each of
which will contain ptsPerCluster
points. We’ll place the centers of
these clusters within the (min, max) values given by xCenterBounds
(“x” is our one dimension in this one-dimensional example). K
is
the number of centroids, i.e., clusters, we’d like the algorithm to look
for–it’s the “K” in “K-means”. Note that the
number of clusters the algorithm searches for is independent of the number of
clusters we actually generate. In fact, this is an important point–in this
example, we are explicitly generating clusters of data before letting our
K-means algorithm have a go at it. In the real world, however, we won’t
usually be generating the data. We’ll be collecting or receiving it, we
won’t necessarily know how many clusters exist, if any, and it’ll be
up to us and/or our K-means algorithm to determine how many clusters to divide
the data into, i.e., what value of K to choose. There are methods for
determining the optimal value of K, but we’ll get into that later.
|
|
Line 12 utilizes the random_sample
function of numpy’s random module
to generate an array of random floats in the interval [0.0, 1.0), sampled from a
continuous uniform distribution (as opposed to, say, a normal distribution).
Line 13 maps these values from the range [0.0, 1.0) to the range given by
xCenterBounds
. These values constitute the centers of the clusters of points
we’ll generate, which is achieved by lines 19-22 (each iteration of
the loop generates a cluster of points and stores them in points
). To do this,
we define a standard deviation stDev
, and utilize the numpy function randn
,
which draws from a normal distribution. Increasing the value of stDev
will
cause the points to spread out, whereas decreasing its value results in more
tightly packed clusters. Play around with this to see how it affects the
algorithm’s ability to differentiate clusters from one another.
|
|
Several methods exist for choosing the initial locations for our K centroids.
This is important, since our choice of initial locations affects the outcome.
There are many ways in which we can cluster any given set of data. The K-means
algorithm does not guarantee that we will arrive at the best solution (the
“global optimum”), only that we will arrive at a solution
(“local optimum”). Regardless, as a rule of thumb, a good way to
initialize the centroid locations is to randomly choose K unique points from our
data set and use those as the initial centroid locations. This is what lines
25-31 accomplish by choosing K unique indices from the points
array and
storing the values at those indices in the centroids
array. During each
iteration of the algorithm, we’ll update centroids
with the new centroid
locations.
Finally, we will assign each point in points
to the cluster whose centroid the
point is closest to, with the function assignPointsToCentroids()
. In other
words, we’ll iterate through points
and, for each point, we will
determine which of the K centroids in centroids
that point is closest to, then
store the index corresponding to that centroid/cluster in the array
classifications
(since we’re “classifying” the point as
belonging to one of the K clusters). Each element in classifications
corresponds to a point in points
. So, if classifications[3] = 0
, that would
signify that the fourth point, i.e., points[3]
, is closest to centroid 0,
whose location is given by centroids[0]
. Since we’re only dealing with
one dimension in this example, the distance between a given point and a given
centroid can be obtained with the built-in absolute value function abs()
.
After defining our function assignPointsToCentroids()
, we run it once to group
the points into clusters based on the initial centroid locations currently in
centroids
.
|
|
Lines 51-54 define a function, recalcCentroids()
, that carries out the
other major part of the algorithm: recalculating each cluster’s centroid
location. For each cluster, we first check whether or not any points are
actually assigned to that cluster with if sum(classifications == k) > 0
. The
comparative statement classifications == k
returns an array containing 1s at
the indices where classifications
equals k
and 0s where it doesn’t; if
the sum of the elements is at least 1, then that means at least one point is
assigned to that cluster. The next statement on line 54, which actually
computes the centroid, involves dividing by the number of points in the cluster.
Ensuring there’s at least one point in the cluster preempts division by
zero.
|
|
Our goal in this exercise is to visualize the algorithm, which means we’ll
want a different color to represent each cluster. A nifty way to achieve this
for any arbitrary number of clusters is to take advantage of the
HSV color space, which classifies colors by hue, saturation, and
lightness value. These parameters can be represented by a cone or cylinder.
The hue changes as we travel around the cylinder. 0° represents red,
120° represents green, 240° represents blue, and 360° marks a
return to red. In the Python colorsys
module, the range 0°-360° is
represented by a float from 0 to 1. On line 58, to get K distinct colors, we
divide the hue range [0, 1] into K+1 equally spaced numbers and use all except
the last one. This is because the first number will be 0 and the last will be 1,
which, in the HSV color space, both represent red.
Line 60 creates a figure with an axis using the pyplot function
subplots()
, and returns handles to the figure and axis objects, which we can
use to get or set figure and axis properties later. To visualize the algorithm,
we want to plot each cluster, as well as the centroid for that cluster, in a
unique color. We also need to be able to update each cluster after each
iteration of the algorithm. To do this, we’ll utilize a for-loop to
initialize an empty matplotlib line object for each cluster and for each
centroid (lines 63-71) by calling the plot()
function of our axis.
Basically, for each of the K clusters, we add a “line object” to the
axis ax
using plot()
. The properties of our line object are determined by
the arguments we pass to plot()
.
The first two arguments to plot()
are the x and y coordinates of the points
we’d like to plot. By passing in empty arrays with ax.plot([], [], ...)
,
we’re initializing a line that doesn’t have any x or y data. The
keyword argument ls
sets the linestyle; ls='None'
tells matplotlib that we
want to plot the points without a line connecting them. The keyword argument
marker
sets the marker style. The keyword argument color
takes any
matplotlib color specification. In this case, we’re feeding it an RGB
tuple crafted from our HSV colors using the colorsys
function hsv_to_rgb()
(note that matplotlib also has a colors module, matplotlib.colors
, which can
convert HSV to RGB, but I used colorsys
to show that Python has a built-in
package to handle this functionality). plot()
returns a handle to the line
object it just created, so we can modify any of the aforementioned properties
later. Of course, we need to be able to access these line object handles later
each time we want to update our plot. We do this by adding the cluster and
centroid line object handles to the lists clusterPointsList
and
centroidPointsList
, respectively. This maintains a reference to the objects
from our for-loop after the loop has completed.
Line 72 adds a blank text annotation to the lower left corner of the plot.
We’ll use this text to display the number of iterations the algorithm has
performed, and we’ll update it at every iteration using its handle, which
I’ve named iterText
.
|
|
We’re almost done. We just need a function to update the plot during each
iteration of the algorithm, which is defined on lines 75-82. Our
updatePlot()
function takes one argument–the current iteration–and
uses it to set the value of the iterText
annotation. For each cluster k
, the
function determines which points belong to the cluster (line 77), then uses
that to set the x data of the line object corresponding to the cluster (which it
pulls from clusterPointsList
) using the line object’s set_data()
method. The set_data()
method takes two arguments: an array of x coordinates
and an array of y coordinates. Since this is a one-dimensional example, we
don’t care about the y values, so we set them all to 0. Note that we
don’t have to worry about setting the line object colors or marker styles,
because we already did that when we created the line objects (line 66). On
line 79, we do the same thing for the cluster centroid. Since each cluster
only has one centroid value, we pass a single x value and a single y value to
the centroid line object’s set_data()
method.
On line 81, we use the savefig()
method to save the plot in its current
state as an image, which will occur on every iteration of the while loop that
we’ll use to animate the algorithm. Afterward, we’ll use the images
to create a video of the animation. You can comment this line out if you
don’t want to save images of the animation.
IMPORTANT NOTE: A WHILE LOOP IS NOT THE BEST WAY TO ANIMATE IN MATPLOTLIB OR
TO CREATE VIDEOS OF THE ANIMATION! Using a while loop to animate things in
matplotlib works, but it’s not a good way to animate things. There’s
an animation
module in matplotlib that does a better job of this. Saving
individual frames and manually creating a video is also unnecessary–the
matplotlib.animation
module does that, too. For this example, though,
I’d like to keep things simple by avoiding the animation module.
Furthermore, I’d like to demonstrate the “hard way” before
demonstrating the correct way. Plus, we’ll get to see an example of how to
create videos from the terminal with ffmpeg, which, really, is what the
matplotlib animation module does behind the scenes, anyway. We will, however,
use the animation module in the next post.
Getting back to the current exercise: lines 84-86 set the axis limits of our
plot to ensure all our data points will be visible. On line 88, we run the
updatePlot()
function defined above to initialize the plot. On line 89, we
turn on interactive plotting with plt.ion()
. This, ironically, allows us to
continuously update the plot without user interaction, by permitting the rest of
the code to continue executing while the plot is open. Finally, a call to
plt.show()
is necessary to actually show the plot.
|
|
At last, we execute the algorithm. The end condition is met when the centroid
locations no longer change. To accomplish this, we create the array
lastCentroids
and initialize it with an arbitrary set of values that’s
different from centroids
to ensure the while loop executes at least once. Note
that we use np.copy()
on line 97 because the statement
lastCentroids = centroids
wouldn’t create a copy of centroids
–instead, it would point to it (i.e., if we used
lastCentroids = centroids
, lastCentroids
and centroids
would point to the
same object; modifying one would modify the other and, since they’d both
point to the same object, they would always be equal, meaning the loop would
only execute once). This is a quirk of numpy arrays that’s important to
keep in mind.
That’s it. Run the script and see what happens. Try running it multiple times to see how different centroid initializations impact the results. Play around with the parameters: try different values for K, different numbers of clusters, and different cluster standard deviations.
Create a video using ffmpeg
Optionally, we can now combine the saved images (from line 81 in our
updatePlot()
function) to create a video. There are many ways to do this, but
I’m going to use ffmpeg. REMINDER: THIS IS NOT THE BEST WAY TO CREATE A
VIDEO OF THE ANIMATION! As I mentioned above, using the
matplotlib.animation
module is a better method. Saving images and manually
creating a video with ffmpeg is a purely instructional exercise.
Open a terminal or command prompt in the same directory as the images and type the following command:
ffmpeg -r 1 -i %d.png -vcodec h264 -pix_fmt yuv420p output.mp4
The -r
option sets the framerate. -i
sets the input file name or
pattern–in this case, we use the pattern %d.png
to specify that the
input file names contain an integer with a .png extension. -vcodec
tells
ffmpeg which video codec to utilize for encoding the file. You can view the
ffmpeg codecs available on your system by running the command ffmpeg -codecs
.
-pix_fmt
sets the pixel format to use. In this example, we use a format called
yuv420p. YUV is simply a color space like RGB or HSV, and our selection of a YUV
pixel format determines how colors are mapped to produce the video. The command
ffmpeg -pix_fmts
will list all available pixel formats. The last argument is
the name of the output file, output.mp4
. ffmpeg uses the file extension of the
output file to properly encode the video. Note that most of the arguments to
ffmpeg are optional, with the exception of your input file(s) and output file,
but the optional arguments provide more control over how your input is processed
and how your output is encoded. There are many other options you can specify, as
well. I’m by no means an expert on ffmpeg, so I’ll
refer
you to the official ffmpeg documentation if you’d like to learn more.
In the next post, we’ll generalize the K-means clustering algorithm to any
arbitrary number of dimensions, and we’ll animate the result using the
matplotlib.animation
module, as well as tackle several other Python concepts.