File size: 3,736 Bytes
b0e88cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | # EVOLVE-BLOCK-START
import jax
import jax.numpy as jnp
from dataclasses import dataclass
import numpy as np
import tqdm
@dataclass
class Hyperparameters:
max_integer: int = 250
num_restarts: int = 5
num_search_steps: int = 1000
initial_temperature: float = 0.01
class C6Searcher:
"""
Searches for a set U by running the search in pure Python for correctness.
"""
def __init__(self, hypers: Hyperparameters):
self.hypers = hypers
self.allowed_values = jnp.array((-1, 0, 1), dtype=jnp.int32)
@staticmethod
def _objective_fn(u_mask: jnp.ndarray) -> jnp.ndarray:
"""Calculates the C6 lower bound using jnp.unique"""
U = jnp.where(u_mask)[0]
sums = U[:, None] + U[None, :]
diffs = U[:, None] - U[None, :]
size_U_plus_U = jnp.unique(sums).shape[0]
size_U_minus_U = jnp.unique(diffs).shape[0]
max_U = jnp.max(U)
# Handle the case where max_U is 0 to avoid log(1)=0 in denominator
if max_U == 0:
return -1.0 # Return a low value for trivial sets
ratio = size_U_minus_U / size_U_plus_U
c6_bound = 1 + jnp.log(ratio) / jnp.log(2 * max_U + 1)
return -c6_bound # Return negative for maximization
def anneal_step(self, key, temp, current_mask, current_loss):
"""Performs one step of Simulated Annealing (not JIT-compiled)."""
# Propose a random mutation
idx_to_flip = jax.random.randint(key, (), 1, len(current_mask))
neighbor_mask = current_mask.at[idx_to_flip].set(1 - current_mask[idx_to_flip])
neighbor_loss = self._objective_fn(neighbor_mask)
delta_loss = neighbor_loss - current_loss
# Metropolis acceptance criterion
should_accept = False
if delta_loss < 0:
should_accept = True
else:
accept_prob = jnp.exp(-delta_loss / temp)
if jax.random.uniform(key) < accept_prob:
should_accept = True
if should_accept:
return neighbor_mask, neighbor_loss
else:
return current_mask, current_loss
def run():
hypers = Hyperparameters()
main_key = jax.random.PRNGKey(42)
best_loss = float("inf")
best_set_np = None
for i in range(hypers.num_restarts):
print(f"\n{'='*20} Restart {i+1}/{hypers.num_restarts} {'='*20}")
restart_key, main_key = jax.random.split(main_key)
loss, u_set_np = run_single_trial(hypers, restart_key)
if loss < best_loss:
print(f"New best C6 bound found: {-loss:.8f}")
best_loss = loss
best_set_np = u_set_np
c6_bound = -best_loss
print(f"\nSearch complete. Best C6 lower bound found: {c6_bound:.8f}")
return best_set_np, c6_bound
def run_single_trial(hypers, key):
# Initialize a random sparse set, ensuring 0 is included
key, subkey = jax.random.split(key)
sparsity = 0.95
u_mask = jax.random.bernoulli(subkey, p=(1 - sparsity), shape=(hypers.max_integer + 1,))
u_mask = u_mask.at[0].set(True)
searcher = C6Searcher(hypers)
current_loss = searcher._objective_fn(u_mask)
print(f"Starting SA search. Initial C6 bound: {-current_loss:.6f}")
current_mask = u_mask
for step in tqdm.tqdm(range(hypers.num_search_steps), desc="Annealing Progress"):
key, subkey = jax.random.split(key)
current_temp = hypers.initial_temperature * (1 - step / hypers.num_search_steps)
current_mask, current_loss = searcher.anneal_step(
subkey, jnp.maximum(current_temp, 1e-6), current_mask, current_loss
)
final_set = np.where(current_mask)[0]
return current_loss, final_set
# EVOLVE-BLOCK-END
|