| | import random |
| |
|
| |
|
| | def sample_future_length( |
| | range: tuple[int, int] | str = "gift_eval", |
| | total_length: int | None = None, |
| | ) -> int: |
| | """ |
| | Sample a forecast length. |
| | |
| | - If `range` is a tuple, uniformly sample in [min, max]. When `total_length` is |
| | provided, enforce a cap so the result is at most floor(0.45 * total_length). |
| | - If `range` is "gift_eval", sample from a pre-defined weighted set. When |
| | `total_length` is provided, filter out candidates greater than |
| | floor(0.45 * total_length) before sampling. |
| | """ |
| | |
| | cap: int | None = None |
| | if total_length is not None: |
| | cap = max(1, int(0.45 * int(total_length))) |
| |
|
| | if isinstance(range, tuple): |
| | min_len, max_len = range |
| | if cap is not None: |
| | effective_max_len = min(max_len, cap) |
| | |
| | if min_len > effective_max_len: |
| | return effective_max_len |
| | return random.randint(min_len, effective_max_len) |
| | return random.randint(min_len, max_len) |
| | elif range == "gift_eval": |
| | |
| | GIFT_EVAL_FORECAST_LENGTHS = { |
| | 48: 5, |
| | 720: 38, |
| | 480: 38, |
| | 30: 3, |
| | 300: 16, |
| | 8: 2, |
| | 120: 3, |
| | 450: 8, |
| | 80: 8, |
| | 12: 2, |
| | 900: 10, |
| | 180: 3, |
| | 600: 10, |
| | 60: 3, |
| | 210: 3, |
| | 195: 3, |
| | 140: 3, |
| | 130: 3, |
| | 14: 1, |
| | 18: 1, |
| | 13: 1, |
| | 6: 1, |
| | } |
| |
|
| | lengths = list(GIFT_EVAL_FORECAST_LENGTHS.keys()) |
| | weights = list(GIFT_EVAL_FORECAST_LENGTHS.values()) |
| |
|
| | if cap is not None: |
| | filtered = [ |
| | (length_candidate, weight) |
| | for length_candidate, weight in zip(lengths, weights, strict=True) |
| | if length_candidate <= cap |
| | ] |
| | if filtered: |
| | lengths, weights = zip(*filtered, strict=True) |
| | lengths = list(lengths) |
| | weights = list(weights) |
| |
|
| | return random.choices(lengths, weights=weights)[0] |
| | else: |
| | raise ValueError(f"Invalid range: {range}") |
| |
|