-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkmeans.py
185 lines (156 loc) · 6.61 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from unsupervisedLearning import *
import matplotlib.pyplot as plt
import numpy as np
import math
import random
from time import sleep
def main():
kmeans = KMeans(3, "universities.points", "universities.labels")
kmeans.findCentroids()
class KMeans(UnsupervisedLearning):
"""
A class to implement Kmeans.
"""
def __init__(self, k, pointFile, labelFile):
"""k is the number of centroids to create when clustering.
centroids is a dictionary that maps each label 'c1', 'c2', ..., 'ck'
to a centroid point, represented as an array.
members is a dictionary that maps each label to a list of points
in that label, represented as arrays.
labels is a dictionary that maps a tuple representation of each point
to its current cluster label
"""
UnsupervisedLearning.__init__(self, pointFile, labelFile)
self.k = k
self.centroids = {}
self.members = {}
self.labels = {}
self.error = 0
self.updateFlag = False
def showClusters(self, verbose=False):
"""Display data about each cluster, including its centroid,
the number of points assigned to it. When verbose is True
also show each of the member points."""
print "Current error:", self.error
for key in self.centroids:
print "-"*20
print "Cluster:", key, "Length:", len(self.members[key])
print "Cluster point:",
self.showPoint(self.centroids[key])
if verbose:
for point in self.members[key]:
self.showPoint(point)
def showPoint(self, point):
"""Compactly display a point using 3 decimal places per dimension.
If it has a label in the labels dictionary, show this as well.
"""
for floatVal in point:
print "%.3f" % floatVal,
if tuple(point) in self.labels:
print self.labels[tuple(point)]
else:
print
def plotClusters(self):
"""Plots 2d data about each centroid and its members. Uses 8 unique
colors. When the number of centroids is 8 or less, each cluster
will have a unique color. Otherwise colors will be repeated.
The centroid of each cluster is plotted as an x, all other points
are plotted aas o's.
"""
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
for i,cluster in enumerate(self.centroids.keys()):
plt.xlim(0,1)
plt.ylim(0,1)
x = [p[0] for p in self.members[cluster]]
y = [p[1] for p in self.members[cluster]]
plt.plot(x, y, colors[i%len(colors)]+ 'o')
centroid = self.centroids[cluster]
plt.plot(centroid[0], centroid[1], colors[i%len(colors)]+ 'x')
plt.show()
def initClusters(self):
"""Chooses random data points to be the initial cluster centroids.
k unique random points from self.points are selected as initial
cluster centroids.
The centroids dictionary is initialized to map each cluster label
to one of these centroid points.
The members dictionary is initialized to map each cluster label
to an empty list.
"""
indlist = range(len(self.labelList))
for i in range(self.k):
ind = random.choice(indlist)
while self.labelList[ind] in self.centroids:
ind = random.choice(indlist)
self.centroids[self.labelList[ind]] = self.pointList[ind]
def assignPoints(self):
"""E step: assigns every point to the closest centroid.
Returns True or False, indicating whether any points changed clusters.
Uses self.dist() to find the closest centroid to each point.
Loops over points one cluster at a time, according to the old
members dictionary, so that as each point is assigned to a cluster
in the new members dictionary, it can easily be determined whether
that point has switched clusters.
Also updates self.error, by initializing it to zero and keeping track
of the squared distance from each point to its assigned centroid.
"""
for p in self.pointList:
dist, cl = self.findClosestCentroid(p)
self.error += dist**2
self.members[cl].append(p)
hashable_p = str([x for x in p])
if hashable_p in self.labels:
if self.labels[hashable_p] != cl:
self.labels[hashable_p] = cl
self.updateFlag = True
else:
self.labels[hashable_p] = cl
self.updateFlag = True
def findClosestCentroid(self, point):
bestDist = (float("inf"), None)
for cl in self.centroids:
d = self.dist(self.centroids[cl], point)
if d < bestDist[0]:
bestDist = (d, cl)
return bestDist
def updateCentroids(self):
"""M step: computes new centroids for each cluster.
Each cluster's new centroid is the average along each dimension of
the points in that cluster. This computation is simplified by the
fact that points are represented as numpy arrays, which support
elementwise addition with + and mutliplication/division of all
elements by a constant.
The resulting centroid points are stored in the self.centroids dict.
"""
for cl in self.centroids:
assigned_pts = self.members[cl]
new_avg = 0.0
for p in assigned_pts:
new_avg += p
new_avg /= len(assigned_pts)
self.centroids[cl] = new_avg
def findCentroids(self):
"""Runs k-means to find a centroid for each of the k clusters.
Initializes the centroids to random points in the data set. Then
while the members of a centroid continue to change, the centroids
are recalibrated and the points are reassigned.
The methods initClusters, assignPoints, and updateCentroids
have been provided for you, but you are encouraged to create
additional helper methods as needed.
"""
self.initClusters()
ct = 0
while True:
ct+=1
self.updateFlag = False
self.error = 0.0
for cl in self.centroids:
self.members[cl] = []
self.assignPoints()
print "\n\nITERATION %d\n" % (ct)
self.showClusters(True)
self.plotClusters()
if not self.updateFlag:
break
self.updateCentroids()
if __name__ == '__main__':
main()