[Scipy-svn] r3156 - trunk/Lib/sparse
scipy-svn@scip...
scipy-svn@scip...
Mon Jul 9 20:53:12 CDT 2007
Author: wnbell
Date: 2007-07-09 20:53:10 -0500 (Mon, 09 Jul 2007)
New Revision: 3156
Modified:
trunk/Lib/sparse/sparse.py
Log:
fixed rmatvec bug, resolves ticket #462
Modified: trunk/Lib/sparse/sparse.py
===================================================================
--- trunk/Lib/sparse/sparse.py 2007-07-09 16:32:28 UTC (rev 3155)
+++ trunk/Lib/sparse/sparse.py 2007-07-10 01:53:10 UTC (rev 3156)
@@ -588,41 +588,24 @@
oth = numpy.ravel(other)
y = fn(self.shape[0], self.shape[1], \
self.indptr, self.indices, self.data, oth)
- if isinstance(other, matrix):
+ if isinstance(other, matrix):
y = asmatrix(y)
+ if other.ndim == 2 and other.shape[1] == 1:
# If 'other' was an (nx1) column vector, transpose the result
# to obtain an (mx1) column vector.
- if other.ndim == 2 and other.shape[1] == 1:
- y = y.T
+ y = y.T
return y
+
elif isspmatrix(other):
raise TypeError, "use matmat() for sparse * sparse"
else:
raise TypeError, "need a dense vector"
- def _rmatvec(self, other, shape0, shape1, fn, conjugate=True):
- if isdense(other):
- # This check is too harsh -- it prevents a column vector from
- # being created on-the-fly like dense matrix objects can.
- # if len(other) != self.shape[0]:
- # raise ValueError, "dimension mismatch"
- if conjugate:
- cd = conj(self.data)
- else:
- cd = self.data
- oth = numpy.ravel(other)
- y = fn(shape0, shape1, self.indptr, self.indices, cd, oth)
- if isinstance(other, matrix):
- y = asmatrix(y)
- # In the (unlikely) event that this matrix is 1x1 and 'other'
- # was an (mx1) column vector, transpose the result.
- if other.ndim == 2 and other.shape[1] == 1:
- y = y.T
- return y
- elif isspmatrix(other):
- raise TypeError, "use matmat() for sparse * sparse"
+ def rmatvec(self, other, conjugate=True):
+ if conjugate:
+ return self.transpose().conj() * other
else:
- raise TypeError, "need a dense vector"
+ return self.transpose() * other
def getdata(self, ind):
return self.data[ind]
@@ -942,13 +925,9 @@
def matvec(self, other):
return _cs_matrix._matvec(self, other, cscmux)
- def rmatvec(self, other, conjugate=True):
- return _cs_matrix._rmatvec(self, other, self.shape[1], self.shape[0], cscmux, conjugate=conjugate)
-
def matmat(self, other):
return _cs_matrix._matmat(self, other, cscmucsc)
-
def __getitem__(self, key):
if isinstance(key, tuple):
row = key[0]
@@ -1307,9 +1286,6 @@
def matvec(self, other):
return _cs_matrix._matvec(self, other, csrmux)
-
- def rmatvec(self, other, conjugate=True):
- return _cs_matrix._rmatvec(self, other, self.shape[0], self.shape[1], csrmux, conjugate=conjugate)
def matmat(self, other):
return _cs_matrix._matmat(self, other, csrmucsr)
More information about the Scipy-svn
mailing list