[Scipy-svn] r3127 - in trunk/Lib/sandbox/pyem: . tests
scipy-svn@scip...
scipy-svn@scip...
Sun Jul 1 04:52:13 CDT 2007
Author: cdavid
Date: 2007-07-01 04:52:06 -0500 (Sun, 01 Jul 2007)
New Revision: 3127
Modified:
trunk/Lib/sandbox/pyem/gmm_em.py
trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
Log:
Add support for EM in log domain + tests
Modified: trunk/Lib/sandbox/pyem/gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/gmm_em.py 2007-07-01 09:32:00 UTC (rev 3126)
+++ trunk/Lib/sandbox/pyem/gmm_em.py 2007-07-01 09:52:06 UTC (rev 3127)
@@ -1,5 +1,5 @@
# /usr/bin/python
-# Last Change: Sun Jul 01 05:00 PM 2007 J
+# Last Change: Sun Jul 01 06:00 PM 2007 J
"""Module implementing GMM, a class to estimate Gaussian mixture models using
EM, and EM, a class which use GMM instances to estimate models parameters using
@@ -331,7 +331,7 @@
def __init__(self):
pass
- def train(self, data, model, maxiter = 10, thresh = 1e-5):
+ def train(self, data, model, maxiter = 10, thresh = 1e-5, log = False):
"""Train a model using EM.
Train a model using data, and stops when the likelihood increase
@@ -366,7 +366,10 @@
model.init(data)
# Actual training
- like = self._train_simple_em(data, model, maxiter, thresh)
+ if log:
+ like = self._train_simple_em_log(data, model, maxiter, thresh)
+ else:
+ like = self._train_simple_em(data, model, maxiter, thresh)
return like
def _train_simple_em(self, data, model, maxiter, thresh):
@@ -385,6 +388,21 @@
if has_em_converged(like[i], like[i-1], thresh):
return like[0:i]
+ def _train_simple_em_log(self, data, model, maxiter, thresh):
+ # Likelihood is kept
+ like = N.zeros(maxiter)
+
+ # Em computation, with computation of the likelihood
+ g, tgd = model.compute_log_responsabilities(data)
+ like[0] = N.sum(densities.logsumexp(tgd), axis = 0)
+ model.update_em(data, N.exp(g))
+ for i in range(1, maxiter):
+ g, tgd = model.compute_log_responsabilities(data)
+ like[i] = N.sum(densities.logsumexp(tgd), axis = 0)
+ model.update_em(data, N.exp(g))
+ if has_em_converged(like[i], like[i-1], thresh):
+ return like[0:i]
+
class RegularizedEM:
# TODO: separate regularizer from EM class ?
def __init__(self, pcnt = _PRIOR_COUNT, pval = _COV_PRIOR):
Modified: trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_gmm_em.py 2007-07-01 09:32:00 UTC (rev 3126)
+++ trunk/Lib/sandbox/pyem/tests/test_gmm_em.py 2007-07-01 09:52:06 UTC (rev 3127)
@@ -1,5 +1,5 @@
#! /usr/bin/env python
-# Last Change: Wed Jun 13 07:00 PM 2007 J
+# Last Change: Sun Jul 01 06:00 PM 2007 J
# For now, just test that all mode/dim execute correctly
@@ -110,65 +110,55 @@
class test_datasets(EmTest):
"""This class tests whether the EM algorithms works using pre-computed
datasets."""
- def test_1d_full(self, level = 1):
- d = 1
- k = 4
- mode = 'full'
- # Data are exactly the same than in diagonal mode, just test that
- # calling full mode works even in 1d, even if it is kind of stupid to
- # do so
- dic = load_dataset('diag_1d_4k.mat')
+ def _test(self, dataset, log):
+ dic = load_dataset(dataset)
gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
gmm = GMM(gm, 'test')
- EM().train(dic['data'], gmm)
+ EM().train(dic['data'], gmm, log = log)
assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
- def test_1d_diag(self, level = 1):
+ def test_1d_full(self, level = 1):
d = 1
k = 4
- mode = 'diag'
- dic = load_dataset('diag_1d_4k.mat')
+ mode = 'full'
+ # Data are exactly the same than in diagonal mode, just test that
+ # calling full mode works even in 1d, even if it is kind of stupid to
+ # do so
+ filename = 'diag_1d_4k.mat'
+ self._test(filename, log = False)
- gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
- gmm = GMM(gm, 'test')
- EM().train(dic['data'], gmm)
-
- assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
def test_2d_full(self, level = 1):
d = 2
k = 3
mode = 'full'
- dic = load_dataset('full_2d_3k.mat')
+ filename = 'full_2d_3k.mat'
+ self._test(filename, log = False)
- gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
- gmm = GMM(gm, 'test')
- EM().train(dic['data'], gmm)
+ def test_2d_full_log(self, level = 1):
+ d = 2
+ k = 3
+ mode = 'full'
+ filename = 'full_2d_3k.mat'
+ self._test(filename, log = True)
- assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
def test_2d_diag(self, level = 1):
d = 2
k = 3
mode = 'diag'
- dic = load_dataset('diag_2d_3k.mat')
+ filename = 'diag_2d_3k.mat'
+ self._test(filename, log = False)
- gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
- gmm = GMM(gm, 'test')
- EM().train(dic['data'], gmm)
+ def test_2d_diag_log(self, level = 1):
+ d = 2
+ k = 3
+ mode = 'diag'
+ filename = 'diag_2d_3k.mat'
+ self._test(filename, log = True)
- assert_array_almost_equal(gmm.gm.w, dic['w'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.mu, dic['mu'], DEF_DEC)
- assert_array_almost_equal(gmm.gm.va, dic['va'], DEF_DEC)
-
class test_log_domain(EmTest):
"""This class tests whether the GMM works in log domain."""
def _test_common(self, d, k, mode):
More information about the Scipy-svn
mailing list