[Scipy-svn] r3143 - in trunk/Lib/cluster: . tests
scipy-svn@scip...
scipy-svn@scip...
Tue Jul 3 06:29:07 CDT 2007
Author: cdavid
Date: 2007-07-03 06:29:01 -0500 (Tue, 03 Jul 2007)
New Revision: 3143
Modified:
trunk/Lib/cluster/tests/test_vq.py
trunk/Lib/cluster/vq.py
Log:
Add an option to kmeans2 to decide what to do when one cluster disappears
Modified: trunk/Lib/cluster/tests/test_vq.py
===================================================================
--- trunk/Lib/cluster/tests/test_vq.py 2007-07-02 15:25:56 UTC (rev 3142)
+++ trunk/Lib/cluster/tests/test_vq.py 2007-07-03 11:29:01 UTC (rev 3143)
@@ -1,7 +1,7 @@
#! /usr/bin/env python
# David Cournapeau
-# Last Change: Tue Jun 19 10:00 PM 2007 J
+# Last Change: Tue Jul 03 08:00 PM 2007 J
# For now, just copy the tests from sandbox.pyem, so we can check that
# kmeans works OK for trivial examples.
@@ -12,7 +12,7 @@
import numpy as N
set_package_path()
-from cluster.vq import kmeans, kmeans2, py_vq, py_vq2, _py_vq_1d, vq
+from cluster.vq import kmeans, kmeans2, py_vq, py_vq2, _py_vq_1d, vq, ClusterError
try:
from cluster import _vq
TESTC=True
@@ -21,10 +21,10 @@
TESTC=False
restore_path()
+import os.path
#Optional:
set_local_path()
# import modules that are located in the same directory as this file.
-import os.path
DATAFILE1 = os.path.join(sys.path[0], "data.txt")
restore_path()
@@ -106,6 +106,12 @@
[-2.31149087,-0.05160469]])
res = kmeans(data, initk)
+ res = kmeans2(data, initk, missing = 'warn')
+ try :
+ res = kmeans2(data, initk, missing = 'raise')
+ raise AssertionError("Exception not raised ! Should not happen")
+ except ClusterError, e:
+ print "exception raised as expected: " + str(e)
def check_kmeans2_simple(self, level=1):
"""Testing simple call to kmeans2 and its results."""
Modified: trunk/Lib/cluster/vq.py
===================================================================
--- trunk/Lib/cluster/vq.py 2007-07-02 15:25:56 UTC (rev 3142)
+++ trunk/Lib/cluster/vq.py 2007-07-03 11:29:01 UTC (rev 3143)
@@ -35,6 +35,9 @@
std, mean
import numpy as N
+class ClusterError(Exception):
+ pass
+
def whiten(obs):
""" Normalize a group of observations on a per feature basis.
@@ -188,7 +191,8 @@
else:
(n, d) = shape(obs)
- # code books and observations should have same number of features and same shape
+ # code books and observations should have same number of features and same
+ # shape
if not N.ndim(obs) == N.ndim(code_book):
raise ValueError("Observation and code_book should have the same rank")
elif not d == code_book.shape[1]:
@@ -228,7 +232,7 @@
nc = code_book.size
dist = N.zeros((n, nc))
for i in range(nc):
- dist[:,i] = N.sum(obs - code_book[i])
+ dist[:, i] = N.sum(obs - code_book[i])
print dist
code = argmin(dist)
min_dist = dist[code]
@@ -270,7 +274,7 @@
code book(%d) and obs(%d) should have the same
number of features (eg columns)""" % (code_book.shape[1], d))
- diff = obs[newaxis,:,:] - code_book[:,newaxis,:]
+ diff = obs[newaxis, :, :] - code_book[:,newaxis,:]
dist = sqrt(N.sum(diff * diff, -1))
code = argmin(dist, 0)
min_dist = minimum.reduce(dist, 0) #the next line I think is equivalent
@@ -314,7 +318,7 @@
"""
code_book = array(guess, copy = True)
- Nc = code_book.shape[0]
+ nc = code_book.shape[0]
avg_dist = []
diff = thresh+1.
while diff > thresh:
@@ -324,7 +328,7 @@
#recalc code_book as centroids of associated obs
if(diff > thresh):
has_members = []
- for i in arange(Nc):
+ for i in arange(nc):
cell_members = compress(equal(obs_code, i), obs, 0)
if cell_members.shape[0] > 0:
code_book[i] = mean(cell_members, 0)
@@ -468,7 +472,20 @@
_valid_init_meth = {'random': _krandinit, 'points': _kpoints}
-def kmeans2(data, k, iter = 10, thresh = 1e-5, minit='random'):
+def _missing_warn():
+ """Print a warning when called."""
+ warnings.warn("One of the clusters is empty. "
+ "Re-run kmean with a different initialization.")
+
+def _missing_raise():
+ """raise a ClusterError when called."""
+ raise ClusterError, "One of the clusters is empty. "\
+ "Re-run kmean with a different initialization."
+
+_valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
+
+def kmeans2(data, k, iter = 10, thresh = 1e-5, minit = 'random',
+ missing = 'warn'):
"""Classify a set of points into k clusters using kmean algorithm.
The algorithm works by minimizing the euclidian distance between data points
@@ -510,6 +527,8 @@
cluster[label[i]].
"""
+ if missing not in _valid_miss_meth.keys():
+ raise ValueError("Unkown missing method: %s" % str(missing))
# If data is rank 1, then we have 1 dimension problem.
nd = N.ndim(data)
if nd == 1:
@@ -544,9 +563,9 @@
clusters = init(data, k)
assert not iter == 0
- return _kmeans2(data, clusters, iter, nc)
+ return _kmeans2(data, clusters, iter, nc, _valid_miss_meth[missing])
-def _kmeans2(data, code, niter, nc):
+def _kmeans2(data, code, niter, nc, missing):
""" "raw" version of kmeans2. Do not use directly.
Run kmeans with a given initial codebook. """
@@ -560,8 +579,7 @@
if mbs[0].size > 0:
code[j] = N.mean(data[mbs], axis=0)
else:
- warnings.warn("One of the clusters are empty. " \
- "Re-run kmean with a different initialization.")
+ missing()
return code, label
More information about the Scipy-svn
mailing list