| | ''' |
| | dutils.py |
| | A utility library for customized data loading functions |
| | ''' |
| | import os |
| | import gzip |
| | import numpy as np |
| | import pandas as pd |
| |
|
| | import os |
| | import cv2 |
| | from typing import List, Union, Dict, Sequence |
| | import numpy as np |
| | import numpy.random as nprand |
| | import datetime |
| | import pandas as pd |
| | import h5py |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.nn.functional import avg_pool2d |
| | import random |
| | from torchvision import transforms as T |
| | from torchvision import datasets |
| | from torch.utils.data import Dataset, DataLoader |
| | from PIL import Image |
| |
|
| | SEVIR_ROOT_DIR = "data/SEVIR" |
| | METEO_FILE_DIR = "data/meteonet" |
| |
|
| | def resize(seq, size): |
| | |
| | seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) |
| | seq = seq.clamp(0,1) |
| | return seq.unsqueeze(2) |
| |
|
| | |
| | |
| | |
| | def pixel_to_dBZ_nonlinear(img): |
| | ''' |
| | [0, 255] OR [0, 1] pixel => [0, 80] dBZ |
| | ''' |
| | if img.mean() > 1.0: |
| | img = img / 255.0 |
| | ashift = 31.0 |
| | afact = 4.0 |
| | atan_dBZ_min = -1.482 |
| | atan_dBZ_max = 1.412 |
| | tan_pix = np.tan(img * (atan_dBZ_max - atan_dBZ_min) + atan_dBZ_min) |
| | return tan_pix * afact + ashift |
| |
|
| | def dbZ_to_pixel_nonlinear(dbZ): |
| | ''' |
| | [0, 80] dBZ => [0, 255] OR [0, 1] pixel |
| | ''' |
| | ashift = 31.0 |
| | afact = 4.0 |
| | atan_dBZ_min = -1.482 |
| | atan_dBZ_max = 1.412 |
| | dbZ_adjusted = (dbZ - ashift) / afact |
| | return (np.arctan(dbZ_adjusted) - atan_dBZ_min) / (atan_dBZ_max - atan_dBZ_min) |
| |
|
| | def dbZ_to_pixel(dbZ): |
| | ''' |
| | [0, 80] dbZ => [0, 1] pixel |
| | ''' |
| | return np.floor((dbZ + 10) * 255 / 70 + 0.5) / 255.0 |
| |
|
| | def pixel_to_dBZ(pixel): |
| | ''' |
| | [0, 255] (or [0, 1]) pixel => [0, 80] dBZ |
| | ''' |
| | if pixel.mean() > 1.0: |
| | pixel = pixel / 255.0 |
| | return (70 * pixel) - 10 |
| |
|
| | def nonlinear_to_linear(im): |
| | return dbZ_to_pixel(pixel_to_dBZ_nonlinear(im)) |
| |
|
| | def nonlinear_to_linear_batched(seq, datetime): |
| | seq_linear = np.zeros_like(seq) |
| | for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
| | if dt_b[0].year >= 2016: |
| | seq_linear[i] = nonlinear_to_linear(seq_b) |
| | else: |
| | seq_linear[i] = seq_b |
| | seq_linear = np.clip(seq_linear, 0.0, 1.0) |
| | return seq_linear |
| |
|
| | def linear_to_nonlinear(im): |
| | return dbZ_to_pixel_nonlinear(pixel_to_dBZ(im)) |
| |
|
| | def linear_to_nonlinear_batched(seq, datetime): |
| | seq_nonlinear = np.zeros_like(seq) |
| | for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
| | if dt_b[0].year < 2016: |
| | seq_nonlinear[i] = linear_to_nonlinear(seq_b) |
| | else: |
| | seq_nonlinear[i] = seq_b |
| | seq_nonlinear = np.clip(seq_nonlinear, 0.0, 1.0) |
| | return seq_nonlinear |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | SEVIR_DATA_TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght'] |
| | SEVIR_RAW_DTYPES = {'vis': np.int16, |
| | 'ir069': np.int16, |
| | 'ir107': np.int16, |
| | 'vil': np.uint8, |
| | 'lght': np.int16} |
| | LIGHTING_FRAME_TIMES = np.arange(- 120.0, 125.0, 5) * 60 |
| | SEVIR_DATA_SHAPE = {'lght': (48, 48), } |
| | PREPROCESS_SCALE_SEVIR = {'vis': 1, |
| | 'ir069': 1 / 1174.68, |
| | 'ir107': 1 / 2562.43, |
| | 'vil': 1 / 47.54, |
| | 'lght': 1 / 0.60517} |
| | PREPROCESS_OFFSET_SEVIR = {'vis': 0, |
| | 'ir069': 3683.58, |
| | 'ir107': 1552.80, |
| | 'vil': - 33.44, |
| | 'lght': - 0.02990} |
| | PREPROCESS_SCALE_01 = {'vis': 1, |
| | 'ir069': 1, |
| | 'ir107': 1, |
| | 'vil': 1 / 255, |
| | 'lght': 1} |
| | PREPROCESS_OFFSET_01 = {'vis': 0, |
| | 'ir069': 0, |
| | 'ir107': 0, |
| | 'vil': 0, |
| | 'lght': 0} |
| |
|
| | |
| | SEVIR_CATALOG = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv") |
| | SEVIR_DATA_DIR = os.path.join(SEVIR_ROOT_DIR, "data") |
| | SEVIR_RAW_SEQ_LEN = 49 |
| |
|
| | SEVIR_TRAIN_VAL_SPLIT_DATE = datetime.datetime(2019, 1, 1) |
| | SEVIR_TRAIN_TEST_SPLIT_DATE = datetime.datetime(2019, 6, 1) |
| |
|
| | def change_layout_np(data, |
| | in_layout='NHWT', out_layout='NHWT', |
| | ret_contiguous=False): |
| | |
| | if in_layout == 'NHWT': |
| | pass |
| | elif in_layout == 'NTHW': |
| | data = np.transpose(data, |
| | axes=(0, 2, 3, 1)) |
| | elif in_layout == 'NWHT': |
| | data = np.transpose(data, |
| | axes=(0, 2, 1, 3)) |
| | elif in_layout == 'NTCHW': |
| | data = data[:, :, 0, :, :] |
| | data = np.transpose(data, |
| | axes=(0, 2, 3, 1)) |
| | elif in_layout == 'NTHWC': |
| | data = data[:, :, :, :, 0] |
| | data = np.transpose(data, |
| | axes=(0, 2, 3, 1)) |
| | elif in_layout == 'NTWHC': |
| | data = data[:, :, :, :, 0] |
| | data = np.transpose(data, |
| | axes=(0, 3, 2, 1)) |
| | elif in_layout == 'TNHW': |
| | data = np.transpose(data, |
| | axes=(1, 2, 3, 0)) |
| | elif in_layout == 'TNCHW': |
| | data = data[:, :, 0, :, :] |
| | data = np.transpose(data, |
| | axes=(1, 2, 3, 0)) |
| | else: |
| | raise NotImplementedError |
| |
|
| | if out_layout == 'NHWT': |
| | pass |
| | elif out_layout == 'NTHW': |
| | data = np.transpose(data, |
| | axes=(0, 3, 1, 2)) |
| | elif out_layout == 'NWHT': |
| | data = np.transpose(data, |
| | axes=(0, 2, 1, 3)) |
| | elif out_layout == 'NTCHW': |
| | data = np.transpose(data, |
| | axes=(0, 3, 1, 2)) |
| | data = np.expand_dims(data, axis=2) |
| | elif out_layout == 'NTHWC': |
| | data = np.transpose(data, |
| | axes=(0, 3, 1, 2)) |
| | data = np.expand_dims(data, axis=-1) |
| | elif out_layout == 'NTWHC': |
| | data = np.transpose(data, |
| | axes=(0, 3, 2, 1)) |
| | data = np.expand_dims(data, axis=-1) |
| | elif out_layout == 'TNHW': |
| | data = np.transpose(data, |
| | axes=(3, 0, 1, 2)) |
| | elif out_layout == 'TNCHW': |
| | data = np.transpose(data, |
| | axes=(3, 0, 1, 2)) |
| | data = np.expand_dims(data, axis=2) |
| | else: |
| | raise NotImplementedError |
| | if ret_contiguous: |
| | data = data.ascontiguousarray() |
| | return data |
| |
|
| | def change_layout_torch(data, |
| | in_layout='NHWT', out_layout='NHWT', |
| | ret_contiguous=False): |
| | |
| | if in_layout == 'NHWT': |
| | pass |
| | elif in_layout == 'NTHW': |
| | data = data.permute(0, 2, 3, 1) |
| | elif in_layout == 'NTCHW': |
| | data = data[:, :, 0, :, :] |
| | data = data.permute(0, 2, 3, 1) |
| | elif in_layout == 'NTHWC': |
| | data = data[:, :, :, :, 0] |
| | data = data.permute(0, 2, 3, 1) |
| | elif in_layout == 'TNHW': |
| | data = data.permute(1, 2, 3, 0) |
| | elif in_layout == 'TNCHW': |
| | data = data[:, :, 0, :, :] |
| | data = data.permute(1, 2, 3, 0) |
| | else: |
| | raise NotImplementedError |
| |
|
| | if out_layout == 'NHWT': |
| | pass |
| | elif out_layout == 'NTHW': |
| | data = data.permute(0, 3, 1, 2) |
| | elif out_layout == 'NTCHW': |
| | data = data.permute(0, 3, 1, 2) |
| | data = torch.unsqueeze(data, dim=2) |
| | elif out_layout == 'NTHWC': |
| | data = data.permute(0, 3, 1, 2) |
| | data = torch.unsqueeze(data, dim=-1) |
| | elif out_layout == 'TNHW': |
| | data = data.permute(3, 0, 1, 2) |
| | elif out_layout == 'TNCHW': |
| | data = data.permute(3, 0, 1, 2) |
| | data = torch.unsqueeze(data, dim=2) |
| | else: |
| | raise NotImplementedError |
| | if ret_contiguous: |
| | data = data.contiguous() |
| | return data |
| |
|
| | class SEVIRDataLoader: |
| | r""" |
| | DataLoader that loads SEVIR sequences, and spilts each event |
| | into segments according to specified sequence length. |
| | |
| | Event Frames: |
| | [-----------------------raw_seq_len----------------------] |
| | [-----seq_len-----] |
| | <--stride-->[-----seq_len-----] |
| | <--stride-->[-----seq_len-----] |
| | ... |
| | """ |
| | def __init__(self, |
| | data_types: Sequence[str] = None, |
| | seq_len: int = 49, |
| | raw_seq_len: int = 49, |
| | sample_mode: str = 'sequent', |
| | stride: int = 12, |
| | batch_size: int = 1, |
| | layout: str = 'NHWT', |
| | num_shard: int = 1, |
| | rank: int = 0, |
| | split_mode: str = "uneven", |
| | sevir_catalog: Union[str, pd.DataFrame] = None, |
| | sevir_data_dir: str = None, |
| | start_date: datetime.datetime = None, |
| | end_date: datetime.datetime = None, |
| | datetime_filter=None, |
| | catalog_filter='default', |
| | shuffle: bool = False, |
| | shuffle_seed: int = 1, |
| | output_type=np.float32, |
| | preprocess: bool = True, |
| | rescale_method: str = '01', |
| | downsample_dict: Dict[str, Sequence[int]] = None, |
| | verbose: bool = False): |
| | r""" |
| | Parameters |
| | ---------- |
| | data_types |
| | A subset of SEVIR_DATA_TYPES. |
| | seq_len |
| | The length of the data sequences. Should be smaller than the max length raw_seq_len. |
| | raw_seq_len |
| | The length of the raw data sequences. |
| | sample_mode |
| | 'random' or 'sequent' |
| | stride |
| | Useful when sample_mode == 'sequent' |
| | stride must not be smaller than out_len to prevent data leakage in testing. |
| | batch_size |
| | Number of sequences in one batch. |
| | layout |
| | str: consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
| | The layout of sampled data. Raw data layout is 'NHWT'. |
| | valid layout: 'NHWT', 'NTHW', 'NTCHW', 'TNHW', 'TNCHW'. |
| | num_shard |
| | Split the whole dataset into num_shard parts for distributed training. |
| | rank |
| | Rank of the current process within num_shard. |
| | split_mode: str |
| | if 'ceil', all `num_shard` dataloaders have the same length = ceil(total_len / num_shard). |
| | Different dataloaders may have some duplicated data batches, if the total size of datasets is not divided by num_shard. |
| | if 'floor', all `num_shard` dataloaders have the same length = floor(total_len / num_shard). |
| | The last several data batches may be wasted, if the total size of datasets is not divided by num_shard. |
| | if 'uneven', the last datasets has larger length when the total length is not divided by num_shard. |
| | The uneven split leads to synchronization error in dist.all_reduce() or dist.barrier(). |
| | See related issue: https://github.com/pytorch/pytorch/issues/33148 |
| | Notice: this also affects the behavior of `self.use_up`. |
| | sevir_catalog |
| | Name of SEVIR catalog CSV file. |
| | sevir_data_dir |
| | Directory path to SEVIR data. |
| | start_date |
| | Start time of SEVIR samples to generate. |
| | end_date |
| | End time of SEVIR samples to generate. |
| | datetime_filter |
| | function |
| | Mask function applied to time_utc column of catalog (return true to keep the row). |
| | Pass function of the form lambda t : COND(t) |
| | Example: lambda t: np.logical_and(t.dt.hour>=13,t.dt.hour<=21) # Generate only day-time events |
| | catalog_filter |
| | function or None or 'default' |
| | Mask function applied to entire catalog dataframe (return true to keep row). |
| | Pass function of the form lambda catalog: COND(catalog) |
| | Example: lambda c: [s[0]=='S' for s in c.id] # Generate only the 'S' events |
| | shuffle |
| | bool, If True, data samples are shuffled before each epoch. |
| | shuffle_seed |
| | int, Seed to use for shuffling. |
| | output_type |
| | np.dtype, dtype of generated tensors |
| | preprocess |
| | bool, If True, self.preprocess_data_dict(data_dict) is called before each sample generated |
| | downsample_dict: |
| | dict, downsample_dict.keys() == data_types. downsample_dict[key] is a Sequence of (t_factor, h_factor, w_factor), |
| | representing the downsampling factors of all dimensions. |
| | verbose |
| | bool, verbose when opening raw data files |
| | """ |
| | super(SEVIRDataLoader, self).__init__() |
| | if sevir_catalog is None: |
| | sevir_catalog = SEVIR_CATALOG |
| | if sevir_data_dir is None: |
| | sevir_data_dir = SEVIR_DATA_DIR |
| | if data_types is None: |
| | data_types = SEVIR_DATA_TYPES |
| | else: |
| | assert set(data_types).issubset(SEVIR_DATA_TYPES) |
| |
|
| | |
| | self._dtypes = SEVIR_RAW_DTYPES |
| | self.lght_frame_times = LIGHTING_FRAME_TIMES |
| | self.data_shape = SEVIR_DATA_SHAPE |
| |
|
| | self.raw_seq_len = raw_seq_len |
| | assert seq_len <= self.raw_seq_len, f'seq_len must not be larger than raw_seq_len = {raw_seq_len}, got {seq_len}.' |
| | self.seq_len = seq_len |
| | assert sample_mode in ['random', 'sequent'], f'Invalid sample_mode = {sample_mode}, must be \'random\' or \'sequent\'.' |
| | self.sample_mode = sample_mode |
| | self.stride = stride |
| | self.batch_size = batch_size |
| | valid_layout = ('NHWT', 'NTHW', 'NTCHW', 'NTHWC', 'TNHW', 'TNCHW') |
| | if layout not in valid_layout: |
| | raise ValueError(f'Invalid layout = {layout}! Must be one of {valid_layout}.') |
| | self.layout = layout |
| | self.num_shard = num_shard |
| | self.rank = rank |
| | valid_split_mode = ('ceil', 'floor', 'uneven') |
| | if split_mode not in valid_split_mode: |
| | raise ValueError(f'Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}.') |
| | self.split_mode = split_mode |
| | self._samples = None |
| | self._hdf_files = {} |
| | self.data_types = data_types |
| | if isinstance(sevir_catalog, str): |
| | self.catalog = pd.read_csv(sevir_catalog, parse_dates=['time_utc'], low_memory=False) |
| | else: |
| | self.catalog = sevir_catalog |
| | self.sevir_data_dir = sevir_data_dir |
| | self.datetime_filter = datetime_filter |
| | self.catalog_filter = catalog_filter |
| | self.start_date = start_date |
| | self.end_date = end_date |
| | self.shuffle = shuffle |
| | self.shuffle_seed = int(shuffle_seed) |
| | self.output_type = output_type |
| | self.preprocess = preprocess |
| | self.downsample_dict = downsample_dict |
| | self.rescale_method = rescale_method |
| | self.verbose = verbose |
| |
|
| | if self.start_date is not None: |
| | self.catalog = self.catalog[self.catalog.time_utc > self.start_date] |
| | if self.end_date is not None: |
| | self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] |
| | if self.datetime_filter: |
| | self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] |
| |
|
| | if self.catalog_filter is not None: |
| | if self.catalog_filter == 'default': |
| | self.catalog_filter = lambda c: c.pct_missing == 0 |
| | self.catalog = self.catalog[self.catalog_filter(self.catalog)] |
| |
|
| | self._compute_samples() |
| | self._open_files(verbose=self.verbose) |
| | self.reset() |
| |
|
| | def _compute_samples(self): |
| | """ |
| | Computes the list of samples in catalog to be used. This sets self._samples |
| | """ |
| | |
| | imgt = self.data_types |
| | imgts = set(imgt) |
| | filtcat = self.catalog[ np.logical_or.reduce([self.catalog.img_type==i for i in imgt]) ] |
| | |
| | filtcat = filtcat.groupby('id').filter(lambda x: imgts.issubset(set(x['img_type']))) |
| | |
| | |
| | filtcat = filtcat.groupby('id').filter(lambda x: x.shape[0]==len(imgt)) |
| | self._samples = filtcat.groupby('id').apply(lambda df: self._df_to_series(df,imgt) ) |
| | if self.shuffle: |
| | self.shuffle_samples() |
| |
|
| | def shuffle_samples(self): |
| | self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) |
| |
|
| | def _df_to_series(self, df, imgt): |
| | d = {} |
| | df = df.set_index('img_type') |
| | for i in imgt: |
| | s = df.loc[i] |
| | idx = s.file_index if i != 'lght' else s.id |
| | d.update({f'{i}_filename': [s.file_name], |
| | f'{i}_index': [idx]}) |
| |
|
| | return pd.DataFrame(d) |
| |
|
| | def _open_files(self, verbose=True): |
| | """ |
| | Opens HDF files |
| | """ |
| | imgt = self.data_types |
| | hdf_filenames = [] |
| | for t in imgt: |
| | hdf_filenames += list(np.unique( self._samples[f'{t}_filename'].values )) |
| | self._hdf_files = {} |
| | for f in hdf_filenames: |
| | if verbose: |
| | print('Opening HDF5 file for reading', f) |
| | self._hdf_files[f] = h5py.File(self.sevir_data_dir + '/' + f, 'r') |
| |
|
| | def close(self): |
| | """ |
| | Closes all open file handles |
| | """ |
| | for f in self._hdf_files: |
| | self._hdf_files[f].close() |
| | self._hdf_files = {} |
| |
|
| | @property |
| | def num_seq_per_event(self): |
| | return 1 + (self.raw_seq_len - self.seq_len) // self.stride |
| |
|
| | @property |
| | def total_num_seq(self): |
| | """ |
| | The total number of sequences within each shard. |
| | Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. |
| | """ |
| | return int(self.num_seq_per_event * self.num_event) |
| |
|
| | @property |
| | def total_num_event(self): |
| | """ |
| | The total number of events in the whole dataset, before split into different shards. |
| | """ |
| | return int(self._samples.shape[0]) |
| |
|
| | @property |
| | def start_event_idx(self): |
| | """ |
| | The event idx used in certain rank should satisfy event_idx >= start_event_idx |
| | """ |
| | return self.total_num_event // self.num_shard * self.rank |
| |
|
| | @property |
| | def end_event_idx(self): |
| | """ |
| | The event idx used in certain rank should satisfy event_idx < end_event_idx |
| | |
| | """ |
| | if self.split_mode == 'ceil': |
| | _last_start_event_idx = self.total_num_event // self.num_shard * (self.num_shard - 1) |
| | _num_event = self.total_num_event - _last_start_event_idx |
| | return self.start_event_idx + _num_event |
| | elif self.split_mode == 'floor': |
| | return self.total_num_event // self.num_shard * (self.rank + 1) |
| | else: |
| | if self.rank == self.num_shard - 1: |
| | return self.total_num_event |
| | else: |
| | return self.total_num_event // self.num_shard * (self.rank + 1) |
| |
|
| | @property |
| | def num_event(self): |
| | """ |
| | The number of events split into each rank |
| | """ |
| | return self.end_event_idx - self.start_event_idx |
| |
|
| | def _read_data(self, row, data): |
| | """ |
| | Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_len). |
| | |
| | Parameters |
| | ---------- |
| | row |
| | A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. |
| | data |
| | Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_len). |
| | |
| | Returns |
| | ------- |
| | data |
| | Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_len). |
| | """ |
| | imgtyps = np.unique([x.split('_')[0] for x in list(row.keys())]) |
| | for t in imgtyps: |
| | fname = row[f'{t}_filename'] |
| | idx = row[f'{t}_index'] |
| | t_slice = slice(0, None) |
| | |
| | if t == 'lght': |
| | lght_data = self._hdf_files[fname][idx][:] |
| | data_i = self._lght_to_grid(lght_data, t_slice) |
| | else: |
| | data_i = self._hdf_files[fname][t][idx:idx + 1, :, :, t_slice] |
| | data[t] = np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i |
| |
|
| | return data |
| |
|
| | def _lght_to_grid(self, data, t_slice=slice(0, None)): |
| | """ |
| | Converts Nx5 lightning data matrix into a 2D grid of pixel counts |
| | """ |
| | |
| | out_size = (*self.data_shape['lght'], len(self.lght_frame_times)) if t_slice.stop is None else (*self.data_shape['lght'], 1) |
| | if data.shape[0] == 0: |
| | return np.zeros((1,) + out_size, dtype=np.float32) |
| |
|
| | |
| | x, y = data[:, 3], data[:, 4] |
| | m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) |
| | data = data[m, :] |
| | if data.shape[0] == 0: |
| | return np.zeros((1,) + out_size, dtype=np.float32) |
| |
|
| | |
| | t = data[:, 0] |
| | if t_slice.stop is not None: |
| | if t_slice.stop > 0: |
| | if t_slice.stop < len(self.lght_frame_times): |
| | tm = np.logical_and(t >= self.lght_frame_times[t_slice.stop - 1], |
| | t < self.lght_frame_times[t_slice.stop]) |
| | else: |
| | tm = t >= self.lght_frame_times[-1] |
| | else: |
| | tm = np.logical_and(t >= self.lght_frame_times[0], t < self.lght_frame_times[1]) |
| | |
| |
|
| | data = data[tm, :] |
| | z = np.zeros(data.shape[0], dtype=np.int64) |
| | else: |
| | z = np.digitize(t, self.lght_frame_times) - 1 |
| | z[z == -1] = 0 |
| |
|
| | x = data[:, 3].astype(np.int64) |
| | y = data[:, 4].astype(np.int64) |
| |
|
| | k = np.ravel_multi_index(np.array([y, x, z]), out_size) |
| | n = np.bincount(k, minlength=np.prod(out_size)) |
| | return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :] |
| |
|
| | def _old_save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
| | """ |
| | This method does not save .h5 dataset correctly. There are some batches missed due to unknown error. |
| | E.g., the first converted .h5 file `SEVIR_VIL_RANDOMEVENTS_2017_0501_0831.h5` only has batch_dim = 1414, |
| | while it should be 1440 in the original .h5 file. |
| | """ |
| | import os |
| | from skimage.measure import block_reduce |
| | assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
| | os.makedirs(save_dir) |
| | sample_counter = 0 |
| | for index, row in self._samples.iterrows(): |
| | if verbose: |
| | print(f"Downsampling {sample_counter}-th data item.", end='\r') |
| | for data_type in self.data_types: |
| | fname = row[f'{data_type}_filename'] |
| | idx = row[f'{data_type}_index'] |
| | t_slice = slice(0, None) |
| | if data_type == 'lght': |
| | lght_data = self._hdf_files[fname][idx][:] |
| | data_i = self._lght_to_grid(lght_data, t_slice) |
| | else: |
| | data_i = self._hdf_files[fname][data_type][idx:idx + 1, :, :, t_slice] |
| | |
| | t_slice = [slice(None, None), ] * 4 |
| | t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
| | data_i = data_i[tuple(t_slice)] |
| | |
| | data_i = block_reduce(data_i, |
| | block_size=(1, *downsample_dict[data_type][1:], 1), |
| | func=np.max) |
| | |
| | new_file_path = os.path.join(save_dir, fname) |
| | if not os.path.exists(new_file_path): |
| | if not os.path.exists(os.path.dirname(new_file_path)): |
| | os.makedirs(os.path.dirname(new_file_path)) |
| | |
| | with h5py.File(new_file_path, 'w') as hf: |
| | hf.create_dataset( |
| | data_type, data=data_i, |
| | maxshape=(None, *data_i.shape[1:])) |
| | else: |
| | |
| | with h5py.File(new_file_path, 'a') as hf: |
| | hf[data_type].resize((hf[data_type].shape[0] + data_i.shape[0]), axis=0) |
| | hf[data_type][-data_i.shape[0]:] = data_i |
| |
|
| | sample_counter += 1 |
| |
|
| | def save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
| | """ |
| | Parameters |
| | ---------- |
| | save_dir |
| | downsample_dict: Dict[Sequence[int]] |
| | Notice that this is different from `self.downsample_dict`, which is used during runtime. |
| | """ |
| | import os |
| | from skimage.measure import block_reduce |
| | from ...utils.utils import path_splitall |
| | assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
| | os.makedirs(save_dir) |
| | for fname, hdf_file in self._hdf_files.items(): |
| | if verbose: |
| | print(f"Downsampling data in {fname}.") |
| | data_type = path_splitall(fname)[0] |
| | if data_type == 'lght': |
| | |
| | raise NotImplementedError |
| | |
| | |
| | |
| | else: |
| | data_i = self._hdf_files[fname][data_type] |
| | |
| | t_slice = [slice(None, None), ] * 4 |
| | t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
| | data_i = data_i[tuple(t_slice)] |
| | |
| | data_i = block_reduce(data_i, |
| | block_size=(1, *downsample_dict[data_type][1:], 1), |
| | func=np.max) |
| | |
| | new_file_path = os.path.join(save_dir, fname) |
| | if not os.path.exists(os.path.dirname(new_file_path)): |
| | os.makedirs(os.path.dirname(new_file_path)) |
| | |
| | with h5py.File(new_file_path, 'w') as hf: |
| | hf.create_dataset( |
| | data_type, data=data_i, |
| | maxshape=(None, *data_i.shape[1:])) |
| |
|
| | @property |
| | def sample_count(self): |
| | """ |
| | Record how many times self.__next__() is called. |
| | """ |
| | return self._sample_count |
| |
|
| | def inc_sample_count(self): |
| | self._sample_count += 1 |
| |
|
| | @property |
| | def curr_event_idx(self): |
| | return self._curr_event_idx |
| |
|
| | @property |
| | def curr_seq_idx(self): |
| | """ |
| | Used only when self.sample_mode == 'sequent' |
| | """ |
| | return self._curr_seq_idx |
| |
|
| | def set_curr_event_idx(self, val): |
| | self._curr_event_idx = val |
| |
|
| | def set_curr_seq_idx(self, val): |
| | """ |
| | Used only when self.sample_mode == 'sequent' |
| | """ |
| | self._curr_seq_idx = val |
| |
|
| | def reset(self, shuffle: bool = None): |
| | self.set_curr_event_idx(val=self.start_event_idx) |
| | self.set_curr_seq_idx(0) |
| | self._sample_count = 0 |
| | if shuffle is None: |
| | shuffle = self.shuffle |
| | if shuffle: |
| | self.shuffle_samples() |
| |
|
| | def __len__(self): |
| | """ |
| | Used only when self.sample_mode == 'sequent' |
| | """ |
| | return self.total_num_seq // self.batch_size |
| |
|
| | @property |
| | def use_up(self): |
| | """ |
| | Check if dataset is used up in 'sequent' mode. |
| | """ |
| | if self.sample_mode == 'random': |
| | return False |
| | else: |
| | |
| | curr_event_remain_seq = self.num_seq_per_event - self.curr_seq_idx |
| | all_remain_seq = curr_event_remain_seq + ( |
| | self.end_event_idx - self.curr_event_idx - 1) * self.num_seq_per_event |
| | if self.split_mode == "floor": |
| | |
| | return all_remain_seq < self.batch_size |
| | else: |
| | return all_remain_seq <= 0 |
| |
|
| | def _load_event_batch(self, event_idx, event_batch_size): |
| | """ |
| | Loads a selected batch of events (not batch of sequences) into memory. |
| | |
| | Parameters |
| | ---------- |
| | idx |
| | event_batch_size |
| | event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size] |
| | Returns |
| | ------- |
| | event_batch |
| | list of event batches. |
| | event_batch[i] is the event batch of the i-th data type. |
| | Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_len) |
| | """ |
| | event_idx_slice_end = event_idx + event_batch_size |
| | pad_size = 0 |
| | if event_idx_slice_end > self.end_event_idx: |
| | pad_size = event_idx_slice_end - self.end_event_idx |
| | event_idx_slice_end = self.end_event_idx |
| | pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] |
| | data = {} |
| | for index, row in pd_batch.iterrows(): |
| | data = self._read_data(row, data) |
| | if pad_size > 0: |
| | event_batch = [] |
| | for t in self.data_types: |
| | pad_shape = [pad_size, ] + list(data[t].shape[1:]) |
| | data_pad = np.concatenate((data[t].astype(self.output_type), |
| | np.zeros(pad_shape, dtype=self.output_type)), |
| | axis=0) |
| | event_batch.append(data_pad) |
| | else: |
| | event_batch = [data[t].astype(self.output_type) for t in self.data_types] |
| | return event_batch |
| |
|
| | def __iter__(self): |
| | return self |
| |
|
| | def __next__(self): |
| | if self.sample_mode == 'random': |
| | self.inc_sample_count() |
| | ret_dict = self._random_sample() |
| | else: |
| | if self.use_up: |
| | raise StopIteration |
| | else: |
| | self.inc_sample_count() |
| | ret_dict = self._sequent_sample() |
| | ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
| | data_types=self.data_types) |
| | if self.preprocess: |
| | ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
| | data_types=self.data_types, |
| | layout=self.layout, |
| | rescale=self.rescale_method) |
| | if self.downsample_dict is not None: |
| | ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
| | data_types=self.data_types, |
| | factors_dict=self.downsample_dict, |
| | layout=self.layout) |
| | return ret_dict |
| |
|
| | def __getitem__(self, index): |
| | data_dict = self._idx_sample(index=index) |
| | return data_dict |
| |
|
| | @staticmethod |
| | def preprocess_data_dict(data_dict, data_types=None, layout='NHWT', rescale='01'): |
| | """ |
| | Parameters |
| | ---------- |
| | data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
| | data_types: Sequence[str] |
| | The data types that we want to rescale. This mainly excludes "mask" from preprocessing. |
| | layout: str |
| | consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
| | rescale: str |
| | 'sevir': use the offsets and scale factors in original implementation. |
| | '01': scale all values to range 0 to 1, currently only supports 'vil' |
| | Returns |
| | ------- |
| | data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
| | preprocessed data |
| | """ |
| | if rescale == 'sevir': |
| | scale_dict = PREPROCESS_SCALE_SEVIR |
| | offset_dict = PREPROCESS_OFFSET_SEVIR |
| | elif rescale == '01': |
| | scale_dict = PREPROCESS_SCALE_01 |
| | offset_dict = PREPROCESS_OFFSET_01 |
| | else: |
| | raise ValueError(f'Invalid rescale option: {rescale}.') |
| | if data_types is None: |
| | data_types = data_dict.keys() |
| | for key, data in data_dict.items(): |
| | if key in data_types: |
| | if isinstance(data, np.ndarray): |
| | data = scale_dict[key] * ( |
| | data.astype(np.float32) + |
| | offset_dict[key]) |
| | data = change_layout_np(data=data, |
| | in_layout='NHWT', |
| | out_layout=layout) |
| | elif isinstance(data, torch.Tensor): |
| | data = scale_dict[key] * ( |
| | data.float() + |
| | offset_dict[key]) |
| | data = change_layout_torch(data=data, |
| | in_layout='NHWT', |
| | out_layout=layout) |
| | data_dict[key] = data |
| | return data_dict |
| |
|
| | @staticmethod |
| | def process_data_dict_back(data_dict, data_types=None, rescale='01'): |
| | """ |
| | Parameters |
| | ---------- |
| | data_dict |
| | each data_dict[key] is a torch.Tensor. |
| | rescale |
| | str: |
| | 'sevir': data are scaled using the offsets and scale factors in original implementation. |
| | '01': data are all scaled to range 0 to 1, currently only supports 'vil' |
| | Returns |
| | ------- |
| | data_dict |
| | each data_dict[key] is the data processed back in torch.Tensor. |
| | """ |
| | if rescale == 'sevir': |
| | scale_dict = PREPROCESS_SCALE_SEVIR |
| | offset_dict = PREPROCESS_OFFSET_SEVIR |
| | elif rescale == '01': |
| | scale_dict = PREPROCESS_SCALE_01 |
| | offset_dict = PREPROCESS_OFFSET_01 |
| | else: |
| | raise ValueError(f'Invalid rescale option: {rescale}.') |
| | if data_types is None: |
| | data_types = data_dict.keys() |
| | for key in data_types: |
| | data = data_dict[key] |
| | data = data.float() / scale_dict[key] - offset_dict[key] |
| | data_dict[key] = data |
| | return data_dict |
| |
|
| | @staticmethod |
| | def data_dict_to_tensor(data_dict, data_types=None): |
| | """ |
| | Convert each element in data_dict to torch.Tensor (copy without grad). |
| | """ |
| | ret_dict = {} |
| | if data_types is None: |
| | data_types = data_dict.keys() |
| | for key, data in data_dict.items(): |
| | if key in data_types: |
| | if isinstance(data, torch.Tensor): |
| | ret_dict[key] = data.detach().clone() |
| | elif isinstance(data, np.ndarray): |
| | ret_dict[key] = torch.from_numpy(data) |
| | else: |
| | raise ValueError(f"Invalid data type: {type(data)}. Should be torch.Tensor or np.ndarray") |
| | else: |
| | ret_dict[key] = data |
| | return ret_dict |
| |
|
| | @staticmethod |
| | def downsample_data_dict(data_dict, data_types=None, factors_dict=None, layout='NHWT'): |
| | """ |
| | Parameters |
| | ---------- |
| | data_dict: Dict[str, Union[np.array, torch.Tensor]] |
| | factors_dict: Optional[Dict[str, Sequence[int]]] |
| | each element `factors` is a Sequence of int, representing (t_factor, h_factor, w_factor) |
| | |
| | Returns |
| | ------- |
| | downsampled_data_dict: Dict[str, torch.Tensor] |
| | Modify on a deep copy of data_dict instead of directly modifying the original data_dict |
| | """ |
| | if factors_dict is None: |
| | factors_dict = {} |
| | if data_types is None: |
| | data_types = data_dict.keys() |
| | downsampled_data_dict = SEVIRDataLoader.data_dict_to_tensor( |
| | data_dict=data_dict, |
| | data_types=data_types) |
| | for key, data in data_dict.items(): |
| | factors = factors_dict.get(key, None) |
| | if factors is not None: |
| | downsampled_data_dict[key] = change_layout_torch( |
| | data=downsampled_data_dict[key], |
| | in_layout=layout, |
| | out_layout='NTHW') |
| | |
| | t_slice = [slice(None, None), ] * 4 |
| | t_slice[1] = slice(None, None, factors[0]) |
| | downsampled_data_dict[key] = downsampled_data_dict[key][tuple(t_slice)] |
| | |
| | downsampled_data_dict[key] = avg_pool2d( |
| | input=downsampled_data_dict[key], |
| | kernel_size=(factors[1], factors[2])) |
| |
|
| | downsampled_data_dict[key] = change_layout_torch( |
| | data=downsampled_data_dict[key], |
| | in_layout='NTHW', |
| | out_layout=layout) |
| |
|
| | return downsampled_data_dict |
| |
|
| | def _random_sample(self): |
| | """ |
| | Returns |
| | ------- |
| | ret_dict |
| | dict. ret_dict.keys() == self.data_types. |
| | If self.preprocess == False: |
| | ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| | """ |
| | num_sampled = 0 |
| | event_idx_list = nprand.randint(low=self.start_event_idx, |
| | high=self.end_event_idx, |
| | size=self.batch_size) |
| | seq_idx_list = nprand.randint(low=0, |
| | high=self.num_seq_per_event, |
| | size=self.batch_size) |
| | seq_slice_list = [slice(seq_idx * self.stride, |
| | seq_idx * self.stride + self.seq_len) |
| | for seq_idx in seq_idx_list] |
| | ret_dict = {} |
| | while num_sampled < self.batch_size: |
| | event = self._load_event_batch(event_idx=event_idx_list[num_sampled], |
| | event_batch_size=1) |
| | for imgt_idx, imgt in enumerate(self.data_types): |
| | sampled_seq = event[imgt_idx][[0, ], :, :, seq_slice_list[num_sampled]] |
| | if imgt in ret_dict: |
| | ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| | axis=0) |
| | else: |
| | ret_dict.update({imgt: sampled_seq}) |
| | return ret_dict |
| |
|
| | def _sequent_sample(self): |
| | """ |
| | Returns |
| | ------- |
| | ret_dict: Dict |
| | `ret_dict.keys()` contains `self.data_types`. |
| | `ret_dict["mask"]` is a list of bool, indicating if the data entry is real or padded. |
| | If self.preprocess == False: |
| | ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| | """ |
| | assert not self.use_up, 'Data loader used up! Reset it to reuse.' |
| | event_idx = self.curr_event_idx |
| | seq_idx = self.curr_seq_idx |
| | num_sampled = 0 |
| | sampled_idx_list = [] |
| | while num_sampled < self.batch_size: |
| | sampled_idx_list.append({'event_idx': event_idx, |
| | 'seq_idx': seq_idx}) |
| | seq_idx += 1 |
| | if seq_idx >= self.num_seq_per_event: |
| | event_idx += 1 |
| | seq_idx = 0 |
| | num_sampled += 1 |
| |
|
| | start_event_idx = sampled_idx_list[0]['event_idx'] |
| | event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
| |
|
| | event_batch = self._load_event_batch(event_idx=start_event_idx, |
| | event_batch_size=event_batch_size) |
| | ret_dict = {"mask": []} |
| | all_no_pad_flag = True |
| | for sampled_idx in sampled_idx_list: |
| | batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
| | seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
| | sampled_idx['seq_idx'] * self.stride + self.seq_len) |
| | for imgt_idx, imgt in enumerate(self.data_types): |
| | sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
| | if imgt in ret_dict: |
| | ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| | axis=0) |
| | else: |
| | ret_dict.update({imgt: sampled_seq}) |
| | |
| | no_pad_flag = sampled_idx['event_idx'] < self.end_event_idx |
| | if not no_pad_flag: |
| | all_no_pad_flag = False |
| | ret_dict["mask"].append(no_pad_flag) |
| | if all_no_pad_flag: |
| | |
| | ret_dict["mask"] = None |
| | |
| | self.set_curr_event_idx(event_idx) |
| | self.set_curr_seq_idx(seq_idx) |
| | return ret_dict |
| |
|
| | def _idx_sample(self, index): |
| | """ |
| | Parameters |
| | ---------- |
| | index |
| | The index of the batch to sample. |
| | Returns |
| | ------- |
| | ret_dict |
| | dict. ret_dict.keys() == self.data_types. |
| | If self.preprocess == False: |
| | ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
| | """ |
| | event_idx = (index * self.batch_size) // self.num_seq_per_event |
| | seq_idx = (index * self.batch_size) % self.num_seq_per_event |
| | num_sampled = 0 |
| | sampled_idx_list = [] |
| | while num_sampled < self.batch_size: |
| | sampled_idx_list.append({'event_idx': event_idx, |
| | 'seq_idx': seq_idx}) |
| | seq_idx += 1 |
| | if seq_idx >= self.num_seq_per_event: |
| | event_idx += 1 |
| | seq_idx = 0 |
| | num_sampled += 1 |
| |
|
| | start_event_idx = sampled_idx_list[0]['event_idx'] |
| | event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
| |
|
| | event_batch = self._load_event_batch(event_idx=start_event_idx, |
| | event_batch_size=event_batch_size) |
| | ret_dict = {} |
| | for sampled_idx in sampled_idx_list: |
| | batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
| | seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
| | sampled_idx['seq_idx'] * self.stride + self.seq_len) |
| | for imgt_idx, imgt in enumerate(self.data_types): |
| | sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
| | if imgt in ret_dict: |
| | ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
| | axis=0) |
| | else: |
| | ret_dict.update({imgt: sampled_seq}) |
| |
|
| | ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
| | data_types=self.data_types) |
| | if self.preprocess: |
| | ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
| | data_types=self.data_types, |
| | layout=self.layout, |
| | rescale=self.rescale_method) |
| |
|
| | if self.downsample_dict is not None: |
| | ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
| | data_types=self.data_types, |
| | factors_dict=self.downsample_dict, |
| | layout=self.layout) |
| | return ret_dict |
| |
|
| |
|
| | class SEVIRDataIterator(): |
| | ''' |
| | A wrapper s.t. it implements the function sample(). |
| | Every arguments in this class will be redirected to the inner SEVIRDataLoader object. |
| | If you expect a pythonic iterator, use SEVIRDataLoader instead. |
| | ''' |
| | def __init__(self, **kwargs): |
| | self.loader = SEVIRDataLoader(**kwargs) |
| | self.sample_mode = kwargs['sample_mode'] if 'sample_mode' in kwargs else 'random' |
| | |
| | def reset(self): |
| | self.loader.reset() |
| | |
| | def sample(self, batch_size=None): |
| | ''' |
| | The input param batch_size here is not used |
| | ''' |
| | out = next(self.loader, None) |
| | if out is None and self.sample_mode == 'random': |
| | self.loader.reset() |
| | out = next(self.loader, None) |
| | return out |
| | |
| | def __len__(self): |
| | """ |
| | Used only when self.sample_mode == 'sequent' |
| | """ |
| | return len(self.loader) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class Meteo(Dataset): |
| | def __init__(self, data_path, img_size, type='train', trans=None, in_len=-1): |
| | super().__init__() |
| | |
| | self.pixel_scale = 70.0 |
| | |
| | self.data_path = data_path |
| | self.img_size = img_size |
| | self.in_len = in_len |
| |
|
| | assert type in ['train', 'test', 'val'] |
| | self.type = type if type!='val' else 'test' |
| | with h5py.File(data_path,'r') as f: |
| | self.all_len = int(f[f'{self.type}_len'][()]) |
| | if trans is not None: |
| | self.transform = trans |
| | else: |
| | self.transform = T.Compose([ |
| | T.Resize((img_size, img_size)), |
| | |
| | |
| | |
| | |
| |
|
| | ]) |
| | |
| | def __len__(self): |
| | return self.all_len |
| |
|
| | def sample(self): |
| | index = np.random.randint(0, self.all_len) |
| | return self.__getitem__(index) |
| | |
| | |
| | def __getitem__(self, index): |
| |
|
| | with h5py.File(self.data_path,'r') as f: |
| | imgs = f[self.type][str(index)][()] |
| |
|
| | frames = torch.from_numpy(imgs).float().squeeze() |
| | frames = frames / self.pixel_scale |
| | frames = self.transform(frames).unsqueeze(1) |
| | |
| | |
| | return frames[:self.in_len], frames[self.in_len:] |
| | |
| |
|
| | def load_meteonet(batch_size, val_batch_size, in_len, train=False, num_workers=0, img_size=128): |
| | meteo_filepath = os.path.join(METEO_FILE_DIR, "meteo.h5") |
| | if train: |
| | train_set = Meteo(meteo_filepath, img_size, 'train', in_len=in_len) |
| | valid_set = Meteo(meteo_filepath, img_size, 'val', in_len=in_len) |
| | dataloader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) |
| | dataloader_valid = torch.utils.data.DataLoader(valid_set, batch_size=val_batch_size, shuffle=False, drop_last=True, num_workers=num_workers) |
| | return dataloader_train, dataloader_valid |
| | else: |
| | test_set = Meteo(meteo_filepath, img_size, 'test', in_len=in_len) |
| | dataloader_test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) |
| | return None, dataloader_test |