dreamlessx commited on
Commit
e013072
·
verified ·
1 Parent(s): 9475836

Upload landmarkdiff/curriculum.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/curriculum.py +194 -0
landmarkdiff/curriculum.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curriculum learning support for progressive training difficulty.
2
+
3
+ Implements a schedule that controls which training samples are used
4
+ at different stages of training, starting with easy examples (small
5
+ displacements) and gradually introducing harder ones.
6
+
7
+ Usage in training loop::
8
+
9
+ curriculum = TrainingCurriculum(
10
+ total_steps=100000,
11
+ warmup_fraction=0.1, # first 10% easy only
12
+ full_difficulty_at=0.5, # full dataset by 50%
13
+ )
14
+
15
+ # In training loop:
16
+ difficulty = curriculum.get_difficulty(global_step)
17
+ # Use difficulty to filter/weight samples
18
+
19
+ Or as a dataset wrapper::
20
+
21
+ dataset = CurriculumDataset(
22
+ base_dataset=SyntheticPairDataset(data_dir),
23
+ metadata_path=Path(data_dir) / "metadata.json",
24
+ total_steps=100000,
25
+ )
26
+ # Call dataset.set_step(global_step) each iteration
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import math
33
+ from pathlib import Path
34
+
35
+ import numpy as np
36
+
37
+
38
+ class TrainingCurriculum:
39
+ """Schedule that maps training step to difficulty level [0, 1].
40
+
41
+ Difficulty 0 = easiest (smallest displacements, lowest intensity).
42
+ Difficulty 1 = full dataset (all difficulties).
43
+
44
+ The schedule uses a cosine ramp:
45
+ - During warmup: difficulty = 0 (easy only)
46
+ - warmup → full_difficulty: cosine ramp from 0 → 1
47
+ - After full_difficulty: difficulty = 1 (full dataset)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ total_steps: int,
53
+ warmup_fraction: float = 0.1,
54
+ full_difficulty_at: float = 0.5,
55
+ ):
56
+ self.total_steps = total_steps
57
+ self.warmup_steps = int(total_steps * warmup_fraction)
58
+ self.full_steps = int(total_steps * full_difficulty_at)
59
+
60
+ def get_difficulty(self, step: int) -> float:
61
+ """Get difficulty level [0, 1] for the given training step."""
62
+ if step < self.warmup_steps:
63
+ return 0.0
64
+ if step >= self.full_steps:
65
+ return 1.0
66
+ progress = (step - self.warmup_steps) / max(1, self.full_steps - self.warmup_steps)
67
+ return 0.5 * (1 - math.cos(math.pi * progress))
68
+
69
+ def should_include(
70
+ self,
71
+ step: int,
72
+ sample_difficulty: float,
73
+ rng: np.random.Generator | None = None,
74
+ ) -> bool:
75
+ """Whether to include a sample of the given difficulty at this step.
76
+
77
+ Uses probabilistic inclusion so harder samples gradually appear.
78
+
79
+ Args:
80
+ step: Current training step.
81
+ sample_difficulty: Difficulty of the sample [0, 1].
82
+ rng: Random number generator for stochastic inclusion.
83
+
84
+ Returns:
85
+ True if sample should be used.
86
+ """
87
+ curr_difficulty = self.get_difficulty(step)
88
+ if sample_difficulty <= curr_difficulty:
89
+ return True
90
+ # Stochastic inclusion for samples slightly above threshold
91
+ if rng is None:
92
+ rng = np.random.default_rng()
93
+ overshoot = sample_difficulty - curr_difficulty
94
+ include_prob = max(0, 1.0 - overshoot * 5) # drops off quickly
95
+ return rng.random() < include_prob
96
+
97
+
98
+ class ProcedureCurriculum:
99
+ """Procedure-aware curriculum that adjusts per-procedure weights.
100
+
101
+ Some procedures are inherently harder (e.g., orthognathic with large
102
+ deformations). This curriculum increases their weight over training.
103
+ """
104
+
105
+ # Difficulty ranking (0=easiest, 1=hardest)
106
+ DEFAULT_PROCEDURE_DIFFICULTY = {
107
+ "blepharoplasty": 0.3, # small, localized changes
108
+ "rhinoplasty": 0.5, # moderate, central face
109
+ "rhytidectomy": 0.7, # large, affects face shape
110
+ "orthognathic": 0.9, # largest deformations
111
+ }
112
+
113
+ def __init__(
114
+ self,
115
+ total_steps: int,
116
+ procedure_difficulty: dict[str, float] | None = None,
117
+ warmup_fraction: float = 0.1,
118
+ ):
119
+ self.curriculum = TrainingCurriculum(total_steps, warmup_fraction)
120
+ self.proc_difficulty = procedure_difficulty or self.DEFAULT_PROCEDURE_DIFFICULTY
121
+
122
+ def get_weight(self, step: int, procedure: str) -> float:
123
+ """Get sampling weight for a procedure at the given step.
124
+
125
+ Returns a value in [0.1, 1.0] — never fully excludes any procedure.
126
+ """
127
+ difficulty = self.get_difficulty(step)
128
+ proc_diff = self.proc_difficulty.get(procedure, 0.5)
129
+
130
+ if proc_diff <= difficulty:
131
+ return 1.0
132
+ # Reduce weight for too-hard procedures
133
+ return max(0.1, 1.0 - (proc_diff - difficulty) * 2)
134
+
135
+ def get_difficulty(self, step: int) -> float:
136
+ return self.curriculum.get_difficulty(step)
137
+
138
+ def get_procedure_weights(self, step: int) -> dict[str, float]:
139
+ """Get all procedure weights at the given step."""
140
+ return {
141
+ proc: self.get_weight(step, proc)
142
+ for proc in self.proc_difficulty
143
+ }
144
+
145
+
146
+ def compute_sample_difficulty(
147
+ metadata_path: str | Path,
148
+ displacement_model_path: str | Path | None = None,
149
+ ) -> dict[str, float]:
150
+ """Compute difficulty scores for each sample in the dataset.
151
+
152
+ Difficulty is based on:
153
+ 1. Displacement intensity (from metadata)
154
+ 2. Procedure difficulty
155
+ 3. Source type (real > synthetic)
156
+
157
+ Returns:
158
+ Dict mapping sample prefix to difficulty score [0, 1].
159
+ """
160
+ with open(metadata_path) as f:
161
+ meta = json.load(f)
162
+
163
+ pairs = meta.get("pairs", {})
164
+ difficulties = {}
165
+
166
+ proc_base = {
167
+ "blepharoplasty": 0.2,
168
+ "rhinoplasty": 0.4,
169
+ "rhytidectomy": 0.6,
170
+ "orthognathic": 0.8,
171
+ "unknown": 0.5,
172
+ }
173
+
174
+ source_bonus = {
175
+ "synthetic": 0.0,
176
+ "synthetic_v3": 0.1, # realistic displacements slightly harder
177
+ "real": 0.2, # real data hardest
178
+ "augmented": 0.0,
179
+ }
180
+
181
+ for prefix, info in pairs.items():
182
+ proc = info.get("procedure", "unknown")
183
+ source = info.get("source", "synthetic")
184
+ intensity = info.get("intensity", 1.0)
185
+
186
+ # Combine factors
187
+ base = proc_base.get(proc, 0.5)
188
+ src = source_bonus.get(source, 0.0)
189
+ # Intensity scaling (higher intensity = harder)
190
+ int_factor = min(1.0, intensity / 1.5) * 0.2
191
+
192
+ difficulties[prefix] = min(1.0, base + src + int_factor)
193
+
194
+ return difficulties