[Scipy-svn] r3207 - in trunk/Lib/linalg: . tests
scipy-svn@scip...
scipy-svn@scip...
Mon Jul 30 09:58:04 CDT 2007
Author: cdavid
Date: 2007-07-30 09:57:57 -0500 (Mon, 30 Jul 2007)
New Revision: 3207
Modified:
trunk/Lib/linalg/iterative.py
trunk/Lib/linalg/tests/test_iterative.py
Log:
Copy initial values in iterative solvers to avoid overwriting input arguments. See ticket #470
Modified: trunk/Lib/linalg/iterative.py
===================================================================
--- trunk/Lib/linalg/iterative.py 2007-07-29 00:54:10 UTC (rev 3206)
+++ trunk/Lib/linalg/iterative.py 2007-07-30 14:57:57 UTC (rev 3207)
@@ -12,6 +12,7 @@
__all__ = ['bicg','bicgstab','cg','cgs','gmres','qmr']
from scipy.linalg import _iterative
import numpy as sb
+import copy
try:
False, True
@@ -148,9 +149,10 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
if xtype is None:
try:
@@ -266,10 +268,12 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
+
if xtype is None:
try:
atyp = A.dtype.char
@@ -376,10 +380,12 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
+
if xtype is None:
try:
atyp = A.dtype.char
@@ -486,9 +492,10 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
if xtype is None:
try:
@@ -598,9 +605,10 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
if xtype is None:
try:
@@ -710,9 +718,10 @@
if maxiter is None:
maxiter = n*10
- x = x0
- if x is None:
+ if x0 is None:
x = sb.zeros(n)
+ else:
+ x = copy.copy(x0)
if xtype is None:
try:
Modified: trunk/Lib/linalg/tests/test_iterative.py
===================================================================
--- trunk/Lib/linalg/tests/test_iterative.py 2007-07-29 00:54:10 UTC (rev 3206)
+++ trunk/Lib/linalg/tests/test_iterative.py 2007-07-30 14:57:57 UTC (rev 3207)
@@ -45,27 +45,39 @@
b = self.b
def check_cg(self):
+ bx0 = self.x0.copy()
x, info = cg(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
def check_bicg(self):
+ bx0 = self.x0.copy()
x, info = bicg(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
def check_cgs(self):
+ bx0 = self.x0.copy()
x, info = cgs(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
def check_bicgstab(self):
+ bx0 = self.x0.copy()
x, info = bicgstab(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
def check_gmres(self):
+ bx0 = self.x0.copy()
x, info = gmres(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
def check_qmr(self):
+ bx0 = self.x0.copy()
x, info = qmr(self.A, self.b, self.x0, callback=callback)
+ assert_array_equal(bx0, self.x0)
assert norm(dot(self.A, x) - self.b) < 5*self.tol
if __name__ == "__main__":
More information about the Scipy-svn
mailing list