Visualizing K-means clustering in 1D with Python

These first few posts will focus on K-means clustering, beginning with a brief introduction to the technique and a simplified implementation in one dimension to demonstrate the concept. In the next post, the concept will be extended to N dimensions. These posts and, for that matter, probably most future posts, will focus not only on the technique in question, but also on the code–Python, in this case. In other words, this post is at least as much about Python–or, perhaps, programming in general–as it is about K-means clustering. To that end, we will go through the accompanying code, line by line, to understand not only what the code is doing, but how and why.

What is K-means clustering?

Without delving into too much detail (there are, after all, numerous resources out there that discuss K-means clustering), K-means clustering is an unsupervised machine learning technique whose purpose is to segment a data set into K clusters. In other words, given a bunch of data, K-means clustering seeks to divide it into distinct groups–usually without prior knowledge of where to start and without feedback on whether the resulting grouping is correct. The K-means clustering algorithm achieves this via several major steps:

1) Initialize K centroids, one for each cluster.
2) Assign each point in the data set to its nearest centroid.
3) After each point has been assigned to a cluster (based on its proximity to the cluster centroids), recalculate the centroid of each cluster.
4) Repeat steps 2-3 until the centroids no longer change, or until a certain number of iterations is reached.

K-means clustering isn’t usually used for one-dimensional data, but the one-dimensional case makes for a relatively simple example that demonstrates how the algorithm works. As the title suggests, the aim of this post is to visualize K-means clustering in one dimension with Python, like so:

We’ll take a look at the code used to create the animations in this video in the following section.

The code

The Python script used to create the animations in the video above can be found on Github. Some familiarity with Python, or at least with programming, is assumed, but most of the content will be explained in detail below–this particular script strives for clarity over brevity. The script should work in both Python 2.x and Python 3.x, and requires you to have the numpy and matplotlib packages installed.

OK, let’s get started. Open up your favorite IDE or text editor (I prefer vim) and create a file named, or download the script from the Github link above and follow along.

import numpy as np
import matplotlib.pyplot as plt
import colorsys
import sys

K = 3   # number of centroids to compute
numClusters = 3 # actual number of clusters to generate
ptsPerCluster = 40  # number of points per actual cluster
xCenterBounds = (-2, 2) # limits within which to place actual cluster centers

numClusters is the number of clusters we’ll actually generate, each of which will contain ptsPerCluster points. We’ll place the centers of these clusters within the (min, max) values given by xCenterBounds (“x” is our one dimension in this one-dimensional example). K is the number of centroids, i.e., clusters, we’d like the algorithm to look for–it’s the “K” in “K-means”. Note that the number of clusters the algorithm searches for is independent of the number of clusters we actually generate. In fact, this is an important point–in this example, we are explicitly generating clusters of data before letting our K-means algorithm have a go at it. In the real world, however, we won’t usually be generating the data. We’ll be collecting or receiving it, we won’t necessarily know how many clusters exist, if any, and it’ll be up to us and/or our K-means algorithm to determine how many clusters to divide the data into, i.e., what value of K to choose. There are methods for determining the optimal value of K, but we’ll get into that later.

# Randomly place cluster centers within the span of xCenterBounds.
centers = np.random.random_sample((numClusters,))
centers = centers * (xCenterBounds[1] - xCenterBounds[0]) + xCenterBounds[0]

# Initialize array of data points.
points = np.zeros((numClusters * ptsPerCluster,))

# Normally distribute ptsPerCluster points around each center.
stDev = 0.15
for i in range(numClusters):
    points[i*ptsPerCluster:(i+1)*ptsPerCluster] = (
        stDev * np.random.randn(ptsPerCluster) + centers[i])

Line 12 utilizes the random_sample function of numpy’s random module to generate an array of random floats in the interval [0.0, 1.0), sampled from a continuous uniform distribution (as opposed to, say, a normal distribution). Line 13 maps these values from the range [0.0, 1.0) to the range given by xCenterBounds. These values constitute the centers of the clusters of points we’ll generate, which is achieved by lines 19-22 (each iteration of the loop generates a cluster of points and stores them in points). To do this, we define a standard deviation stDev, and utilize the numpy function randn, which draws from a normal distribution. Increasing the value of stDev will cause the points to spread out, whereas decreasing its value results in more tightly packed clusters. Play around with this to see how it affects the algorithm’s ability to differentiate clusters from one another.

# Randomly select K points as the initial centroid locations.
centroids = np.zeros((K,))
indices = []
while len(indices) < K:
    index = np.random.randint(0, numClusters * ptsPerCluster)
    if not index in indices:
centroids = points[indices]

# Assign each point to its nearest centroid. Store this in classifications,
# where each element will be an int from 0 to K-1.
classifications = np.zeros((points.shape[0],),
def assignPointsToCentroids():
    for i in range(points.shape[0]):
        smallestDistance = 0
        for k in range(K):
            distance = abs(points[i] - centroids[k])
            if k == 0:
                smallestDistance = distance
                classifications[i] = k
            elif distance < smallestDistance:
                smallestDistance = distance
                classifications[i] = k


Several methods exist for choosing the initial locations for our K centroids. This is important, since our choice of initial locations affects the outcome. There are many ways in which we can cluster any given set of data. The K-means algorithm does not guarantee that we will arrive at the best solution (the “global optimum”), only that we will arrive at a solution (“local optimum”). Regardless, as a rule of thumb, a good way to initialize the centroid locations is to randomly choose K unique points from our data set and use those as the initial centroid locations. This is what lines 25-31 accomplish by choosing K unique indices from the points array and storing the values at those indices in the centroids array. During each iteration of the algorithm, we’ll update centroids with the new centroid locations.

Finally, we will assign each point in points to the cluster whose centroid the point is closest to, with the function assignPointsToCentroids(). In other words, we’ll iterate through points and, for each point, we will determine which of the K centroids in centroids that point is closest to, then store the index corresponding to that centroid/cluster in the array classifications (since we’re “classifying” the point as belonging to one of the K clusters). Each element in classifications corresponds to a point in points. So, if classifications[3] = 0, that would signify that the fourth point, i.e., points[3], is closest to centroid 0, whose location is given by centroids[0]. Since we’re only dealing with one dimension in this example, the distance between a given point and a given centroid can be obtained with the built-in absolute value function abs(). After defining our function assignPointsToCentroids(), we run it once to group the points into clusters based on the initial centroid locations currently in centroids.

# Define a function to recalculate the centroid of a cluster.
def recalcCentroids():
    for k in range(K):
        if sum(classifications == k) > 0:
            centroids[k] = sum(points[classifications == k]) / sum(classifications == k)

Lines 51-54 define a function, recalcCentroids(), that carries out the other major part of the algorithm: recalculating each cluster’s centroid location. For each cluster, we first check whether or not any points are actually assigned to that cluster with if sum(classifications == k) > 0. The comparative statement classifications == k returns an array containing 1s at the indices where classifications equals k and 0s where it doesn’t; if the sum of the elements is at least 1, then that means at least one point is assigned to that cluster. The next statement on line 54, which actually computes the centroid, involves dividing by the number of points in the cluster. Ensuring there’s at least one point in the cluster preempts division by zero.

# Generate a unique color for each of the K clusters using the HSV color scheme.
# Simultaneously, initialize matplotlib line objects for each centroid and cluster.
hues = np.linspace(0, 1, K+1)[:-1]

fig, ax = plt.subplots()
clusterPointsList = []
centroidPointsList = []
for k in range(K):
    clusterColor = tuple(colorsys.hsv_to_rgb(hues[k], 0.8, 0.8))

    clusterLineObj, = ax.plot([], [], ls='None', marker='x', color=clusterColor)

    centroidLineObj, = ax.plot([], [], ls='None', marker='o', 
        markeredgecolor='k', color=clusterColor)
iterText = ax.annotate('', xy=(0.01, 0.01), xycoords='axes fraction')

Our goal in this exercise is to visualize the algorithm, which means we’ll want a different color to represent each cluster. A nifty way to achieve this for any arbitrary number of clusters is to take advantage of the HSV color space, which classifies colors by hue, saturation, and lightness value. These parameters can be represented by a cone or cylinder. The hue changes as we travel around the cylinder. 0° represents red, 120° represents green, 240° represents blue, and 360° marks a return to red. In the Python colorsys module, the range 0°-360° is represented by a float from 0 to 1. On line 58, to get K distinct colors, we divide the hue range [0, 1] into K+1 equally spaced numbers and use all except the last one. This is because the first number will be 0 and the last will be 1, which, in the HSV color space, both represent red.

Line 60 creates a figure with an axis using the pyplot function subplots(), and returns handles to the figure and axis objects, which we can use to get or set figure and axis properties later. To visualize the algorithm, we want to plot each cluster, as well as the centroid for that cluster, in a unique color. We also need to be able to update each cluster after each iteration of the algorithm. To do this, we’ll utilize a for-loop to initialize an empty matplotlib line object for each cluster and for each centroid (lines 63-71) by calling the plot() function of our axis. Basically, for each of the K clusters, we add a “line object” to the axis ax using plot(). The properties of our line object are determined by the arguments we pass to plot().

The first two arguments to plot() are the x and y coordinates of the points we’d like to plot. By passing in empty arrays with ax.plot([], [], ...), we’re initializing a line that doesn’t have any x or y data. The keyword argument ls sets the linestyle; ls='None' tells matplotlib that we want to plot the points without a line connecting them. The keyword argument marker sets the marker style. The keyword argument color takes any matplotlib color specification. In this case, we’re feeding it an RGB tuple crafted from our HSV colors using the colorsys function hsv_to_rgb() (note that matplotlib also has a colors module, matplotlib.colors, which can convert HSV to RGB, but I used colorsys to show that Python has a built-in package to handle this functionality). plot() returns a handle to the line object it just created, so we can modify any of the aforementioned properties later. Of course, we need to be able to access these line object handles later each time we want to update our plot. We do this by adding the cluster and centroid line object handles to the lists clusterPointsList and centroidPointsList, respectively. This maintains a reference to the objects from our for-loop after the loop has completed.

Line 72 adds a blank text annotation to the lower left corner of the plot. We’ll use this text to display the number of iterations the algorithm has performed, and we’ll update it at every iteration using its handle, which I’ve named iterText.

# Define a function to update the plot.
def updatePlot(iteration):
    for k in range(K):
        xDataNew = points[classifications == k]
        clusterPointsList[k].set_data(xDataNew, np.zeros((len(xDataNew),)))
        centroidPointsList[k].set_data(centroids[k], 0)
    iterText.set_text('i = {:d}'.format(iteration))

dataRange = np.amax(points) - np.amin(points)
ax.set_xlim(np.amin(points) - 0.05*dataRange, np.amax(points) + 0.05*dataRange)
ax.set_ylim(-1, 1)
iteration = 0

We’re almost done. We just need a function to update the plot during each iteration of the algorithm, which is defined on lines 75-82. Our updatePlot() function takes one argument–the current iteration–and uses it to set the value of the iterText annotation. For each cluster k, the function determines which points belong to the cluster (line 77), then uses that to set the x data of the line object corresponding to the cluster (which it pulls from clusterPointsList) using the line object’s set_data() method. The set_data() method takes two arguments: an array of x coordinates and an array of y coordinates. Since this is a one-dimensional example, we don’t care about the y values, so we set them all to 0. Note that we don’t have to worry about setting the line object colors or marker styles, because we already did that when we created the line objects (line 66). On line 79, we do the same thing for the cluster centroid. Since each cluster only has one centroid value, we pass a single x value and a single y value to the centroid line object’s set_data() method.

On line 81, we use the savefig() method to save the plot in its current state as an image, which will occur on every iteration of the while loop that we’ll use to animate the algorithm. Afterward, we’ll use the images to create a video of the animation. You can comment this line out if you don’t want to save images of the animation.

IMPORTANT NOTE: A WHILE LOOP IS NOT THE BEST WAY TO ANIMATE IN MATPLOTLIB OR TO CREATE VIDEOS OF THE ANIMATION! Using a while loop to animate things in matplotlib works, but it’s not a good way to animate things. There’s an animation module in matplotlib that does a better job of this. Saving individual frames and manually creating a video is also unnecessary–the matplotlib.animation module does that, too. For this example, though, I’d like to keep things simple by avoiding the animation module. Furthermore, I’d like to demonstrate the “hard way” before demonstrating the correct way. Plus, we’ll get to see an example of how to create videos from the terminal with ffmpeg, which, really, is what the matplotlib animation module does behind the scenes, anyway. We will, however, use the animation module in the next post.

Getting back to the current exercise: lines 84-86 set the axis limits of our plot to ensure all our data points will be visible. On line 88, we run the updatePlot() function defined above to initialize the plot. On line 89, we turn on interactive plotting with plt.ion(). This, ironically, allows us to continuously update the plot without user interaction, by permitting the rest of the code to continue executing while the plot is open. Finally, a call to is necessary to actually show the plot.

# Execute and animate the algorithm with a while loop. Note that this is not the
# best way to animate a matplotlib plot--the matplotlib animation module should be
# used instead, but we will use a while loop here for simplicity.
lastCentroids = centroids + 1
while not np.array_equal(centroids, lastCentroids):
    lastCentroids = np.copy(centroids)
    iteration += 1

pythonMajorVersion = sys.version_info[0]
if pythonMajorVersion < 3:
    raw_input("Press Enter to continue.")
    input("Press Enter to continue.")

At last, we execute the algorithm. The end condition is met when the centroid locations no longer change. To accomplish this, we create the array lastCentroids and initialize it with an arbitrary set of values that’s different from centroids to ensure the while loop executes at least once. Note that we use np.copy() on line 97 because the statement lastCentroids = centroids wouldn’t create a copy of centroids–instead, it would point to it (i.e., if we used lastCentroids = centroids, lastCentroids and centroids would point to the same object; modifying one would modify the other and, since they’d both point to the same object, they would always be equal, meaning the loop would only execute once). This is a quirk of numpy arrays that’s important to keep in mind.

That’s it. Run the script and see what happens. Try running it multiple times to see how different centroid initializations impact the results. Play around with the parameters: try different values for K, different numbers of clusters, and different cluster standard deviations.

Create a video using ffmpeg

Optionally, we can now combine the saved images (from line 81 in our updatePlot() function) to create a video. There are many ways to do this, but I’m going to use ffmpeg. REMINDER: THIS IS NOT THE BEST WAY TO CREATE A VIDEO OF THE ANIMATION! As I mentioned above, using the matplotlib.animation module is a better method. Saving images and manually creating a video with ffmpeg is a purely instructional exercise.

Open a terminal or command prompt in the same directory as the images and type the following command:

ffmpeg -r 1 -i %d.png -vcodec h264 -pix_fmt yuv420p output.mp4

The -r option sets the framerate. -i sets the input file name or pattern–in this case, we use the pattern %d.png to specify that the input file names contain an integer with a .png extension. -vcodec tells ffmpeg which video codec to utilize for encoding the file. You can view the ffmpeg codecs available on your system by running the command ffmpeg -codecs. -pix_fmt sets the pixel format to use. In this example, we use a format called yuv420p. YUV is simply a color space like RGB or HSV, and our selection of a YUV pixel format determines how colors are mapped to produce the video. The command ffmpeg -pix_fmts will list all available pixel formats. The last argument is the name of the output file, output.mp4. ffmpeg uses the file extension of the output file to properly encode the video. Note that most of the arguments to ffmpeg are optional, with the exception of your input file(s) and output file, but the optional arguments provide more control over how your input is processed and how your output is encoded. There are many other options you can specify, as well. I’m by no means an expert on ffmpeg, so I’ll refer you to the official ffmpeg documentation if you’d like to learn more.

In the next post, we’ll generalize the K-means clustering algorithm to any arbitrary number of dimensions, and we’ll animate the result using the matplotlib.animation module, as well as tackle several other Python concepts.