| | import os |
| | import pandas as pd |
| | import numpy as np |
| | import bisect |
| | from nowcasting import image |
| | from nowcasting.mask import * |
| | from nowcasting.config import cfg |
| | from nowcasting.utils import * |
| | import math |
| | import json |
| |
|
| | def encode_month(month): |
| | """Encode the month into a vector |
| | |
| | Parameters |
| | ---------- |
| | month : np.ndarray |
| | (...,) int, between 1 and 12 |
| | Returns |
| | ------- |
| | ret : np.ndarray |
| | (..., 2) float |
| | """ |
| | angle = 2 * np.pi * month/12.0 |
| | ret = np.empty(shape=month.shape + (2,), dtype=np.float32) |
| | ret[..., 0] = np.cos(angle) |
| | ret[..., 1] = np.sin(angle) |
| | return ret |
| |
|
| |
|
| | def decode_month(code): |
| | """Decode the month code back to the month value |
| | |
| | Parameters |
| | ---------- |
| | code : np.ndarray |
| | (..., 2) float |
| | Returns |
| | ------- |
| | month : np.ndarray |
| | (...,) int |
| | """ |
| | assert code.shape[-1] == 2 |
| | flag = code[..., 1] >= 0 |
| | arccos_res = np.arccos(code[..., 0]) |
| | angle = flag * arccos_res + (1 - flag) * (2 * np.pi - arccos_res) |
| | month = angle / (2.0 * np.pi) * 12.0 |
| | month = np.round(month).astype(int) |
| | return month |
| |
|
| |
|
| | def get_valid_datetime_set(): |
| | valid_datetime_set = pickle.load(open(cfg.HKO_VALID_DATETIME_PATH, 'rb')) |
| | return valid_datetime_set |
| |
|
| |
|
| | def get_exclude_mask(): |
| | with np.load(os.path.join(cfg.HKO_DATA_BASE_PATH, 'mask_dat.npz')) as dat: |
| | exclude_mask = dat['exclude_mask'][:] |
| | return exclude_mask |
| |
|
| |
|
| | def convert_datetime_to_filepath(date_time): |
| | """Convert datetime to the filepath |
| | |
| | Parameters |
| | ---------- |
| | date_time : datetime.datetime |
| | |
| | Returns |
| | ------- |
| | ret : str |
| | """ |
| | ret = os.path.join("%04d" %date_time.year, |
| | "%02d" %date_time.month, |
| | "%02d" %date_time.day, |
| | 'RAD%02d%02d%02d%02d%02d00.png' |
| | %(date_time.year - 2000, date_time.month, date_time.day, |
| | date_time.hour, date_time.minute)) |
| | ret = os.path.join(cfg.HKO_PNG_PATH, ret) |
| | return ret |
| |
|
| |
|
| | def convert_datetime_to_maskpath(date_time): |
| | """Convert datetime to path of the mask |
| | |
| | Parameters |
| | ---------- |
| | date_time : datetime.datetime |
| | |
| | Returns |
| | ------- |
| | ret : str |
| | """ |
| | ret = os.path.join("%04d" %date_time.year, |
| | "%02d" %date_time.month, |
| | "%02d" %date_time.day, |
| | 'RAD%02d%02d%02d%02d%02d00.mask' |
| | %(date_time.year - 2000, date_time.month, date_time.day, |
| | date_time.hour, date_time.minute)) |
| | ret = os.path.join(cfg.HKO_MASK_PATH, ret) |
| | return ret |
| |
|
| |
|
| | class HKOSimpleBuffer(object): |
| | def __init__(self, df, max_buffer_length, width, height): |
| | self._df = df |
| | self._max_buffer_length = max_buffer_length |
| | assert self._df.size > self._max_buffer_length |
| | self._width = width |
| | self._height = height |
| |
|
| | def reset(self): |
| | self._datetime_keys = self._df.index[:self._max_buffer_length] |
| | self._load() |
| |
|
| | def _load(self): |
| | paths = [] |
| | for i in range(self._datetime_keys.size): |
| | paths.append(convert_datetime_to_filepath(self._datetime_keys[i])) |
| | self._frame_dat = image.quick_read_frames(path_list=paths, |
| | im_h=self._height, |
| | im_w=self._width, |
| | grayscale=True) |
| | self._frame_dat = self._frame_dat.reshape((self._max_buffer_length, 1, |
| | self._height, self._width)) |
| | self._noise_mask_dat = np.zeros((self._datetime_keys.size, 1, |
| | self._height, self._width), |
| | dtype=np.uint8) |
| |
|
| | def get(self, timestamps): |
| | """timestamps must be sorted |
| | |
| | Parameters |
| | ---------- |
| | timestamps |
| | |
| | Returns |
| | ------- |
| | |
| | """ |
| | if not (timestamps[0] in self._datetime_keys and timestamps[-1] in self._datetime_keys): |
| | read_begin_ind = self._df.index[self._df.index.get_loc(timestamps[0])] |
| | read_end_ind = min(read_begin_ind + self._max_buffer_length, self._df.size) |
| | assert self._df.index[read_end_ind - 1] >= timestamps[-1] |
| | self._datetime_keys = self._df.index[read_begin_ind:read_end_ind] |
| | self._load() |
| | begin_ind = self._datetime_keys.get_loc(timestamps[0]) |
| | end_ind = self._datetime_keys.get_loc(timestamps[-1]) + 1 |
| | return self._frame_dat[begin_ind:end_ind, :, :, :],\ |
| | self._noise_mask_dat[begin_ind:end_ind, :, :, :] |
| |
|
| |
|
| | def pad_hko_dat(frame_dat, mask_dat, batch_size): |
| | if frame_dat.shape[1] < batch_size: |
| | ret_frame_dat = np.zeros(shape=(frame_dat.shape[0], batch_size, |
| | frame_dat.shape[2], frame_dat.shape[3], frame_dat.shape[4]), |
| | dtype=frame_dat.dtype) |
| | ret_mask_dat = np.zeros(shape=(mask_dat.shape[0], batch_size, |
| | mask_dat.shape[2], mask_dat.shape[3], mask_dat.shape[4]), |
| | dtype=mask_dat.dtype) |
| | ret_frame_dat[:, :frame_dat.shape[1], ...] = frame_dat |
| | ret_mask_dat[:, :frame_dat.shape[1], ...] = mask_dat |
| | return ret_frame_dat, ret_mask_dat, frame_dat.shape[1] |
| | else: |
| | return frame_dat, mask_dat, batch_size |
| |
|
| |
|
| | _exclude_mask = get_exclude_mask() |
| | def precompute_mask(img): |
| | if img.dtype == np.uint8: |
| | threshold = round(cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD * 255.0) |
| | else: |
| | threshold = cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD |
| | mask = np.zeros_like(img, dtype=bool) |
| | mask[:] = np.broadcast_to((1 - _exclude_mask).astype(bool), shape=img.shape) |
| | mask[np.logical_and(img < threshold, |
| | img > 0)] = 0 |
| | return mask |
| |
|
| |
|
| | class HKOIterator(object): |
| | """The iterator for HKO-7 dataset |
| | |
| | """ |
| | def __init__(self, pd_path, sample_mode, seq_len=30, |
| | max_consecutive_missing=2, begin_ind=None, end_ind=None, |
| | stride=None, width=None, height=None, base_freq='6min'): |
| | """Random sample: sample a random clip that will not violate the max_missing frame_num criteria |
| | Sequent sample: sample a clip from the beginning of the time. |
| | Everytime, the clips from {T_begin, T_begin + 6min, ..., T_begin + (seq_len-1) * 6min} will be used |
| | The begin datetime will move forward by adding stride: T_begin += 6min * stride |
| | Once the clips violates the maximum missing number criteria, the starting |
| | point will be moved to the next datetime that does not violate the missing_frame criteria |
| | |
| | Parameters |
| | ---------- |
| | pd_path : str |
| | path of the saved pandas dataframe |
| | sample_mode : str |
| | Can be "random" or "sequent" |
| | seq_len : int |
| | max_consecutive_missing : int |
| | The maximum consecutive missing frames |
| | begin_ind : int |
| | Index of the begin frame |
| | end_ind : int |
| | Index of the end frame |
| | stride : int or None, optional |
| | width : int or None, optional |
| | height : int or None, optional |
| | base_freq : str, optional |
| | """ |
| | if width is None: |
| | width = cfg.HKO.ITERATOR.WIDTH |
| | if height is None: |
| | height = cfg.HKO.ITERATOR.HEIGHT |
| | self._df = pd.read_pickle(pd_path) |
| | self.set_begin_end(begin_ind=begin_ind, end_ind=end_ind) |
| | self._df_index_set = frozenset([self._df.index[i] for i in range(self._df.size)]) |
| | self._exclude_mask = get_exclude_mask() |
| | self._seq_len = seq_len |
| | self._width = width |
| | self._height = height |
| | self._stride = stride |
| | self._max_consecutive_missing = max_consecutive_missing |
| | self._base_freq = base_freq |
| | self._base_time_delta = pd.Timedelta(base_freq) |
| | assert sample_mode in ["random", "sequent"], "Sample mode=%s is not supported" %sample_mode |
| | self.sample_mode = sample_mode |
| | if sample_mode == "sequent": |
| | assert self._stride is not None |
| | self._current_datetime = self.begin_time |
| | self._buffer_mult = 6 |
| | self._buffer_datetime_keys = None |
| | self._buffer_frame_dat = None |
| | self._buffer_mask_dat = None |
| | else: |
| | self._max_buffer_length = None |
| |
|
| | def set_begin_end(self, begin_ind=None, end_ind=None): |
| | self._begin_ind = 0 if begin_ind is None else begin_ind |
| | self._end_ind = self.total_frame_num - 1 if end_ind is None else end_ind |
| |
|
| | @property |
| | def total_frame_num(self): |
| | return self._df.size |
| |
|
| | @property |
| | def begin_time(self): |
| | return self._df.index[self._begin_ind] |
| |
|
| | @property |
| | def end_time(self): |
| | return self._df.index[self._end_ind] |
| |
|
| | @property |
| | def use_up(self): |
| | if self.sample_mode == "random": |
| | return False |
| | else: |
| | return self._current_datetime > self.end_time |
| |
|
| | def _next_exist_timestamp(self, timestamp): |
| | next_ind = bisect.bisect_right(self._df.index, timestamp) |
| | if next_ind >= self._df.size: |
| | return None |
| | else: |
| | return self._df.index[bisect.bisect_right(self._df.index, timestamp)] |
| |
|
| | def _is_valid_clip(self, datetime_clip): |
| | """Check if the given datetime_clip is valid |
| | |
| | Parameters |
| | ---------- |
| | datetime_clip : |
| | |
| | Returns |
| | ------- |
| | ret : bool |
| | """ |
| | missing_count = 0 |
| | for i in range(len(datetime_clip)): |
| | if datetime_clip[i] not in self._df_index_set: |
| | missing_count += 1 |
| | if missing_count > self._max_consecutive_missing or\ |
| | missing_count >= len(datetime_clip): |
| | return False |
| | else: |
| | missing_count = 0 |
| | return True |
| |
|
| | def _load_frames(self, datetime_clips): |
| | assert isinstance(datetime_clips, list) |
| | for clip in datetime_clips: |
| | assert len(clip) == self._seq_len |
| | batch_size = len(datetime_clips) |
| | frame_dat = np.zeros((self._seq_len, batch_size, 1, self._height, self._width), |
| | dtype=np.uint8) |
| | mask_dat = np.zeros((self._seq_len, batch_size, 1, self._height, self._width), |
| | dtype=bool) |
| | if self.sample_mode == "random": |
| | paths = [] |
| | mask_paths = [] |
| | hit_inds = [] |
| | miss_inds = [] |
| | for i in range(self._seq_len): |
| | for j in range(batch_size): |
| | timestamp = datetime_clips[j][i] |
| | if timestamp in self._df_index_set: |
| | paths.append(convert_datetime_to_filepath(datetime_clips[j][i])) |
| | mask_paths.append(convert_datetime_to_maskpath(datetime_clips[j][i])) |
| | hit_inds.append([i, j]) |
| | else: |
| | miss_inds.append([i, j]) |
| | hit_inds = np.array(hit_inds, dtype=int) |
| | all_frame_dat = image.quick_read_frames(path_list=paths, |
| | im_h=self._height, |
| | im_w=self._width, |
| | grayscale=True) |
| | all_mask_dat = quick_read_masks(mask_paths) |
| | frame_dat[hit_inds[:, 0], hit_inds[:, 1], :, :, :] = all_frame_dat |
| | mask_dat[hit_inds[:, 0], hit_inds[:, 1], :, :, :] = all_mask_dat |
| | else: |
| | |
| | first_timestamp = datetime_clips[-1][-1] |
| | last_timestamp = datetime_clips[0][0] |
| | for i in range(self._seq_len): |
| | for j in range(batch_size): |
| | timestamp = datetime_clips[j][i] |
| | if timestamp in self._df_index_set: |
| | first_timestamp = min(first_timestamp, timestamp) |
| | last_timestamp = max(last_timestamp, timestamp) |
| | if self._buffer_datetime_keys is None or\ |
| | not (first_timestamp in self._buffer_datetime_keys |
| | and last_timestamp in self._buffer_datetime_keys): |
| | read_begin_ind = self._df.index.get_loc(first_timestamp) |
| | read_end_ind = self._df.index.get_loc(last_timestamp) + 1 |
| | read_end_ind = min(read_begin_ind + |
| | self._buffer_mult * (read_end_ind - read_begin_ind), |
| | self._df.size) |
| | self._buffer_datetime_keys = self._df.index[read_begin_ind:read_end_ind] |
| | |
| | paths = [] |
| | mask_paths = [] |
| | for i in range(self._buffer_datetime_keys.size): |
| | paths.append(convert_datetime_to_filepath(self._buffer_datetime_keys[i])) |
| | mask_paths.append(convert_datetime_to_maskpath(self._buffer_datetime_keys[i])) |
| | self._buffer_frame_dat = image.quick_read_frames(path_list=paths, |
| | im_h=self._height, |
| | im_w=self._width, |
| | grayscale=True) |
| | self._buffer_mask_dat = quick_read_masks(mask_paths) |
| | for i in range(self._seq_len): |
| | for j in range(batch_size): |
| | timestamp = datetime_clips[j][i] |
| | if timestamp in self._df_index_set: |
| | assert timestamp in self._buffer_datetime_keys |
| | ind = self._buffer_datetime_keys.get_loc(timestamp) |
| | frame_dat[i, j, :, :, :] = self._buffer_frame_dat[ind, :, :, :] |
| | mask_dat[i, j, :, :, :] = self._buffer_mask_dat[ind, :, :, :] |
| | return frame_dat, mask_dat |
| |
|
| | def reset(self, begin_ind=None, end_ind=None): |
| | assert self.sample_mode == "sequent" |
| | self.set_begin_end(begin_ind=begin_ind, end_ind=end_ind) |
| | self._current_datetime = self.begin_time |
| |
|
| | def random_reset(self): |
| | assert self.sample_mode == "sequent" |
| | self.set_begin_end(begin_ind=np.random.randint(0, |
| | self.total_frame_num - |
| | 5 * self._seq_len), |
| | end_ind=None) |
| | self._current_datetime = self.begin_time |
| |
|
| | def check_new_start(self): |
| | assert self.sample_mode == "sequent" |
| | datetime_clip = pd.date_range(start=self._current_datetime, |
| | periods=self._seq_len, |
| | freq=self._base_freq) |
| | if self._is_valid_clip(datetime_clip): |
| | return self._current_datetime == self.begin_time |
| | else: |
| | return True |
| |
|
| | def sample(self, batch_size, only_return_datetime=False): |
| | """Sample a minibatch from the hko7 dataset based on the given type and pd_file |
| | |
| | Parameters |
| | ---------- |
| | batch_size : int |
| | Batch size |
| | only_return_datetime : bool |
| | Whether to only return the datetimes |
| | Returns |
| | ------- |
| | frame_dat : np.ndarray |
| | Shape: (seq_len, valid_batch_size, 1, height, width) |
| | mask_dat : np.ndarray |
| | Shape: (seq_len, valid_batch_size, 1, height, width) |
| | datetime_clips : list |
| | length should be valid_batch_size |
| | new_start : bool |
| | """ |
| | if self.sample_mode == 'sequent': |
| | if self.use_up: |
| | raise ValueError("The HKOIterator has been used up!") |
| | datetime_clips = [] |
| | new_start = False |
| | for i in range(batch_size): |
| | while not self.use_up: |
| | datetime_clip = pd.date_range(start=self._current_datetime, |
| | periods=self._seq_len, |
| | freq=self._base_freq) |
| | if self._is_valid_clip(datetime_clip): |
| | new_start = new_start or (self._current_datetime == self.begin_time) |
| | datetime_clips.append(datetime_clip) |
| | self._current_datetime += self._stride * self._base_time_delta |
| | break |
| | else: |
| | new_start = True |
| | self._current_datetime =\ |
| | self._next_exist_timestamp(timestamp=self._current_datetime) |
| | if self._current_datetime is None: |
| | |
| | |
| | self._current_datetime = self.end_time + self._base_time_delta |
| | break |
| | continue |
| | new_start = None if batch_size != 1 else new_start |
| | if only_return_datetime: |
| | return datetime_clips, new_start |
| | else: |
| | assert only_return_datetime is False |
| | datetime_clips = [] |
| | new_start = None |
| | for i in range(batch_size): |
| | while True: |
| | rand_ind = np.random.randint(0, self._df.size, 1)[0] |
| | random_datetime = self._df.index[rand_ind] |
| | datetime_clip = pd.date_range(start=random_datetime, |
| | periods=self._seq_len, |
| | freq=self._base_freq) |
| | if self._is_valid_clip(datetime_clip): |
| | datetime_clips.append(datetime_clip) |
| | break |
| | frame_dat, mask_dat = self._load_frames(datetime_clips=datetime_clips) |
| | return frame_dat, mask_dat, datetime_clips, new_start |
| |
|
| | |
| | if __name__ == '__main__': |
| | np.random.seed(123) |
| | import time |
| | import cProfile, pstats |
| | from nowcasting.config import cfg |
| | from nowcasting.helpers.visualization import save_hko_gif, save_hko_movie |
| |
|
| | minibatch_size = 32 |
| | seq_len = 30 |
| | train_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN, |
| | sample_mode="random", |
| | seq_len=seq_len) |
| | valid_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_VALID, |
| | sample_mode="sequent", |
| | seq_len=seq_len, |
| | stride=5) |
| | test_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST, |
| | sample_mode="sequent", |
| | seq_len=seq_len, |
| | stride=5) |
| |
|
| | repeat_time = 3 |
| | pr = cProfile.Profile() |
| | pr.enable() |
| | begin = time.time() |
| | for i in range(repeat_time): |
| | sample_sequence, sample_mask, sample_datetime_clips, new_start =\ |
| | train_hko_iter.sample(batch_size=minibatch_size) |
| | end = time.time() |
| | pr.disable() |
| | ps = pstats.Stats(pr).sort_stats('cumulative') |
| | ps.print_stats(20) |
| | print("Train Data Sample FPS: %f" % (minibatch_size * seq_len |
| | * repeat_time / float(end - begin))) |
| |
|
| | begin = time.time() |
| | for i in range(repeat_time): |
| | sample_sequence, sample_mask, sample_datetimes, new_start =\ |
| | valid_hko_iter.sample(batch_size=minibatch_size) |
| | end = time.time() |
| | print("Valid Data Sample FPS: %f" % (minibatch_size * seq_len |
| | * repeat_time / float(end - begin))) |
| | begin = time.time() |
| | for i in range(repeat_time): |
| | sample_sequence, sample_mask, sample_datetimes, new_start =\ |
| | test_hko_iter.sample(batch_size=minibatch_size) |
| | end = time.time() |
| | print("Test Data Sample FPS: %f" %(minibatch_size * seq_len |
| | * repeat_time / float(end-begin))) |
| | code = encode_month(np.arange(1, 13)) |
| | month = decode_month(code) |
| | print(code) |
| | print(month.T) |
| |
|
| | train_time = 0 |
| | for i in range(30): |
| | train_batch, train_mask, sample_datetimes, new_start = \ |
| | train_hko_iter.sample(batch_size=minibatch_size) |
| | name_str = 'train_' + str(i) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M') |
| | save_hko_movie(train_batch[:, 0, 0, :, :], |
| | sample_datetimes[0], |
| | train_mask[:, 0, 0, :, :], |
| | masked=False, |
| | save_path=name_str + '.mp4') |
| | tic = time.time() |
| | save_hko_movie(train_batch[:, 0, 0, :, :], |
| | sample_datetimes[0], |
| | train_mask[:, 0, 0, :, :], |
| | masked=True, |
| | save_path=name_str + '_filtered.mp4') |
| | toc = time.time() |
| | save_hko_movie(train_mask[:, 0, 0, :, :].astype(np.uint8) * 255, |
| | sample_datetimes[0], |
| | None, |
| | masked=False, |
| | save_path=name_str + '_mask.mp4') |
| | print('train, time:', toc - tic) |
| |
|
| | valid_time = 0 |
| | while not valid_hko_iter.use_up: |
| | valid_batch, valid_mask, sample_datetimes, new_start =\ |
| | valid_hko_iter.sample(batch_size=minibatch_size) |
| | if valid_batch.shape[1] == 0: |
| | break |
| | name_str = 'valid_' + str(valid_time) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M') |
| | save_hko_movie(valid_batch[:, 0, 0, :, :], |
| | sample_datetimes[0], |
| | valid_mask[:, 0, 0, :, :], |
| | masked=False, |
| | save_path=name_str + '.mp4') |
| | tic = time.time() |
| | save_hko_movie(valid_batch[:, 0, 0, :, :], |
| | sample_datetimes[0], |
| | valid_mask[:, 0, 0, :, :], |
| | masked=True, |
| | save_path=name_str + '_filtered.mp4') |
| | toc = time.time() |
| | save_hko_movie(valid_mask[:, 0, 0, :, :].astype(np.uint8) * 255, |
| | sample_datetimes[0], |
| | None, |
| | masked=False, |
| | save_path=name_str + '_mask.mp4') |
| | print('valid, time:', toc - tic) |
| | print(valid_batch.shape[1]) |
| | valid_time += 1 |
| | print(valid_time) |
| |
|