| | from typing import List, TYPE_CHECKING |
| | import torch |
| | import sys |
| |
|
| | try: |
| | if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
| | compile_fn = torch.compile |
| | else: |
| | raise RuntimeError |
| | except Exception: |
| |
|
| | def compile_fn(fn=None, **kwargs): |
| | if fn is None: |
| | return lambda f: f |
| | return fn |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from .model import BitTransformerLM |
| |
|
| |
|
| | @compile_fn |
| | def bytes_to_bits(data: bytes) -> List[int]: |
| | """Convert bytes to bits with per-byte parity bit.""" |
| | result: List[int] = [] |
| | for b in data: |
| | bits = [(b >> i) & 1 for i in reversed(range(8))] |
| | parity = sum(bits) % 2 |
| | result.extend(bits + [parity]) |
| | return result |
| |
|
| |
|
| | @compile_fn |
| | def bits_to_bytes(bits: List[int]) -> bytes: |
| | """Convert parity-protected bits back to bytes.""" |
| | if len(bits) % 9 != 0: |
| | raise ValueError("Bit stream length must be multiple of 9") |
| | out = bytearray() |
| | for i in range(0, len(bits), 9): |
| | chunk = bits[i : i + 9] |
| | payload = chunk[:8] |
| | parity = chunk[8] |
| | if parity != sum(payload) % 2: |
| | raise ValueError("Parity check failed") |
| | value = 0 |
| | for bit in payload: |
| | value = (value << 1) | bit |
| | out.append(value) |
| | return bytes(out) |
| |
|
| |
|
| | def text_to_bits(text: str) -> List[int]: |
| | return bytes_to_bits(text.encode("utf-8")) |
| |
|
| |
|
| | def bits_to_text(bits: List[int]) -> str: |
| | return bits_to_bytes(bits).decode("utf-8", errors="replace") |
| |
|
| |
|
| | def infer_text( |
| | model: "BitTransformerLM", |
| | text: str, |
| | c_floor: float = 0.3, |
| | s_floor: float = 0.5, |
| | ) -> str: |
| | """Run text through the model using the safety gate.""" |
| | from .safety import hil_safe_inference |
| | bits = text_to_bits(text) |
| | tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
| | out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor) |
| | return bits_to_text(out_bits.squeeze(0).tolist()) |
| |
|
| |
|
| | def sample_text( |
| | model: "BitTransformerLM", |
| | prompt: str, |
| | max_new_tokens: int = 16, |
| | temperature: float = 1.0, |
| | top_p: float = 1.0, |
| | ) -> str: |
| | """Generate text from the model using simple top-p sampling.""" |
| | model.eval() |
| | bits = text_to_bits(prompt) |
| | tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
| | for _ in range(max_new_tokens * 9): |
| | if tensor.size(1) >= model.pos_enc.pe.size(0): |
| | break |
| | logits, _ = model(tensor, causal=True) |
| | prob = logits[0, -1].softmax(-1) / temperature |
| | sorted_prob, sorted_idx = prob.sort(descending=True) |
| | cumulative = sorted_prob.cumsum(0) |
| | mask = cumulative > top_p |
| | sorted_prob[mask] = 0 |
| | sorted_prob = sorted_prob / sorted_prob.sum() |
| | next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)] |
| | tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1) |
| | return bits_to_text(tensor.squeeze(0).tolist()) |
| |
|