Harry00 commited on
Commit
909b197
·
verified ·
1 Parent(s): 8983e24

Upload mle/energy.py

Browse files
Files changed (1) hide show
  1. mle/energy.py +347 -0
mle/energy.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paysage d'Énergie Apprenant
3
+
4
+ La fonction d'énergie évalue la cohérence d'un état du système.
5
+ Elle doit être :
6
+ - Locale : les mises à jour ne dépendent que des voisins
7
+ - Apprenante : s'ajuste avec l'expérience
8
+ - Discriminative : basse énergie = cohérent, haute = incohérent
9
+ """
10
+
11
+ import numpy as np
12
+ from numba import njit, prange
13
+ from typing import Dict, List, Tuple, Optional, Set
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ VECTOR_SIZE = 4096
19
+
20
+
21
+ def compute_local_hamming_energy_fast(
22
+ state: np.ndarray,
23
+ neighbors: np.ndarray,
24
+ weights: np.ndarray,
25
+ ) -> float:
26
+ """
27
+ Énergie locale de Hamming : somme pondérée des distances aux voisins.
28
+ Version numpy vectorisée rapide.
29
+ """
30
+ N = neighbors.shape[0]
31
+ if N == 0:
32
+ return 0.0
33
+ # XOR vectorisé : state ^ neighbors[i]
34
+ xor_result = state.astype(np.uint8) ^ neighbors # (N, 4096)
35
+ # Somme par ligne
36
+ dists = np.sum(xor_result, axis=1)
37
+ return float(np.dot(weights, dists)) / N
38
+
39
+
40
+ def compute_bit_flip_delta_fast(
41
+ state: np.ndarray,
42
+ neighbors: np.ndarray,
43
+ weights: np.ndarray,
44
+ biases: np.ndarray,
45
+ ) -> np.ndarray:
46
+ """
47
+ Calcule le delta d'énergie pour chaque flip de bit possible.
48
+ Version numpy vectorisée.
49
+ """
50
+ N = neighbors.shape[0]
51
+ if N == 0:
52
+ return biases.copy()
53
+
54
+ # Pour chaque bit, calculer le changement si on le flippe
55
+ # dist_actuel = sum(weights[i] * (state[j] ^ neighbors[i,j]))
56
+ # dist_nouveau = sum(weights[i] * ((1-state[j]) ^ neighbors[i,j]))
57
+
58
+ state_expanded = state[np.newaxis, :] # (1, 4096)
59
+ # XOR actuel
60
+ xor_current = state_expanded ^ neighbors # (N, 4096)
61
+ # XOR après flip
62
+ flipped = 1 - state_expanded
63
+ xor_flipped = flipped ^ neighbors # (N, 4096)
64
+
65
+ # Delta = new_dist - old_dist pour chaque bit
66
+ delta_xor = xor_flipped.astype(np.float64) - xor_current.astype(np.float64) # (N, 4096)
67
+ # Pondération
68
+ weights_col = weights[:, np.newaxis] # (N, 1)
69
+ weighted_delta = delta_xor * weights_col # (N, 4096)
70
+
71
+ deltas = np.sum(weighted_delta, axis=0) # (4096,)
72
+ deltas += biases * (2.0 * (1 - state.astype(np.float64)) - 1.0)
73
+
74
+ return deltas
75
+
76
+
77
+ class EnergyLandscape:
78
+ """
79
+ Paysage d'énergie apprenant avec mises à jour locales.
80
+
81
+ Components:
82
+ - hamming_weight: poids de la distance de Hamming
83
+ - association_weight: poids des associations Hebbian
84
+ - structure_weight: poids de la cohérence structurelle (binding)
85
+ - biases: biais par bit (apprenant)
86
+ - local_weights: poids par voisin (apprenant)
87
+
88
+ Apprentissage:
89
+ - Renforcement des associations dans les états de basse énergie
90
+ - Affaiblissement dans les états instables
91
+ - Ajustement des biais locaux
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ hamming_weight: float = 1.0,
97
+ association_weight: float = 0.5,
98
+ structure_weight: float = 0.3,
99
+ bias_learning_rate: float = 0.01,
100
+ association_decay: float = 0.99,
101
+ energy_threshold_low: float = 500.0,
102
+ energy_threshold_high: float = 1500.0,
103
+ ):
104
+ self.hamming_weight = hamming_weight
105
+ self.association_weight = association_weight
106
+ self.structure_weight = structure_weight
107
+ self.bias_lr = bias_learning_rate
108
+ self.assoc_decay = association_decay
109
+ self.low_threshold = energy_threshold_low
110
+ self.high_threshold = energy_threshold_high
111
+
112
+ # Biases par bit (apprenant)
113
+ self.biases = np.zeros(VECTOR_SIZE, dtype=np.float64)
114
+
115
+ # Associations apprises : (id1, id2) -> strength
116
+ self.associations: Dict[Tuple[int, int], float] = {}
117
+
118
+ # Poids structurels pour binding
119
+ self.structure_weights: Dict[Tuple[int, int], float] = {}
120
+
121
+ # Historique pour apprentissage
122
+ self.energy_history: List[float] = []
123
+ self.min_energy_history: List[float] = []
124
+ self.stable_states: List[np.ndarray] = []
125
+
126
+ # Stats
127
+ self.n_updates = 0
128
+ self.total_energy = 0.0
129
+
130
+ def compute_energy(
131
+ self,
132
+ state: np.ndarray,
133
+ neighbor_vectors: np.ndarray,
134
+ neighbor_ids: List[int],
135
+ neighbor_metadata: Optional[List] = None,
136
+ ) -> float:
137
+ """
138
+ Calcule l'énergie totale d'un état.
139
+
140
+ Args:
141
+ state: vecteur courant (4096,) uint8
142
+ neighbor_vectors: (N, 4096) voisins actifs
143
+ neighbor_ids: IDs des voisins
144
+ neighbor_metadata: métadonnées optionnelles
145
+
146
+ Returns:
147
+ énergie scalaire (plus bas = plus cohérent)
148
+ """
149
+ N = len(neighbor_ids) if neighbor_ids else 0
150
+ if N == 0:
151
+ return float(np.sum(self.biases * (2.0 * state - 1.0)))
152
+
153
+ # Poids des voisins (apprenant)
154
+ weights = np.array([
155
+ self.associations.get(
156
+ tuple(sorted((-1, nid))), 0.5 # Default weight
157
+ ) + 0.5
158
+ for nid in neighbor_ids
159
+ ], dtype=np.float64)
160
+
161
+ # 1. Énergie de Hamming (vectorisée rapide)
162
+ hamming_energy = compute_local_hamming_energy_fast(
163
+ state, neighbor_vectors, weights
164
+ )
165
+ hamming_energy *= self.hamming_weight
166
+
167
+ # 2. Énergie d'association (Hebbian-like)
168
+ association_energy = 0.0
169
+ if len(neighbor_ids) > 1:
170
+ for i in range(len(neighbor_ids)):
171
+ for j in range(i+1, len(neighbor_ids)):
172
+ pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
173
+ strength = self.associations.get(pair, 0.0)
174
+ # Distance entre voisins : proches = faible énergie
175
+ dist = np.sum(neighbor_vectors[i] != neighbor_vectors[j])
176
+ association_energy += strength * dist
177
+
178
+ association_energy *= self.association_weight
179
+
180
+ # 3. Énergie de structure (binding)
181
+ structure_energy = 0.0
182
+ if len(neighbor_ids) > 1:
183
+ for i in range(len(neighbor_ids)):
184
+ for j in range(i+1, len(neighbor_ids)):
185
+ pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
186
+ struct_weight = self.structure_weights.get(pair, 0.0)
187
+ # Corrélation comme proxy de structure
188
+ corr = np.mean(neighbor_vectors[i] == neighbor_vectors[j])
189
+ structure_energy += struct_weight * (1.0 - corr)
190
+
191
+ structure_energy *= self.structure_weight
192
+
193
+ # 4. Biais local
194
+ bias_energy = float(np.sum(self.biases * (2.0 * state.astype(np.float64) - 1.0)))
195
+
196
+ total = hamming_energy + association_energy + structure_energy + bias_energy
197
+
198
+ self.energy_history.append(total)
199
+ if len(self.energy_history) > 1000:
200
+ self.energy_history = self.energy_history[-1000:]
201
+
202
+ return total
203
+
204
+ def get_bit_flip_deltas(
205
+ self,
206
+ state: np.ndarray,
207
+ neighbor_vectors: np.ndarray,
208
+ neighbor_ids: List[int],
209
+ ) -> np.ndarray:
210
+ """
211
+ Calcule le delta d'énergie pour chaque bit flip possible.
212
+ Utilisé par l'inférence pour trouver les meilleurs flips.
213
+
214
+ Returns:
215
+ deltas: (4096,) négatif = flip réduit l'énergie
216
+ """
217
+ N = len(neighbor_ids)
218
+ if N == 0:
219
+ # Juste les biais
220
+ return self.biases.copy()
221
+
222
+ weights = np.array([
223
+ self.associations.get(tuple(sorted((-1, nid))), 0.5) + 0.5
224
+ for nid in neighbor_ids
225
+ ], dtype=np.float64)
226
+
227
+ return compute_bit_flip_delta_fast(state, neighbor_vectors, weights, self.biases)
228
+
229
+ def update_from_state(
230
+ self,
231
+ state: np.ndarray,
232
+ neighbor_ids: List[int],
233
+ energy: float,
234
+ is_stable: bool = False,
235
+ ):
236
+ """
237
+ Met à jour le paysage d'énergie à partir d'un état observé.
238
+ Appelé après chaque itération d'inférence.
239
+
240
+ Règles:
241
+ - Basse énergie stable : renforce les associations
242
+ - Haute énergie instable : affaiblit ou réorganise
243
+ """
244
+ self.n_updates += 1
245
+ self.total_energy += energy
246
+
247
+ if len(neighbor_ids) < 2:
248
+ return
249
+
250
+ # Normalise l'énergie relative
251
+ if len(self.energy_history) > 20:
252
+ recent_mean = np.mean(self.energy_history[-20:])
253
+ recent_std = np.std(self.energy_history[-20:]) + 1e-8
254
+ normalized_energy = (energy - recent_mean) / recent_std
255
+ else:
256
+ normalized_energy = 0.0
257
+
258
+ # Détermine si l'état est cohérent ou incohérent
259
+ is_coherent = normalized_energy < -0.5 or (energy < self.low_threshold and is_stable)
260
+ is_incoherent = normalized_energy > 0.5 or energy > self.high_threshold
261
+
262
+ # Met à jour les associations
263
+ for i in range(len(neighbor_ids)):
264
+ for j in range(i+1, len(neighbor_ids)):
265
+ pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
266
+
267
+ if is_coherent:
268
+ # Renforce l'association (Hebbian-like)
269
+ current = self.associations.get(pair, 0.0)
270
+ self.associations[pair] = min(1.0, current + self.bias_lr * (1.0 - current))
271
+ elif is_incoherent:
272
+ # Affaiblit l'association (anti-Hebbian)
273
+ current = self.associations.get(pair, 0.0)
274
+ self.associations[pair] = max(-0.5, current - self.bias_lr * 2.0)
275
+
276
+ # Met à jour les biais locaux
277
+ state_float = state.astype(np.float64)
278
+ if is_coherent:
279
+ # Renforce les bits qui sont cohérents avec le voisinage
280
+ self.biases += self.bias_lr * (2.0 * state_float - 1.0) * 0.1
281
+ elif is_incoherent:
282
+ # Affaiblit les bits qui sont incohérents
283
+ self.biases -= self.bias_lr * (2.0 * state_float - 1.0) * 0.2
284
+
285
+ # Normalise les biais pour éviter la dérive
286
+ self.biases = np.clip(self.biases, -2.0, 2.0)
287
+
288
+ # Si stable et basse énergie, mémorise l'état
289
+ if is_stable and is_coherent:
290
+ self.stable_states.append(state.copy())
291
+ if len(self.stable_states) > 100:
292
+ self.stable_states = self.stable_states[-100:]
293
+
294
+ # Déduit les associations avec le temps
295
+ if self.n_updates % 100 == 0:
296
+ self._decay_associations()
297
+
298
+ def update_structure_weights(
299
+ self,
300
+ bound_state: np.ndarray,
301
+ component_ids: List[int],
302
+ coherence: float,
303
+ ):
304
+ """
305
+ Met à jour les poids structurels pour le binding.
306
+ """
307
+ for i in range(len(component_ids)):
308
+ for j in range(i+1, len(component_ids)):
309
+ pair = tuple(sorted((component_ids[i], component_ids[j])))
310
+ current = self.structure_weights.get(pair, 0.0)
311
+ # Mise à jour basée sur la cohérence
312
+ if coherence > 0.7:
313
+ self.structure_weights[pair] = min(1.0, current + self.bias_lr)
314
+ elif coherence < 0.3:
315
+ self.structure_weights[pair] = max(-0.5, current - self.bias_lr * 1.5)
316
+
317
+ def _decay_associations(self):
318
+ """Décroissance périodique des associations peu utilisées."""
319
+ to_remove = []
320
+ for pair, strength in self.associations.items():
321
+ decayed = strength * self.assoc_decay
322
+ if abs(decayed) < 0.01:
323
+ to_remove.append(pair)
324
+ else:
325
+ self.associations[pair] = decayed
326
+
327
+ for pair in to_remove:
328
+ del self.associations[pair]
329
+
330
+ def get_association_strength(self, id1: int, id2: int) -> float:
331
+ """Retourne la force d'association entre deux vecteurs."""
332
+ pair = tuple(sorted((id1, id2)))
333
+ return self.associations.get(pair, 0.0)
334
+
335
+ def get_stats(self) -> Dict:
336
+ n_assoc = len(self.associations)
337
+ n_struct = len(self.structure_weights)
338
+ recent_energy = np.mean(self.energy_history[-100:]) if self.energy_history else 0.0
339
+
340
+ return {
341
+ 'n_associations': n_assoc,
342
+ 'n_structure_weights': n_struct,
343
+ 'mean_bias': float(np.mean(np.abs(self.biases))),
344
+ 'recent_mean_energy': float(recent_energy),
345
+ 'n_updates': self.n_updates,
346
+ 'n_stable_states': len(self.stable_states),
347
+ }