[Scipy-svn] r2690 - in trunk/Lib/sandbox: . rbf rbf/tests spline/tests
scipy-svn@scip...
scipy-svn@scip...
Thu Feb 8 06:05:31 CST 2007
Author: jtravs
Date: 2007-02-08 06:05:18 -0600 (Thu, 08 Feb 2007)
New Revision: 2690
Added:
trunk/Lib/sandbox/rbf/
trunk/Lib/sandbox/rbf/README.txt
trunk/Lib/sandbox/rbf/__init__.py
trunk/Lib/sandbox/rbf/info.py
trunk/Lib/sandbox/rbf/rbf.py
trunk/Lib/sandbox/rbf/setup.py
trunk/Lib/sandbox/rbf/tests/
trunk/Lib/sandbox/rbf/tests/example.py
trunk/Lib/sandbox/rbf/tests/test_rbf.py
Modified:
trunk/Lib/sandbox/setup.py
trunk/Lib/sandbox/spline/tests/test_fitpack.py
Log:
Added new rbf package to sandbox
Added: trunk/Lib/sandbox/rbf/README.txt
===================================================================
--- trunk/Lib/sandbox/rbf/README.txt 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/README.txt 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,8 @@
+This package uses radial basis functions for n-dimensional
+smoothing/interpolation of scattered data
+
+It is closely based on the MAtlab code by Alex Chirokov found at:
+
+http://www.mathworks.com/matlabcentral/fileexchange/loadFile.do?objectId=10056&objectType=FILE
+
+John Travers
\ No newline at end of file
Added: trunk/Lib/sandbox/rbf/__init__.py
===================================================================
--- trunk/Lib/sandbox/rbf/__init__.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/__init__.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,11 @@
+#
+# rbf - Radial Basis Functions
+#
+
+from info import __doc__
+
+from rbf import *
+
+__all__ = filter(lambda s:not s.startswith('_'),dir())
+from numpy.testing import NumpyTest
+test = NumpyTest().test
\ No newline at end of file
Added: trunk/Lib/sandbox/rbf/info.py
===================================================================
--- trunk/Lib/sandbox/rbf/info.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/info.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,9 @@
+"""
+Radial Basis Functions
+===================
+
+rbf - Radial basis functions for interpolation/smoothing.
+
+"""
+
+postpone_import = 1
Added: trunk/Lib/sandbox/rbf/rbf.py
===================================================================
--- trunk/Lib/sandbox/rbf/rbf.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/rbf.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,126 @@
+#!/usr/bin/env python
+"""
+rbf - Radial basis functions for interpolation/smoothing scattered Nd data.
+
+Written by John Travers <jtravs@gmail.com>, February 2007
+Based closely on Matlab code by Alex Chirokov
+
+Permission to use, modify, and distribute this software is given under the
+terms of the SciPy (BSD style) license. See LICENSE.txt that came with
+this distribution for specifics.
+
+NO WARRANTY IS EXPRESSED OR IMPLIED. USE AT YOUR OWN RISK.
+
+"""
+
+import scipy as s
+import scipy.linalg
+
+class Rbf(object):
+ """ A class for radial basis function approximation/interpolation of
+ n-dimensional scattered data.
+ """
+ def __init__(self,x,y, function='multiquadrics', constant=None, smooth=0):
+ """ Constructor for Rbf class.
+
+ Inputs:
+ x (dim, n) array of coordinates for the nodes
+ y (n,) array of values at the nodes
+ function the radial basis function
+ 'linear', 'cubic' 'thinplate', 'multiquadrics'
+ or 'gaussian', default is 'multiquadrics'
+ constant adjustable constant for gaussian or multiquadrics
+ functions - defaults to approximate average distance
+ between nodes (which is a good start)
+ smooth values greater than zero increase the smoothness
+ of the approximation.
+ 0 is for interpolation (default), the function will
+ always go through the nodal points in this case.
+
+ Outputs: None
+ """
+ if len(x.shape) == 1:
+ nxdim = 1
+ nx = x.shape[0]
+ else:
+ (nxdim, nx)=x.shape
+ if len(y.shape) == 1:
+ nydim = 1
+ ny = y.shape[0]
+ else:
+ (nydim, ny)=y.shape
+ x.shape = (nxdim, nx)
+ y.shape = (nydim, ny)
+ if nx != ny:
+ raise ValueError, 'x and y should have the same number of points'
+ if nydim != 1:
+ raise ValueError, 'y should be a length n vector'
+ self.x = x
+ self.y = y
+ self.function = function
+ if (constant==None
+ and ((function == 'multiquadrics') or (function == 'gaussian'))):
+ # approx. average distance between the nodes
+ constant = (s.product(x.T.max(0)-x.T.min(0),axis=0)/nx)**(1/nxdim)
+ self.constant = constant
+ self.smooth = smooth
+ if self.function == 'linear':
+ self.phi = lambda r: r
+ elif self.function == 'cubic':
+ self.phi = lambda r: r*r*r
+ elif self.function == 'multiquadrics':
+ self.phi = lambda r: s.sqrt(1.0+r*r/(self.constant*self.constant))
+ elif self.function == 'thinplate':
+ self.phi = lambda r: r*r*s.log(r+1)
+ elif self.function == 'gaussian':
+ self.phi = lambda r: s.exp(-0.5*r*r/(self.rbfconst*self.constant))
+ else:
+ raise ValueError, 'unkown function'
+ A = self._rbf_assemble()
+ b=s.r_[y.T, s.zeros((nxdim+1, 1), float)]
+ self.coeff = s.linalg.solve(A,b)
+
+ def __call__(self, xi):
+ """ Evaluate the radial basis function approximation at points xi.
+
+ Inputs:
+ xi (dim, n) array of coordinates for the points to evaluate at
+
+ Outputs:
+ y (n,) array of values at the points xi
+ """
+ if len(xi.shape) == 1:
+ nxidim = 1
+ nxi = xi.shape[0]
+ else:
+ (nxidim, nxi)=xi.shape
+ xi.shape = (nxidim, nxi)
+ (nxdim, nx) = self.x.shape
+ if nxdim != nxidim:
+ raise ValueError, 'xi should have the same number of rows as an' \
+ ' array used to create RBF interpolation'
+ f = s.zeros(nxi, float)
+ r = s.zeros(nx, float)
+ for i in range(nxi):
+ st=0.0
+ r = s.dot(xi[:,i,s.newaxis],s.ones((1,nx))) - self.x
+ r = s.sqrt(sum(r*r))
+ st = self.coeff[nx,:] + s.sum(self.coeff[0:nx,:].flatten()*self.phi(r))
+ for k in range(nxdim):
+ st=st+self.coeff[k+nx+1,:]*xi[k,i]
+ f[i] = st
+ return f
+
+ def _rbf_assemble(self):
+ (nxdim, nx)=self.x.shape
+ A=s.zeros((nx,nx), float)
+ for i in range(nx):
+ for j in range(i+1):
+ r=s.linalg.norm(self.x[:,i]-self.x[:,j])
+ temp=self.phi(r)
+ A[i,j]=temp
+ A[j,i]=temp
+ A[i,i] = A[i,i] - self.smooth
+ P = s.c_[s.ones((nx,1), float), self.x.T]
+ A = s.r_[s.c_[A, P], s.c_[P.T, s.zeros((nxdim+1,nxdim+1), float)]]
+ return A
Added: trunk/Lib/sandbox/rbf/setup.py
===================================================================
--- trunk/Lib/sandbox/rbf/setup.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/setup.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+
+import os
+
+def configuration(parent_package='',top_path=None):
+ from numpy.distutils.misc_util import Configuration
+
+ config = Configuration('rbf', parent_package, top_path)
+
+ config.add_data_dir('tests')
+
+ return config
+
+if __name__ == '__main__':
+ from numpy.distutils.core import setup
+ setup(**configuration(top_path='').todict())
\ No newline at end of file
Added: trunk/Lib/sandbox/rbf/tests/example.py
===================================================================
--- trunk/Lib/sandbox/rbf/tests/example.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/tests/example.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,52 @@
+import scipy as s
+import scipy.interpolate
+
+from scipy.sandbox.rbf import Rbf
+
+import matplotlib
+matplotlib.use('Agg')
+import pylab as p
+
+# 1d tests - setup data
+x = s.linspace(0,10,9)
+y = s.sin(x)
+xi = s.linspace(0,10,101)
+
+# use interpolate methods
+ius = s.interpolate.InterpolatedUnivariateSpline(x,y)
+yi = ius(xi)
+p.subplot(2,1,1)
+p.plot(x,y,'o',xi,yi, xi, s.sin(xi),'r')
+p.title('Interpolation using current scipy fitpack2')
+
+# use RBF method
+rbf = Rbf(x, y)
+fi = rbf(xi)
+p.subplot(2,1,2)
+p.plot(x,y,'bo',xi.flatten(),fi.flatten(),'g',xi.flatten(),
+ s.sin(xi.flatten()),'r')
+p.title('RBF interpolation - multiquadrics')
+p.savefig('rbf1dtest.png')
+p.close()
+
+# 2-d tests - setup scattered data
+x = s.rand(50,1)*4-2
+y = s.rand(50,1)*4-2
+z = x*s.exp(-x**2-y**2)
+ti = s.linspace(-2.0,2.0,81)
+(XI,YI) = s.meshgrid(ti,ti)
+
+# use RBF
+rbf = Rbf(s.c_[x.flatten(),y.flatten()].T,z.T,constant=2)
+ZI = rbf(s.c_[XI.flatten(), YI.flatten()].T)
+ZI.shape = XI.shape
+
+# plot the result
+from enthought.tvtk.tools import mlab
+f=mlab.figure(browser=False)
+su=mlab.Surf(XI,YI,ZI,ZI,scalar_visibility=True)
+f.add(su)
+su.lut_type='blue-red'
+f.objects[0].axis.z_label='value'
+pp = mlab.Spheres(s.c_[x.flatten(), y.flatten(), z.flatten()],radius=0.03)
+f.add(pp)
\ No newline at end of file
Property changes on: trunk/Lib/sandbox/rbf/tests/example.py
___________________________________________________________________
Name: svn:executable
+ *
Added: trunk/Lib/sandbox/rbf/tests/test_rbf.py
===================================================================
--- trunk/Lib/sandbox/rbf/tests/test_rbf.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/rbf/tests/test_rbf.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+# Created by John Travers, February 2007
+""" Test functions for rbf module """
+
+from numpy.testing import *
+import numpy as n
+
+set_package_path()
+from rbf.rbf import Rbf
+restore_path()
+
+class test_Rbf1D(NumpyTestCase):
+ def check_multiquadrics(self):
+ x = n.linspace(0,10,9)
+ y = n.sin(x)
+ rbf = Rbf(x, y)
+ yi = rbf(x)
+ assert_array_almost_equal(y.flatten(), yi)
+
+class test_Rbf2D(NumpyTestCase):
+ def check_multiquadrics(self):
+ x = n.random.rand(50,1)*4-2
+ y = n.random.rand(50,1)*4-2
+ z = x*n.exp(-x**2-y**2)
+ rbf = Rbf(n.c_[x.flatten(),y.flatten()].T,z.T,constant=2)
+ zi = rbf(n.c_[x.flatten(), y.flatten()].T)
+ zi.shape = x.shape
+ assert_array_almost_equal(z, zi)
+
+if __name__ == "__main__":
+ NumpyTest().run()
\ No newline at end of file
Modified: trunk/Lib/sandbox/setup.py
===================================================================
--- trunk/Lib/sandbox/setup.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/setup.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -81,6 +81,9 @@
# New spline package (based on scipy.interpolate)
#config.add_subpackage('spline')
+
+ # Radial basis functions package
+ #config.add_subpackage('rbf')
return config
Modified: trunk/Lib/sandbox/spline/tests/test_fitpack.py
===================================================================
--- trunk/Lib/sandbox/spline/tests/test_fitpack.py 2007-02-07 14:32:55 UTC (rev 2689)
+++ trunk/Lib/sandbox/spline/tests/test_fitpack.py 2007-02-08 12:05:18 UTC (rev 2690)
@@ -142,3 +142,6 @@
decimal=1)
assert_almost_equal(0.0,
around(abs(splev(uv[0],tck)-f(uv[0])),2),decimal=1)
+
+if __name__ == "__main__":
+ NumpyTest().run()
\ No newline at end of file
More information about the Scipy-svn
mailing list