Changeset 1505

Show
Ignore:
Timestamp:
09/28/08 13:28:45 (2 months ago)
Author:
pierregm
Message:

tstables
* fixed tabulate to allow the processing of standard ndarrays
* add support for readCoordinates
* add support for append

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • trunk/timeseries/scikits/timeseries/lib/tests/test_tstables.py

    r1396 r1505  
    99import numpy.ma as ma 
    1010import numpy.ma.mrecords as mr 
    11 from numpy.ma import masked_array, masked 
     11from numpy.ma import MaskedArray, masked_array, masked 
    1212 
    1313import scikits.timeseries as ts 
     14from scikits.timeseries import TimeSeries 
    1415 
    1516from numpy.testing import * 
     
    339340        assert_equal(test._hardmask, data._hardmask) 
    340341 
     342 
     343class TestTableRead(TestCase): 
     344    # 
     345    def __init__(self, *args, **kwds): 
     346        TestCase.__init__(self, *args, **kwds) 
     347        series = ts.time_series(zip(np.random.rand(10), 
     348                                    np.arange(10)), 
     349                                start_date=ts.now('M'), 
     350                                dtype=[('a',float),('b',int)]) 
     351        series.mask[0] = (0,1) 
     352        series.mask[-1] = (1,0) 
     353        self.tseries = series 
     354        self.marray = series._series 
     355        self.file = tempfile.mktemp(".hdf5") 
     356        self.h5file = tables.openFile(self.file,'a') 
     357        self.populate() 
     358    # 
     359    def tearDown(self): 
     360        if self.h5file.isopen: 
     361            self.h5file.close() 
     362        os.remove(self.file) 
     363    # 
     364    def populate(self): 
     365        h5file = self.h5file 
     366        table = h5file.createMaskedTable('/', 'marray', self.marray, "") 
     367        h5file.flush() 
     368        table = h5file.createTimeSeriesTable('/', 'tseries', self.tseries, "") 
     369        h5file.flush() 
     370    # 
     371    def test_tseries_read(self): 
     372        "Test reading specific elements of a TimeSeriesTable" 
     373        table = self.h5file.root.tseries 
     374        series = self.tseries 
     375        # 
     376        test = table.read() 
     377        assert(isinstance(test, TimeSeries)) 
     378        assert_equal_records(test, series) 
     379        # 
     380        test = table.read(field='a') 
     381        assert(isinstance(test, TimeSeries)) 
     382        assert_equal(test, series['a']) 
     383        # 
     384        test = table.read(step=2) 
     385        assert(isinstance(test, TimeSeries)) 
     386        assert_equal(test, series[::2]) 
     387        # 
     388        test = table.readCoordinates([1,2,3]) 
     389        assert(isinstance(test, TimeSeries)) 
     390        assert_equal(test, series[[1,2,3]]) 
     391        # 
     392        test = table.readCoordinates([1,2,3], field='a') 
     393        assert(isinstance(test, TimeSeries)) 
     394        assert_equal(test, series['a'][[1,2,3]]) 
     395    # 
     396    def test_marray_read(self): 
     397        "Test reading specific elements of a MaskedTable" 
     398        table = self.h5file.root.marray 
     399        data = self.marray 
     400        # 
     401        test = table.read() 
     402        assert(isinstance(test, MaskedArray)) 
     403        assert_equal_records(test, data) 
     404        # 
     405        test = table.read(field='a') 
     406        assert(isinstance(test, MaskedArray)) 
     407        assert_equal(test, data['a']) 
     408        # 
     409        test = table.read(step=2) 
     410        assert(isinstance(test, MaskedArray)) 
     411        assert_equal(test, data[::2]) 
     412        # 
     413        test = table.readCoordinates([1,2,3]) 
     414        assert(isinstance(test, MaskedArray)) 
     415        assert_equal(test, data[[1,2,3]]) 
     416        # 
     417        test = table.readCoordinates([1,2,3], field='a') 
     418        assert(isinstance(test, MaskedArray)) 
     419        assert_equal(test, data['a'][[1,2,3]]) 
     420 
    341421############################################################################### 
    342422#------------------------------------------------------------------------------ 
  • trunk/timeseries/scikits/timeseries/lib/tstables.py

    r1478 r1505  
    144144 
    145145 
     146_doc_parameters = dict( 
     147mareturn=""" 
     148    Depending on the value of the ``field`` parameter, the method returns either 
     149    * a ndarray, if ``field=='_data'`` or if ``field=='_mask'``; 
     150    * a :class:`~numpy.ma.MaskedArray`, if ``field`` is None or a valid field. 
     151""", 
     152tsreturn=""" 
     153    Depending on the value of the ``field`` parameter, the method returns: 
     154     
     155    * a :class:`~scikits.timeseries.TimeSeries`, if ``field`` is None or a valid 
     156      field; 
     157    * a :class:`~scikits.timeseries.DateArray`, if ``field=='_dates'``; 
     158    * a ndarray, if ``field=='_data'`` or if ``field=='_mask'``; 
     159    * a :class:`~numpy.ma.MaskedArray`, if ``field=='_series'``. 
     160""", 
     161readinput=""" 
     162    start : {None, int}, optional 
     163        Index of the first record to read. 
     164        If None, records will be read starting from the very first one. 
     165    stop : {None, int}, optional 
     166        Index of the last record to read. 
     167        If None, records will be read until the very last one. 
     168    step : {None, int}, optional 
     169        Increment between succesive records to read. 
     170        If None, all the records between ``start`` and ``stop`` will be read. 
     171    field : {None, str}, optional 
     172        Name of the field to read. 
     173        If None, all the fields from each record are read. 
     174""", 
     175readcoordinateinput=""" 
     176    coords : sequence 
     177        A sequence of integers, corresponding to the indices of the rows to 
     178        retrieve 
     179    field : {None, str}, optional 
     180        Name of the field to read. 
     181        If None, all the fields from each record are read. 
     182""", 
     183) 
     184 
    146185 
    147186 
     
    160199        else: 
    161200            pseudodtype = [('_data',basedtype),('_mask',bool)] 
    162         pseudo = itertools.izip(a.filled(), ma.getmaskarray(a)) 
     201        pseudo = itertools.izip(ma.filled(a), ma.getmaskarray(a)) 
    163202    else: 
    164203        pseudodtype = [(fname,[('_data',ftype), ('_mask',bool)]) 
    165204                       for (fname,ftype) in basedtype.descr] 
    166205        fields = [a[f] for f in basenames] 
    167         pseudo = itertools.izip(*[zip(f.filled().flat,ma.getmaskarray(f).flat) 
     206        pseudo = itertools.izip(*[zip(ma.filled(f).flat,ma.getmaskarray(f).flat) 
    168207                                 for f in fields]) 
    169208    return np.fromiter(pseudo, dtype=pseudodtype) 
     
    223262        A ndarray with flexible dtype. 
    224263    """ 
     264    a = np.asanyarray(a) 
    225265    if isinstance(a, TimeSeries): 
    226266        return _tabulate_time_series(a) 
     
    279319 
    280320 
    281     def read(self, start=None, stop=None, step=None, field=None): 
    282         """ 
    283     Reads the current :class:`MaskedTable`. 
    284      
    285     Returns 
    286     ------- 
    287      
    288     Depending on the value of the ``field`` parameter, the method returns either 
    289     * a ndarray, if ``field=='_data'`` or if ``field=='_mask'``; 
    290     * a :class:`~numpy.ma.MaskedArray`, if ``field`` is None or a valid field. 
    291      
    292     Parameters 
    293     ---------- 
    294     start : {None, int}, optional 
    295         Index of the first record to read. 
    296         If None, records will be read starting from the very first one. 
    297     stop : {None, int}, optional 
    298         Index of the last record to read. 
    299         If None, records will be read until the very last one. 
    300     step : {None, int}, optional 
    301         Increment between succesive records to read. 
    302         If None, all the records between ``start`` and ``stop`` will be read. 
    303     field : {None, str}, optional 
    304         Name of the field to read. 
    305         If None, all the fields from each record are read. 
    306         The argument should be one of the field of the series, or one of the 
    307         following: ``'_data','_mask'``. 
    308         """ 
    309         data = Table.read(self, start=start, stop=stop, step=step, 
    310                           field=field) 
     321    def _reader(self, meth, *args, **kwargs): 
     322        """ 
     323    Private function that retransforms the output of Table.read and equivalent 
     324    to the proper type. 
     325        """ 
     326        data = meth(self, *args, **kwargs) 
    311327        special_attrs = getattr(self.attrs, 'special_attrs', {}) 
    312328        fill_value = special_attrs.get('_fill_value', None) 
     
    314330        ndtype = self._get_dtype() 
    315331        field_names = ndtype.names 
     332        field = kwargs.get('field', None) 
    316333        # 
    317334        if field in ['_data','_mask']: 
     
    341358 
    342359 
     360    def read(self, start=None, stop=None, step=None, field=None): 
     361        """ 
     362    Reads the current :class:`MaskedTable`. 
     363     
     364    Returns 
     365    ------- 
     366    %(mareturn)s 
     367     
     368    Parameters 
     369    ---------- 
     370    %(readinput)s 
     371        """ 
     372        args = () 
     373        kwargs = dict(field=field, start=start, stop=stop, step=step) 
     374        return self._reader(Table.read, *args, **kwargs) 
     375    read.__doc__ = ((read.__doc__ or '') % _doc_parameters) or None 
     376 
     377 
     378    def readCoordinates(self, coords, field=None): 
     379        """ 
     380    Reads a set of rows given their coordinates. 
     381 
     382    Parameters 
     383    ---------- 
     384    %(readcoordinateinput)s 
     385     
     386    Returns 
     387    ------- 
     388    %(mareturn)s 
     389        """ 
     390        args = (coords,) 
     391        kwargs = dict(field=field) 
     392        return self._reader(Table.readCoordinates, *args, **kwargs) 
     393    readCoordinates.__doc__ = ((readCoordinates.__doc__ or '') % _doc_parameters) or None 
     394 
     395    def append(self, rows): 
     396        """ 
     397        """ 
     398        rows = tabulate(rows) 
     399        Table.append(self, rows) 
     400 
     401 
    343402class TimeSeriesTable(MaskedTable): 
    344403    """ 
     
    385444 
    386445 
    387     def read(self, start=None, stop=None, step=None, field=None): 
    388         """ 
    389     Reads the current :class:`TimeSeriesTable`. 
    390      
    391     Returns 
    392     ------- 
    393      
    394     Depending on the value of the ``field`` parameter, the method returns: 
    395      
    396     * a :class:`~scikits.timeseries.TimeSeries`, if ``field`` is None or a valid 
    397       field; 
    398     * a :class:`~scikits.timeseries.DateArray`, if ``field=='_dates'``; 
    399     * a ndarray, if ``field=='_data'`` or if ``field=='_mask'``; 
    400     * a :class:`~numpy.ma.MaskedArray`, if ``field=='_series'``. 
    401      
    402      
    403     Parameters 
    404     ---------- 
    405     start : {None, int}, optional 
    406         Index of the first record to read. 
    407         If None, records will be read starting from the very first one. 
    408     stop : {None, int}, optional 
    409         Index of the last record to read. 
    410         If None, records will be read until the very last one. 
    411     step : {None, int}, optional 
    412         Increment between succesive records to read. 
    413         If None, all the records between ``start`` and ``stop`` will be read. 
    414     field : {None, str}, optional 
    415         Name of the field to read. 
    416         If None, all the fields from each record are read. 
    417         The argument should be one of the field of the series, or one of the 
    418         following: ``'_data','_mask','_dates','_series'``. 
    419         """ 
    420 #        data = Table.read(self, start=start, stop=stop, step=step, 
    421 #                          field=field) 
     446#    def read(self, start=None, stop=None, step=None, field=None): 
     447#        """ 
     448#    Reads the current :class:`TimeSeriesTable`. 
     449#     
     450#    Returns 
     451#    ------- 
     452#     
     453#    Depending on the value of the ``field`` parameter, the method returns: 
     454#     
     455#    * a :class:`~scikits.timeseries.TimeSeries`, if ``field`` is None or a valid 
     456#      field; 
     457#    * a :class:`~scikits.timeseries.DateArray`, if ``field=='_dates'``; 
     458#    * a ndarray, if ``field=='_data'`` or if ``field=='_mask'``; 
     459#    * a :class:`~numpy.ma.MaskedArray`, if ``field=='_series'``. 
     460#     
     461#     
     462#    Parameters 
     463#    ---------- 
     464#    start : {None, int}, optional 
     465#        Index of the first record to read. 
     466#        If None, records will be read starting from the very first one. 
     467#    stop : {None, int}, optional 
     468#        Index of the last record to read. 
     469#        If None, records will be read until the very last one. 
     470#    step : {None, int}, optional 
     471#        Increment between succesive records to read. 
     472#        If None, all the records between ``start`` and ``stop`` will be read. 
     473#    field : {None, str}, optional 
     474#        Name of the field to read. 
     475#        If None, all the fields from each record are read. 
     476#        The argument should be one of the field of the series, or one of the 
     477#        following: ``'_data','_mask','_dates','_series'``. 
     478#        """ 
     479##        data = Table.read(self, start=start, stop=stop, step=step, 
     480##                          field=field) 
     481#        special_attrs = getattr(self.attrs, 'special_attrs', {}) 
     482#        fill_value = special_attrs.get('_fill_value', None) 
     483#        baseclass = special_attrs.get('_baseclass', np.ndarray) 
     484#        # 
     485#        position_keywords = dict(start=start, stop=stop, step=step) 
     486#        # 
     487#        ndtype = self._get_dtype() 
     488#        field_names = ndtype.names 
     489#         
     490#        # Case 1. : Global read ................. 
     491#        if field is None: 
     492#            data = Table.read(self, **position_keywords) 
     493#            dates = DateArray(data['_dates'], 
     494#                              freq=special_attrs.get('freq','U')) 
     495#            if field_names is None: 
     496#                output = time_series(data['_data'], 
     497#                                     dates = dates, 
     498#                                     mask=data['_mask']) 
     499#            else: 
     500#                output = ma.empty(data.shape, dtype=ndtype).view(TimeSeries) 
     501#                for name in field_names: 
     502#                    current = data[name] 
     503#                    output[name] = ma.array(current['_data'], 
     504#                                            mask=current['_mask']) 
     505#                output._dates = dates 
     506#            # Reset some attributes.................. 
     507#            output._baseclass = baseclass 
     508#            output.fill_value = fill_value 
     509#            output._hardmask = special_attrs.get('_hardmask', False) 
     510#            output._optinfo = special_attrs.get('_optinfo', {}) 
     511#        # Case 2. Partial reads.................. 
     512#        elif field in ['_dates','_data','_mask']: 
     513#            output = Table.read(self, field=field, **position_keywords) 
     514#        # Case 3. The series as a masked array 
     515#        elif field == '_series': 
     516#            # Special case: read the table, but keep it as MaskedArray 
     517#            data = Table.read(self, field=None, **position_keywords) 
     518#            if field_names is None: 
     519#                output = ma.array(data['_data'], mask=data['_mask']) 
     520#            else: 
     521#                output = ma.empty(data.shape, dtype=ndtype) 
     522#                for name in field_names: 
     523#                    current = data[name] 
     524#                    output[name] = ma.array(current['_data'], 
     525#                                            mask=current['_mask']) 
     526#            output.fill_value = fill_value 
     527#            output._baseclass = baseclass 
     528#            output._hardmask = special_attrs.get('_hardmask', False) 
     529#            output._optinfo = special_attrs.get('_optinfo', {}) 
     530#        # Case 4. Field read .................... 
     531#        elif field in field_names: 
     532#            data = Table.read(self, field=field, **position_keywords) 
     533#            dates = Table.read(self, field='_dates', **position_keywords) 
     534#            dates = DateArray(dates, freq=special_attrs.get('freq','U')) 
     535#            # Get the data part 
     536#            output = time_series(data['_data'], 
     537#                                 dates=dates, 
     538#                                 mask=data['_mask'],) 
     539#            output._baseclass = baseclass 
     540#            if fill_value is not None: 
     541#                output.fill_value = fill_value[field] 
     542#            output._hardmask = special_attrs.get('_hardmask', False) 
     543#            output._optinfo = special_attrs.get('_optinfo', {}) 
     544#        else: 
     545#            raise KeyError("Unable to process field '%s'" % field) 
     546#        return output 
     547 
     548 
     549    def _reader(self, meth, *args, **kwargs): 
     550        """ 
     551    Private function that retransforms the output of Table.read and equivalent 
     552    to the proper type. 
     553        """ 
     554        # Get some of the special attributes 
    422555        special_attrs = getattr(self.attrs, 'special_attrs', {}) 
    423556        fill_value = special_attrs.get('_fill_value', None) 
    424557        baseclass = special_attrs.get('_baseclass', np.ndarray) 
    425         # 
    426         position_keywords = dict(start=start, stop=stop, step=step) 
    427         # 
     558        # Get the dtype, in particular, the names of the fields. 
    428559        ndtype = self._get_dtype() 
    429560        field_names = ndtype.names 
    430561         
     562        field = kwargs.get('field', None) 
    431563        # Case 1. : Global read ................. 
    432564        if field is None: 
    433             data = Table.read(self, **position_keywords) 
     565            data = meth(self, *args, **kwargs) 
    434566            dates = DateArray(data['_dates'], 
    435567                              freq=special_attrs.get('freq','U')) 
     
    452584        # Case 2. Partial reads.................. 
    453585        elif field in ['_dates','_data','_mask']: 
    454             output = Table.read(self, field=field, **position_keywords) 
     586            output = meth(self, *args, **kwargs) 
    455587        # Case 3. The series as a masked array 
    456588        elif field == '_series': 
    457589            # Special case: read the table, but keep it as MaskedArray 
    458             data = Table.read(self, field=None, **position_keywords) 
     590            kwargs['field'] = None 
     591            data = meth(self, *args, **kwargs) 
    459592            if field_names is None: 
    460593                output = ma.array(data['_data'], mask=data['_mask']) 
     
    471604        # Case 4. Field read .................... 
    472605        elif field in field_names: 
    473             data = Table.read(self, field=field, **position_keywords) 
    474             dates = Table.read(self, field='_dates', **position_keywords) 
     606            data = meth(self, *args, **kwargs) 
     607            kwargs['field'] = '_dates' 
     608            dates = meth(self, *args, **kwargs) 
    475609            dates = DateArray(dates, freq=special_attrs.get('freq','U')) 
    476610            # Get the data part 
     
    488622 
    489623 
     624    def readCoordinates(self, coords, field=None): 
     625        """ 
     626    Reads a set of rows given their coordinates. 
     627 
     628    Parameters 
     629    ---------- 
     630    %(readcoordinateinput)s 
     631     
     632    Returns 
     633    ------- 
     634    %(tsreturn)s 
     635        """ 
     636        args = (coords,) 
     637        kwargs = dict(field=field) 
     638        return self._reader(Table.readCoordinates, *args, **kwargs) 
     639    readCoordinates.__doc__ = ((readCoordinates.__doc__ or '') % _doc_parameters) or None 
     640 
     641 
     642    def read(self, start=None, stop=None, step=None, field=None): 
     643        """ 
     644    Reads the current :class:`TimeSeriesTable`. 
     645     
     646    Returns 
     647    ------- 
     648    %(tsreturn)s 
     649     
     650    Parameters 
     651    ---------- 
     652    %(readcoordinateinput)s 
     653        """ 
     654        args = () 
     655        kwargs = dict(field=field, start=start, stop=stop, step=step) 
     656        return self._reader(Table.read, *args, **kwargs) 
     657    read.__doc__ = ((read.__doc__ or '') % _doc_parameters) or None 
     658 
     659 
    490660#-- File extensions -----------------------------------------------------------                        
    491661