OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
Raw
History Blame Contribute Delete
171 kB
import json
import hashlib
import random
import site
import string
import sys
import unicodedata
from dataclasses import dataclass
from pathlib import Path
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
if np is not None and not hasattr(np, "asarray"):
np = None
from .checkpoint import read_safetensor_file, write_safetensor_file
from .config import ReframrConfig
from .embeddings import EmbeddingModel, fit_ppmi_embedding_from_tokens
from .hippo import AnalyticalMemoryUnit, analytical_embedding_drive, analytical_embedding_drive_fast
from .linalg import Vector, dot, mean, norm, softmax, zeros_vector
from .reservoir import apply_readout, ridge_regression_readout
from .reasoning import reasoning_prefix
from .ternary import apply_ternary_mask, derive_ternary_mask_from_states
from .tokenizer import NativeTokenizer
ASSOCIATIVE_BLEND = 0.42
TRANSITION_BLEND = 0.08
COPY_BLEND = 0.04
BASE_BLEND = 0.34
FAST_ASSOCIATIVE_BLEND = 0.06
FAST_TRANSITION_BLEND = 0.14
FAST_COPY_BLEND = 0.04
FAST_BASE_BLEND = 0.58
FAST_PREFERENCE_BLEND = 0.15
FAST_ANSWER_BLEND = 0.30
PROMPT_READOUT_LOGIT_ZSCORE_SCALE = 0.48
ASSOCIATIVE_TOP_K = 12
ANSWER_TOP_K = 48
ANSWER_START_TOP_K = 32
ANSWER_SEQUENCE_MATCH_FLOOR = 0.30
ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR = 0.45
ANSWER_SEQUENCE_LOCK_FLOOR = 0.55
ANSWER_SEQUENCE_SPIKE_CONFIDENCE = 0.80
READOUT_LOGIT_ZSCORE_SCALE = 0.22
TRACE_IDENTITY_SCALE = 0.78
TRACE_IDENTITY_HASHES = (
(1103515245, 12345, 214013, 2531011),
(1664525, 1013904223, 22695477, 1),
(69069, 362437, 134775813, 17),
(134775813, 97, 1103515245, 31),
(22695477, 911, 1664525, 73),
(214013, 2531011, 69069, 19),
(48271, 0, 69621, 11),
(16807, 37, 40692, 101),
(279470273, 173, 1299709, 53),
(39916801, 29, 2147483629, 7),
)
NGRAM_KEY_SEPARATOR = "\u0001"
TRANSITION_ORDERS = (10, 8, 6, 5, 4, 3, 2, 1)
DEFAULT_GENERATION_TEMPERATURE = 0.82
DEFAULT_GENERATION_TOP_K = 24
DEFAULT_GENERATION_TOP_P = 0.92
DEFAULT_REPETITION_PENALTY = 1.18
ANSWER_SEQUENCE_MAX_TOKENS = 192
RUNTIME_ARRAY_DTYPE = np.float32 if np is not None else None
@dataclass(frozen=True, slots=True)
class CharacterCountFact:
character: str
word: str
count: int
surface_seed: int
def _normalize_vector(values: Vector) -> Vector:
total = sum(values)
if total <= 0.0:
return [0.0 for _ in values]
return [value / total for value in values]
def _encode_ngram_key(tokens: tuple[str, ...]) -> str:
return NGRAM_KEY_SEPARATOR.join(tokens)
def _decode_ngram_key(key: str) -> tuple[str, ...]:
return tuple(part for part in key.split(NGRAM_KEY_SEPARATOR) if part)
def _last_index(values: list[str], target: str) -> int | None:
for index in range(len(values) - 1, -1, -1):
if values[index] == target:
return index
return None
@dataclass(slots=True)
class DecodeState:
hidden_states: list[Vector]
context_traces: list[Vector]
combined_state: Vector
context_tokens: list[str]
answer_anchor_state: Vector | None = None
answer_matches: list[tuple[float, int, int]] | None = None
answer_start_matches: list[tuple[float, int, int]] | None = None
answer_sequence_matches: list[tuple[float, int, int]] | None = None
prompt_answer_prior: object | None = None
prompt_answer_start_prior: object | None = None
@dataclass(slots=True)
class ReframrModel:
config: ReframrConfig
tokenizer: NativeTokenizer | None = None
embedding_model: EmbeddingModel | None = None
memory_units: list[AnalyticalMemoryUnit] | None = None
ternary_scale: float = 1.0
ternary_mask: list[int] | None = None
ternary_mask_array: object | None = None
readout_weights: list[list[float]] | None = None
readout_weights_array: object | None = None
readout_bias: Vector | None = None
readout_bias_array: object | None = None
prompt_answer_weights: list[list[float]] | None = None
prompt_answer_weights_array: object | None = None
prompt_answer_bias: Vector | None = None
prompt_answer_bias_array: object | None = None
prompt_answer_start_weights: list[list[float]] | None = None
prompt_answer_start_weights_array: object | None = None
prompt_answer_start_bias: Vector | None = None
prompt_answer_start_bias_array: object | None = None
trace_token_weights: Vector | None = None
trace_token_weights_array: object | None = None
trace_embedding_table_array: object | None = None
preference_bias: Vector | None = None
preference_bias_array: object | None = None
preference_valid_mask_array: object | None = None
state_offset: Vector | None = None
state_offset_array: object | None = None
associative_keys: list[Vector] | None = None
associative_keys_array: object | None = None
associative_key_norms: list[float] | None = None
associative_key_norms_array: object | None = None
associative_values: list[int] | None = None
associative_values_array: object | None = None
associative_valid_mask_array: object | None = None
answer_keys: list[Vector] | None = None
answer_keys_array: object | None = None
answer_key_norms: list[float] | None = None
answer_key_norms_array: object | None = None
answer_similarity_keys_array: object | None = None
answer_similarity_key_norms_array: object | None = None
answer_similarity_mask_array: object | None = None
answer_values: list[int] | None = None
answer_values_array: object | None = None
answer_valid_mask_array: object | None = None
answer_start_keys: list[Vector] | None = None
answer_start_keys_array: object | None = None
answer_start_key_norms: list[float] | None = None
answer_start_key_norms_array: object | None = None
answer_start_similarity_keys_array: object | None = None
answer_start_similarity_key_norms_array: object | None = None
answer_start_values: list[int] | None = None
answer_start_values_array: object | None = None
answer_start_valid_mask_array: object | None = None
answer_sequence_keys: list[Vector] | None = None
answer_sequence_keys_array: object | None = None
answer_sequence_key_norms: list[float] | None = None
answer_sequence_key_norms_array: object | None = None
answer_sequence_similarity_keys_array: object | None = None
answer_sequence_similarity_key_norms_array: object | None = None
answer_sequence_prompt_tokens: list[list[int]] | None = None
answer_sequence_prompt_tokens_array: object | None = None
answer_sequence_tokens: list[list[int]] | None = None
answer_sequence_tokens_array: object | None = None
answer_sequence_prompt_weight_maps: list[dict[int, float]] | None = None
answer_sequence_prompt_weight_norms: list[float] | None = None
answer_sequence_prompt_bigram_sets: list[set[tuple[int, int]]] | None = None
answer_sequence_prompt_trigram_sets: list[set[tuple[int, int, int]]] | None = None
answer_sequence_prompt_number_sets: list[set[str]] | None = None
answer_sequence_prompt_inverted_index: dict[int, list[int]] | None = None
answer_sequence_prompt_specificity: dict[int, float] | None = None
transition_tables: dict[int, dict[tuple[str, ...], dict[str, float]]] | None = None
def fit(self, text: str) -> "ReframrModel":
self.tokenizer = NativeTokenizer.train(
text,
vocab_size=self.config.tokenizer_vocab_size,
min_pair_frequency=self.config.tokenizer_min_pair_frequency,
lowercase=self.config.lowercase,
)
tokens = self.tokenizer.encode(text)
if len(tokens) < 2:
raise ValueError("REFRAMR needs at least two tokens to derive a next-token readout.")
self.embedding_model = fit_ppmi_embedding_from_tokens(
tokens,
embedding_dim=self.config.embedding_dim,
window_size=self.config.window_size,
min_frequency=self.config.min_frequency,
max_vocab=self.config.max_vocab,
)
self.memory_units = [
AnalyticalMemoryUnit(self.config.state_dim, timescale)
for timescale in self.config.timescales
]
token_counts: dict[str, float] = {}
for token in tokens:
token_counts[token] = token_counts.get(token, 0.0) + 1.0
self.trace_token_weights = self._derive_trace_token_weights_from_counts(token_counts)
raw_states, targets, target_ids = self._collect_training_examples(tokens)
self.ternary_scale, self.ternary_mask = derive_ternary_mask_from_states(raw_states)
analytical_states = [
apply_ternary_mask(state, self.ternary_mask, self.ternary_scale)
for state in raw_states
]
self.associative_keys = [state[:] for state in analytical_states]
self.associative_key_norms = [norm(state) for state in analytical_states]
self.associative_values = target_ids[:]
self.answer_keys = []
self.answer_key_norms = []
self.answer_values = []
self.answer_start_keys = []
self.answer_start_key_norms = []
self.answer_start_values = []
self.answer_sequence_keys = []
self.answer_sequence_key_norms = []
self.answer_sequence_prompt_tokens = []
self.answer_sequence_tokens = []
self.prompt_answer_weights = []
self.prompt_answer_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.prompt_answer_start_weights = []
self.prompt_answer_start_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.transition_tables = self._build_transition_tables(tokens)
self._fit_answer_memory_from_text(text)
self.readout_weights = ridge_regression_readout(
analytical_states,
targets,
regularization=self.config.regularization,
)
self.readout_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.preference_bias = [0.0 for _ in self.embedding_model.id_to_token]
self.state_offset = [0.0 for _ in analytical_states[0]] if analytical_states else []
self._refresh_numeric_caches()
return self
def _fit_answer_memory_from_text(self, text: str) -> None:
assert self.tokenizer is not None
assert self.embedding_model is not None
if (
self.answer_keys is None
or self.answer_key_norms is None
or self.answer_values is None
or self.answer_start_keys is None
or self.answer_start_key_norms is None
or self.answer_start_values is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_prompt_tokens is None
or self.answer_sequence_tokens is None
):
return
for line in text.splitlines():
if "<answer>" not in line:
continue
prompt_text, answer_text = line.split("<answer>", 1)
prompt_text = prompt_text.strip()
answer_text = answer_text.strip()
if not prompt_text or not answer_text:
continue
prompt_tokens = self.tokenizer.encode(prompt_text) + ["<answer>"]
answer_tokens = [
token
for token in self.tokenizer.encode(answer_text)
if token in self.embedding_model.token_to_id
and token not in self.tokenizer.special_tokens
]
if not prompt_tokens or not answer_tokens:
continue
key = self._encode_context(prompt_tokens)
key_norm = norm(key)
if key_norm <= 0.0:
continue
answer_ids = [
self.embedding_model.token_to_id[token]
for token in answer_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
]
prompt_ids = [
self.embedding_model.token_to_id[token]
for token in prompt_tokens[:ANSWER_SEQUENCE_MAX_TOKENS]
if token in self.embedding_model.token_to_id
and token not in self.tokenizer.special_tokens
]
if not answer_ids:
continue
self.answer_keys.append(key[:])
self.answer_key_norms.append(key_norm)
self.answer_values.append(answer_ids[0])
self.answer_start_keys.append(key[:])
self.answer_start_key_norms.append(key_norm)
self.answer_start_values.append(answer_ids[0])
self.answer_sequence_keys.append(key[:])
self.answer_sequence_key_norms.append(key_norm)
self.answer_sequence_prompt_tokens.append(
prompt_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(prompt_ids))]
)
self.answer_sequence_tokens.append(
answer_ids
+ [-1 for _ in range(ANSWER_SEQUENCE_MAX_TOKENS - len(answer_ids))]
)
def predict_next_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
probabilities = self.predict_next_token_distribution(
context,
reasoning_mode=reasoning_mode,
)
distribution: dict[str, float] = {}
for token, probability in probabilities.items():
rendered = self._render_token(token)
distribution[rendered] = distribution.get(rendered, 0.0) + probability
return distribution
def predict_next_token_distribution(
self,
context: str,
*,
reasoning_mode: str | None = None,
) -> dict[str, float]:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.readout_weights is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
return self._predict_next_token_distribution_from_tokens(context_tokens)
def generate_text(
self,
context: str,
*,
max_tokens: int = 64,
reasoning_mode: str | None = None,
temperature: float = 0.0,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
) -> str:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
return character_count_response
self._require_fit()
self._ensure_numeric_caches()
assert self.tokenizer is not None
if (
np is not None
and self.readout_weights_array is not None
and self.embedding_model is not None
and len(self.embedding_model.id_to_token) >= 1024
):
return self._generate_text_fast(
context,
max_tokens=max_tokens,
reasoning_mode=reasoning_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
distribution, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
return self._decode_tokens(generated_tokens)
@staticmethod
def _character_count_fact(context: str) -> CharacterCountFact | None:
normalized = unicodedata.normalize("NFKC", context).strip()
tokens = ReframrModel._character_count_word_tokens(normalized)
if not tokens:
return None
lowered = [token.casefold() for token in tokens]
count_terms = {"count", "counts", "counting", "many"}
unit_terms = {"character", "characters", "letter", "letters"}
if not any(token in count_terms for token in lowered):
return None
if not any(token in unit_terms for token in lowered) and "count" not in lowered:
return None
filler_terms = {"a", "an", "the", "single", "one", "please"}
word_markers = {"in", "inside"}
char_index = ReframrModel._character_count_target_index(
lowered,
unit_terms=unit_terms,
filler_terms=filler_terms,
)
word_index = ReframrModel._character_count_word_index(
lowered,
char_index=char_index,
filler_terms=filler_terms,
word_markers=word_markers,
)
if char_index is None or word_index is None:
return None
character = tokens[char_index]
word = tokens[word_index]
if len(character) != 1 or not word:
return None
order_offset = 0 if char_index < word_index else 1
surface_seed = ((char_index + 1) * 7 + (word_index + 1) * 3 + len(tokens) + order_offset) % 4
return CharacterCountFact(
character=character,
word=word,
count=word.casefold().count(character.casefold()),
surface_seed=surface_seed,
)
@staticmethod
def _character_count_word_tokens(text: str) -> list[str]:
tokens: list[str] = []
current: list[str] = []
for character in text:
if character != "_" and character.isalnum():
current.append(character)
continue
if current:
tokens.append("".join(current))
current = []
if current:
tokens.append("".join(current))
return tokens
@staticmethod
def _character_count_target_index(
tokens: list[str],
*,
unit_terms: set[str],
filler_terms: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token not in unit_terms:
continue
for adjacent in (index - 1, index + 1):
if 0 <= adjacent < len(tokens) and len(tokens[adjacent]) == 1:
return adjacent
before = ReframrModel._nearest_content_index(tokens, index - 1, -1, filler_terms)
after = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
for candidate in (before, after):
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
for index, token in enumerate(tokens):
if token not in {"count", "counts", "counting"}:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] in unit_terms:
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and len(tokens[candidate]) == 1:
return candidate
return None
@staticmethod
def _character_count_word_index(
tokens: list[str],
*,
char_index: int | None,
filler_terms: set[str],
word_markers: set[str],
) -> int | None:
for index, token in enumerate(tokens):
if token != "word":
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
for index, token in enumerate(tokens):
if token not in word_markers:
continue
candidate = ReframrModel._nearest_content_index(tokens, index + 1, 1, filler_terms)
if candidate is not None and tokens[candidate] == "word":
candidate = ReframrModel._nearest_content_index(tokens, candidate + 1, 1, filler_terms)
if candidate is not None and candidate != char_index and len(tokens[candidate]) > 1:
return candidate
skipped_terms = {
"how",
"many",
"count",
"counts",
"counting",
"letter",
"letters",
"character",
"characters",
"word",
"there",
"are",
"is",
"appear",
"appears",
"times",
} | filler_terms | word_markers
for index in range(len(tokens) - 1, -1, -1):
if index == char_index:
continue
if len(tokens[index]) <= 1 or tokens[index] in skipped_terms:
continue
return index
return None
@staticmethod
def _nearest_content_index(
tokens: list[str],
start: int,
direction: int,
skipped_terms: set[str],
) -> int | None:
index = start
while 0 <= index < len(tokens):
if tokens[index] not in skipped_terms:
return index
index += direction
return None
@classmethod
def _character_count_response(cls, context: str, *, temperature: float = 0.0) -> str | None:
fact = cls._character_count_fact(context)
if fact is None:
return None
return cls._render_character_count_fact(fact, temperature=temperature)
@staticmethod
def _render_character_count_fact(fact: CharacterCountFact, *, temperature: float = 0.0) -> str:
character_label = f"'{fact.character}'"
word_label = f"'{fact.word}'"
character_noun = "character" if fact.count == 1 else "characters"
plural_times = "" if fact.count == 1 else "s"
surfaces = (
f"There {'is' if fact.count == 1 else 'are'} {fact.count} {character_label} {character_noun} in {word_label}.",
f"{word_label} contains {fact.count} {character_label} {character_noun}.",
f"In {word_label}, {character_label} appears {fact.count} time{plural_times}.",
f"The count is {fact.count} for {character_label} in {word_label}.",
)
if temperature > 0.0:
return surfaces[(random.randrange(len(surfaces)) + fact.surface_seed) % len(surfaces)]
return surfaces[fact.surface_seed % len(surfaces)]
def _generate_text_fast(
self,
context: str,
*,
max_tokens: int,
reasoning_mode: str | None,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> str:
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
_, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
generated_tokens: list[str] = []
for _ in range(max_tokens):
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
if self._should_stop_answer_sequence(decode_state, generated_tokens):
break
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
probabilities, _ = self._score_next_token_array_from_state(
decode_state,
include_associative=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token_from_array(
probabilities,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=self._answer_decode_has_continuation(
decode_state,
generated_tokens,
),
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
overflow_budget -= 1
return self._decode_tokens(generated_tokens)
def trace_next_token(
self,
context: str,
*,
reasoning_mode: str | None = None,
top_k: int = 5,
) -> dict[str, object]:
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
context_tokens = reasoning_prefix(active_mode) + self.tokenizer.encode(context)
_, trace = self._score_next_token_from_tokens(
context_tokens,
top_k=top_k,
include_trace=True,
)
trace.update(
{
"context": context,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"context_tokens": context_tokens,
}
)
return trace
def trace_generation(
self,
context: str,
*,
max_tokens: int = 16,
reasoning_mode: str | None = None,
top_k: int = 5,
temperature: float = 0.0,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
) -> dict[str, object]:
character_count_response = self._character_count_response(
context,
temperature=temperature,
)
if character_count_response is not None:
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt = context if "<answer>" in context else f"{context} <answer>"
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": [],
"generated_tokens": [],
"generated_text": character_count_response,
"generated_token_count": len(character_count_response.split()),
"steps": [],
"reasoning_summary": (
"The prompt matched the generic character-counting path, so Reframr "
"read the requested character and word from the prompt and counted "
"the characters directly."
),
}
self._require_fit()
assert self.tokenizer is not None
active_mode = reasoning_mode or self.config.default_reasoning_profile
prompt, context_tokens = self._generation_prompt_tokens(context, active_mode)
decode_state = self._build_decode_state(context_tokens)
prompt_tokens = decode_state.context_tokens[:]
generated_tokens: list[str] = []
steps: list[dict[str, object]] = []
for step_index in range(1, max_tokens + 1):
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if not next_token:
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = step_index
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
if self._should_stop_generation(
generated_tokens
) and not self._answer_decode_has_continuation(decode_state, generated_tokens):
break
overflow_budget = 6
while (
generated_tokens
and not self._starts_new_word(generated_tokens[-1])
and overflow_budget > 0
):
distribution, trace = self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=True,
generated_tokens=generated_tokens,
)
next_token = self._select_generation_token(
distribution,
context_tokens=decode_state.context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=max(DEFAULT_GENERATION_TOP_K, top_k),
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if not next_token or self._starts_new_word(next_token):
break
generated_tokens.append(next_token)
self._advance_decode_state(decode_state, next_token)
trace["step"] = len(steps) + 1
trace["chosen_token"] = next_token
trace["chosen_text"] = self._render_token(next_token)
trace["chosen_probability"] = distribution[next_token]
steps.append(trace)
overflow_budget -= 1
return {
"context": context,
"prompt": prompt,
"reasoning_mode": active_mode,
"reasoning_tokens": reasoning_prefix(active_mode),
"generation_policy": {
"temperature": temperature,
"top_k": max(DEFAULT_GENERATION_TOP_K, top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"generated_text": self._decode_tokens(generated_tokens),
"generated_token_count": len(generated_tokens),
"steps": steps,
}
def _generation_prompt_tokens(self, context: str, active_mode: str) -> tuple[str, list[str]]:
assert self.tokenizer is not None
prompt = context if "<answer>" in context else f"{context} <answer>"
prefix = reasoning_prefix(active_mode)
prompt_tokens = self.tokenizer.encode(prompt)
if (
"<answer>" in prompt_tokens
and "<reason>" not in prompt_tokens
and "<reason>" not in prefix
):
prompt_tokens = ["<reason>"] + prompt_tokens
return prompt, prefix + prompt_tokens
def _predict_next_token_distribution_from_tokens(
self,
context_tokens: list[str],
) -> dict[str, float]:
decode_state = self._build_decode_state(context_tokens)
return self._predict_next_token_distribution_from_state(decode_state)
def _predict_next_token_distribution_from_state(
self,
decode_state: DecodeState,
) -> dict[str, float]:
probabilities, _ = self._score_next_token_from_state(
decode_state,
include_trace=False,
)
return probabilities
@staticmethod
def _answer_sequence_should_lock(
*,
answer_sequence_confidence: float,
answer_sequence_match_confidence: float,
has_answer_sequence_prior: bool,
) -> bool:
if not has_answer_sequence_prior or answer_sequence_confidence <= 0.0:
return False
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return True
return (
answer_sequence_match_confidence >= ANSWER_SEQUENCE_DISTRIBUTED_LOCK_FLOOR
and answer_sequence_confidence <= ANSWER_SEQUENCE_SPIKE_CONFIDENCE
)
@staticmethod
def _answer_start_blend_weights(
*,
answer_sequence_match_confidence: float,
) -> dict[str, float]:
if answer_sequence_match_confidence >= ANSWER_SEQUENCE_LOCK_FLOOR:
return {
"prompt_answer_start": 0.35,
"prompt_answer": 0.10,
"answer_sequence": 0.45,
"answer_start": 0.10,
}
return {
"prompt_answer_start": 0.55,
"prompt_answer": 0.20,
"answer_sequence": 0.15,
"answer_start": 0.10,
}
def _score_next_token_from_tokens(
self,
context_tokens: list[str],
*,
top_k: int = 5,
include_trace: bool = True,
) -> tuple[dict[str, float], dict[str, object]]:
decode_state = self._build_decode_state(context_tokens)
return self._score_next_token_from_state(
decode_state,
top_k=top_k,
include_trace=include_trace,
)
def _score_next_token_from_state(
self,
decode_state: DecodeState,
*,
top_k: int = 5,
include_trace: bool = True,
generated_tokens: list[str] | None = None,
) -> tuple[dict[str, float], dict[str, object]]:
assert self.embedding_model is not None
assert self.readout_weights is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state(decode_state)
logits = self._apply_readout_fast(state)
base_probabilities = self._calibrated_softmax(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_TOP_K, top_k) if include_trace else ANSWER_TOP_K,
)
answer_matches = decode_state.answer_matches
if decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_start_matches = decode_state.answer_start_matches
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
limit=max(ANSWER_START_TOP_K, top_k) if include_trace else ANSWER_START_TOP_K,
)
answer_sequence_matches = decode_state.answer_sequence_matches
answer_prior = self._answer_prior_from_matches(answer_matches, generated_tokens)
answer_start_prior = self._answer_prior_from_matches(answer_start_matches, generated_tokens)
answer_sequence_prior = self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
)
answer_sequence_confidence = max(answer_sequence_prior) if answer_sequence_prior else 0.0
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
has_answer_sequence_prior = any(value > 0.0 for value in answer_sequence_prior)
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
)
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = (
decode_state.prompt_answer_start_prior
if not generated_tokens
else [0.0 for _ in self.embedding_model.id_to_token]
)
if not generated_tokens and prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = decode_state.prompt_answer_start_prior
use_answer_start = (
not generated_tokens
and (
any(value > 0.0 for value in answer_start_prior)
or any(value > 0.0 for value in prompt_answer_start_prior)
)
)
if answer_locked:
answer_prior = answer_sequence_prior
elif use_answer_start:
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence
)
answer_prior = self._weighted_prior_sum(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
elif any(value > 0.0 for value in answer_sequence_prior):
answer_prior = self._weighted_prior_sum(
[
(0.50, prompt_answer_prior),
(0.30, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif any(value > 0.0 for value in prompt_answer_prior):
answer_prior = self._weighted_prior_sum(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
associative_matches = (
[]
if use_answer_start
else self._score_associative_matches(
state,
limit=max(ASSOCIATIVE_TOP_K, top_k) if include_trace else ASSOCIATIVE_TOP_K,
)
)
associative_prior = (
[0.0 for _ in self.embedding_model.id_to_token]
if use_answer_start
else self._associative_prior_from_matches(associative_matches)
)
transition_prior, transition_order = self._transition_prior_with_order(decode_state.context_tokens)
copy_prior = self._copy_prior(decode_state.context_tokens)
preference_prior = self._preference_prior()
probabilities, blend_weights = self._blend_probabilities(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
)
distribution = {
token: probabilities[index]
for index, token in enumerate(self.embedding_model.id_to_token)
}
if not include_trace:
return distribution, {}
trace = {
"state_norm": norm(state),
"blend_weights": blend_weights,
"transition_order": transition_order,
"base_top_predictions": self._top_entries_from_vector(base_probabilities, top_k),
"answer_top_predictions": self._top_entries_from_vector(answer_prior, top_k),
"prompt_answer_top_predictions": self._top_entries_from_vector(prompt_answer_prior, top_k),
"prompt_answer_start_top_predictions": self._top_entries_from_vector(prompt_answer_start_prior, top_k),
"answer_start_top_predictions": self._top_entries_from_vector(answer_start_prior, top_k),
"answer_sequence_top_predictions": self._top_entries_from_vector(answer_sequence_prior, top_k),
"associative_top_predictions": self._top_entries_from_vector(associative_prior, top_k),
"transition_top_predictions": self._top_entries_from_vector(transition_prior, top_k),
"copy_top_predictions": self._top_entries_from_vector(copy_prior, top_k),
"preference_top_predictions": self._top_entries_from_vector(preference_prior, top_k),
"final_top_predictions": self._top_entries_from_vector(probabilities, top_k),
"associative_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in associative_matches[:top_k]
],
"answer_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_matches[:top_k]
],
"answer_start_matches": [
{
"example_index": example_index,
"similarity": similarity,
**self._token_entry(token_id, similarity),
}
for similarity, token_id, example_index in answer_start_matches[:top_k]
],
"answer_sequence_matches": [
{
"example_index": example_index,
"similarity": similarity,
}
for similarity, _, example_index in answer_sequence_matches[:top_k]
],
"reasoning_summary": self._build_reasoning_summary(
transition_order,
blend_weights,
),
}
return distribution, trace
def _score_next_token_array_from_state(
self,
decode_state: DecodeState,
*,
include_associative: bool,
generated_tokens: list[str] | None = None,
) -> tuple[object, dict[str, float]]:
assert np is not None
assert self.embedding_model is not None
generated_tokens = generated_tokens or []
state = self._masked_decode_state_array(decode_state)
logits = self._apply_readout_array(state)
base_probabilities = self._calibrated_softmax_array(logits)
if decode_state.answer_matches is None:
decode_state.answer_matches = self._score_answer_matches(decode_state.answer_anchor_state)
answer_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_matches,
generated_tokens,
),
dtype=np.float64,
)
if decode_state.answer_sequence_matches is None:
decode_state.answer_sequence_matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
answer_sequence_matches = decode_state.answer_sequence_matches
answer_sequence_prior = np.asarray(
self._answer_sequence_prior_from_matches(
answer_sequence_matches,
generated_tokens,
),
dtype=np.float64,
)
answer_sequence_confidence = (
float(answer_sequence_prior.max()) if answer_sequence_prior.size else 0.0
)
answer_sequence_match_confidence = (
answer_sequence_matches[0][0] if answer_sequence_matches else 0.0
)
has_answer_sequence_prior = bool(np.any(answer_sequence_prior > 0.0))
answer_locked = self._answer_sequence_should_lock(
answer_sequence_confidence=answer_sequence_confidence,
answer_sequence_match_confidence=answer_sequence_match_confidence,
has_answer_sequence_prior=has_answer_sequence_prior,
)
if decode_state.prompt_answer_prior is None:
decode_state.prompt_answer_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=False,
)
prompt_answer_prior = decode_state.prompt_answer_prior
prompt_answer_start_prior = np.zeros_like(base_probabilities)
use_answer_start = False
if answer_locked:
answer_prior = answer_sequence_prior
elif not generated_tokens:
if decode_state.prompt_answer_start_prior is None:
decode_state.prompt_answer_start_prior = self._prompt_answer_readout_prior_array(
decode_state.answer_anchor_state,
start=True,
)
prompt_answer_start_prior = decode_state.prompt_answer_start_prior
if decode_state.answer_start_matches is None:
decode_state.answer_start_matches = self._score_answer_start_matches(
decode_state.answer_anchor_state
)
answer_start_prior = np.asarray(
self._answer_prior_from_matches(
decode_state.answer_start_matches,
generated_tokens,
),
dtype=np.float64,
)
if np.any(answer_start_prior > 0.0) or np.any(prompt_answer_start_prior > 0.0):
start_blend = self._answer_start_blend_weights(
answer_sequence_match_confidence=answer_sequence_match_confidence
)
answer_prior = self._weighted_prior_sum_array(
[
(start_blend["prompt_answer_start"], prompt_answer_start_prior),
(start_blend["prompt_answer"], prompt_answer_prior),
(start_blend["answer_sequence"], answer_sequence_prior),
(start_blend["answer_start"], answer_start_prior),
],
)
use_answer_start = True
if answer_locked:
answer_prior = answer_sequence_prior
elif not use_answer_start and np.any(answer_sequence_prior > 0.0):
answer_prior = self._weighted_prior_sum_array(
[
(0.50, prompt_answer_prior),
(0.30, answer_sequence_prior),
(0.20, answer_prior),
],
)
elif not use_answer_start and np.any(prompt_answer_prior > 0.0):
answer_prior = self._weighted_prior_sum_array(
[
(0.65, prompt_answer_prior),
(0.35, answer_prior),
],
)
if include_associative and not use_answer_start:
associative_prior = np.asarray(
self._associative_prior_from_matches(
self._score_associative_matches(state)
),
dtype=np.float64,
)
else:
associative_prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
transition_prior, transition_order = self._transition_prior_array_with_order(
decode_state.context_tokens
)
copy_prior = self._copy_prior_array(decode_state.context_tokens)
preference_prior = self._preference_prior_array()
return self._blend_probability_arrays(
base_probabilities,
answer_prior,
associative_prior,
transition_prior,
copy_prior,
preference_prior,
transition_order=transition_order,
generated_count=len(generated_tokens),
answer_locked=answer_locked,
answer_guided_start=use_answer_start,
)
def _calibrated_softmax(
self,
logits: Vector,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> Vector:
if np is not None:
return self._calibrated_softmax_array(
np.asarray(logits, dtype=np.float64),
scale=scale,
).tolist()
if not logits:
return []
center = mean(logits)
variance = mean([(value - center) * (value - center) for value in logits])
spread = variance**0.5
if spread <= 1e-12:
return softmax(logits)
calibrated = [
max(-20.0, min(20.0, ((value - center) / spread) * scale))
for value in logits
]
return softmax(calibrated)
def _calibrated_softmax_array(
self,
logits: object,
*,
scale: float = READOUT_LOGIT_ZSCORE_SCALE,
) -> object:
assert np is not None
values = np.asarray(logits, dtype=np.float64)
if values.size == 0:
return values
spread = float(values.std())
if spread > 1e-12:
values = ((values - float(values.mean())) / spread) * scale
values = np.clip(values, -20.0, 20.0)
else:
values = values - float(values.max())
values = values - float(values.max())
exponentials = np.exp(values)
total = float(exponentials.sum())
if total <= 0.0:
return np.full(values.shape, 1.0 / max(1, values.size), dtype=np.float64)
return exponentials / total
def _weighted_prior_sum(self, sources: list[tuple[float, Vector]]) -> Vector:
assert self.embedding_model is not None
active_sources = [
(weight, vector)
for weight, vector in sources
if weight > 0.0 and any(value > 0.0 for value in vector)
]
if not active_sources:
return [0.0 for _ in self.embedding_model.id_to_token]
total_weight = sum(weight for weight, _ in active_sources)
merged = [0.0 for _ in self.embedding_model.id_to_token]
for weight, vector in active_sources:
normalized_weight = weight / total_weight
for index, value in enumerate(vector):
merged[index] += normalized_weight * value
return _normalize_vector(merged)
def _weighted_prior_sum_array(self, sources: list[tuple[float, object]]) -> object:
assert np is not None
assert self.embedding_model is not None
active_sources = [
(weight, np.asarray(vector, dtype=np.float64))
for weight, vector in sources
if weight > 0.0 and np.any(np.asarray(vector, dtype=np.float64) > 0.0)
]
if not active_sources:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
total_weight = sum(weight for weight, _ in active_sources)
merged = np.zeros_like(active_sources[0][1], dtype=np.float64)
for weight, vector in active_sources:
merged += (weight / total_weight) * vector
total = float(merged.sum())
if total > 0.0:
merged /= total
return merged
def _prompt_answer_readout_prior(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> Vector:
assert self.embedding_model is not None
if answer_anchor_state is None:
return [0.0 for _ in self.embedding_model.id_to_token]
weights = self.prompt_answer_start_weights if start else self.prompt_answer_weights
bias = self.prompt_answer_start_bias if start else self.prompt_answer_bias
if np is not None:
return self._prompt_answer_readout_prior_array(
answer_anchor_state,
start=start,
).tolist()
if not weights:
return [0.0 for _ in self.embedding_model.id_to_token]
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
logits = apply_readout(weights, state)
if bias:
logits = [value + bias[index] for index, value in enumerate(logits)]
return self._calibrated_softmax(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def _prompt_answer_readout_prior_array(
self,
answer_anchor_state: Vector | None,
*,
start: bool,
) -> object:
assert np is not None
assert self.embedding_model is not None
if answer_anchor_state is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
weights = (
self.prompt_answer_start_weights_array
if start
else self.prompt_answer_weights_array
)
bias = self.prompt_answer_start_bias_array if start else self.prompt_answer_bias_array
if weights is None:
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
)
logits = weights @ state_array
if bias is not None and bias.shape == logits.shape:
logits = logits + bias
return self._calibrated_softmax_array(
logits,
scale=PROMPT_READOUT_LOGIT_ZSCORE_SCALE,
)
def save(self, path: str | Path) -> None:
self._require_fit()
assert self.tokenizer is not None
assert self.embedding_model is not None
assert self.ternary_mask is not None
assert self.readout_weights is not None
assert self.associative_keys is not None
assert self.associative_values is not None
assert self.transition_tables is not None
metadata = {
"schema_version": "1",
"checkpoint_kind": "reframr-analytical",
"tokenizer_name": self.tokenizer.name,
"config": json.dumps(self.config.to_dict(), separators=(",", ":")),
"tokenizer": json.dumps(self.tokenizer.to_dict(), separators=(",", ":")),
"embedding_id_to_token": json.dumps(self.embedding_model.id_to_token, separators=(",", ":")),
"tokenizer_vocab_size": str(self.tokenizer.vocab_size),
"transition_tables": json.dumps(self._serialize_transition_tables(), separators=(",", ":")),
}
tensors = {
"embedding_table": self.embedding_model.embeddings,
"ternary_scale": [self.ternary_scale],
"ternary_mask": self.ternary_mask,
"readout_weights": self.readout_weights,
"readout_bias": self.readout_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_weights": self.prompt_answer_weights
if self.prompt_answer_weights is not None
else [],
"prompt_answer_bias": self.prompt_answer_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"prompt_answer_start_weights": self.prompt_answer_start_weights
if self.prompt_answer_start_weights is not None
else [],
"prompt_answer_start_bias": self.prompt_answer_start_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"trace_token_weights": self.trace_token_weights
or [1.0 for _ in self.embedding_model.id_to_token],
"preference_bias": self.preference_bias
or [0.0 for _ in self.embedding_model.id_to_token],
"state_offset": self.state_offset
or [0.0 for _ in range(self._combined_state_width())],
"associative_keys": self.associative_keys,
"associative_values": self.associative_values,
"answer_keys": self.answer_keys if self.answer_keys is not None else [],
"answer_values": self.answer_values if self.answer_values is not None else [],
"answer_start_keys": self.answer_start_keys if self.answer_start_keys is not None else [],
"answer_start_values": self.answer_start_values if self.answer_start_values is not None else [],
"answer_sequence_keys": self.answer_sequence_keys if self.answer_sequence_keys is not None else [],
"answer_sequence_prompt_tokens": self.answer_sequence_prompt_tokens if self.answer_sequence_prompt_tokens is not None else [],
"answer_sequence_tokens": self.answer_sequence_tokens if self.answer_sequence_tokens is not None else [],
}
write_safetensor_file(path, tensors, metadata=metadata)
@classmethod
def load(cls, path: str | Path) -> "ReframrModel":
checkpoint_path = Path(path)
checkpoint = read_safetensor_file(
checkpoint_path,
arrays=np is not None and checkpoint_path.stat().st_size > 10_000_000,
)
metadata = checkpoint.metadata
config = ReframrConfig.from_dict(json.loads(metadata["config"]))
model = cls(config)
model.tokenizer = NativeTokenizer.from_dict(json.loads(metadata["tokenizer"]))
id_to_token = [str(token) for token in json.loads(metadata["embedding_id_to_token"])]
embedding_table = checkpoint.tensors["embedding_table"]
if np is not None and hasattr(embedding_table, "shape"):
embeddings = embedding_table.astype(float, copy=False)
else:
embeddings = [[float(value) for value in row] for row in embedding_table]
model.embedding_model = EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
model.memory_units = [
AnalyticalMemoryUnit(model.config.state_dim, timescale)
for timescale in model.config.timescales
]
model.ternary_scale = float(checkpoint.tensors["ternary_scale"][0])
model.ternary_mask = [int(value) for value in checkpoint.tensors["ternary_mask"]]
readout_tensor = checkpoint.tensors["readout_weights"]
model.readout_weights = (
readout_tensor.astype(float, copy=False)
if np is not None and hasattr(readout_tensor, "shape")
else [[float(value) for value in row] for row in readout_tensor]
)
readout_bias_tensor = checkpoint.tensors.get("readout_bias", [])
model.readout_bias = [
float(value) for value in (
readout_bias_tensor.tolist()
if hasattr(readout_bias_tensor, "tolist")
else readout_bias_tensor
)
]
if not model.readout_bias:
model.readout_bias = [0.0 for _ in id_to_token]
prompt_answer_tensor = checkpoint.tensors.get("prompt_answer_weights", [])
model.prompt_answer_weights = (
prompt_answer_tensor.astype(float, copy=False)
if np is not None
and hasattr(prompt_answer_tensor, "shape")
and len(prompt_answer_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_tensor]
)
prompt_answer_bias_tensor = checkpoint.tensors.get("prompt_answer_bias", [])
model.prompt_answer_bias = [
float(value) for value in (
prompt_answer_bias_tensor.tolist()
if hasattr(prompt_answer_bias_tensor, "tolist")
else prompt_answer_bias_tensor
)
]
if not model.prompt_answer_bias:
model.prompt_answer_bias = [0.0 for _ in id_to_token]
prompt_answer_start_tensor = checkpoint.tensors.get("prompt_answer_start_weights", [])
model.prompt_answer_start_weights = (
prompt_answer_start_tensor.astype(float, copy=False)
if np is not None
and hasattr(prompt_answer_start_tensor, "shape")
and len(prompt_answer_start_tensor.shape) == 2
else [[float(value) for value in row] for row in prompt_answer_start_tensor]
)
prompt_answer_start_bias_tensor = checkpoint.tensors.get("prompt_answer_start_bias", [])
model.prompt_answer_start_bias = [
float(value) for value in (
prompt_answer_start_bias_tensor.tolist()
if hasattr(prompt_answer_start_bias_tensor, "tolist")
else prompt_answer_start_bias_tensor
)
]
if not model.prompt_answer_start_bias:
model.prompt_answer_start_bias = [0.0 for _ in id_to_token]
trace_weight_tensor = checkpoint.tensors.get("trace_token_weights", [])
model.trace_token_weights = [
float(value) for value in (
trace_weight_tensor.tolist()
if hasattr(trace_weight_tensor, "tolist")
else trace_weight_tensor
)
]
if not model.trace_token_weights:
model.trace_token_weights = [
0.0 if token in model.tokenizer.special_tokens else 1.0
for token in id_to_token
]
preference_bias_tensor = checkpoint.tensors.get("preference_bias", [])
model.preference_bias = [
float(value) for value in (
preference_bias_tensor.tolist()
if hasattr(preference_bias_tensor, "tolist")
else preference_bias_tensor
)
]
if not model.preference_bias:
model.preference_bias = [0.0 for _ in id_to_token]
state_offset_tensor = checkpoint.tensors.get("state_offset", [])
model.state_offset = [
float(value) for value in (
state_offset_tensor.tolist()
if hasattr(state_offset_tensor, "tolist")
else state_offset_tensor
)
]
if not model.state_offset:
model.state_offset = [0.0 for _ in range(model._combined_state_width())]
associative_tensor = checkpoint.tensors.get("associative_keys", [])
model.associative_keys = (
associative_tensor.astype(float, copy=False)
if np is not None and hasattr(associative_tensor, "shape")
else [[float(value) for value in row] for row in associative_tensor]
)
if np is not None and hasattr(model.associative_keys, "shape"):
model.associative_key_norms = np.linalg.norm(model.associative_keys, axis=1).tolist()
else:
model.associative_key_norms = [norm(key) for key in model.associative_keys]
raw_associative_values = checkpoint.tensors.get("associative_values", [])
model.associative_values = [
int(value) for value in (
raw_associative_values.tolist()
if hasattr(raw_associative_values, "tolist")
else raw_associative_values
)
]
answer_tensor = checkpoint.tensors.get("answer_keys", [])
if np is not None and hasattr(answer_tensor, "shape"):
model.answer_keys = (
answer_tensor.astype(float, copy=False)
if len(answer_tensor.shape) == 2
else []
)
else:
model.answer_keys = [[float(value) for value in row] for row in answer_tensor]
if (
np is not None
and hasattr(model.answer_keys, "shape")
and len(model.answer_keys.shape) == 2
):
model.answer_key_norms = np.linalg.norm(model.answer_keys, axis=1).tolist()
else:
model.answer_key_norms = [norm(key) for key in model.answer_keys]
raw_answer_values = checkpoint.tensors.get("answer_values", [])
model.answer_values = [
int(value) for value in (
raw_answer_values.tolist()
if hasattr(raw_answer_values, "tolist")
else raw_answer_values
)
]
answer_start_tensor = checkpoint.tensors.get("answer_start_keys", [])
if np is not None and hasattr(answer_start_tensor, "shape"):
model.answer_start_keys = (
answer_start_tensor.astype(float, copy=False)
if len(answer_start_tensor.shape) == 2
else []
)
else:
model.answer_start_keys = [
[float(value) for value in row] for row in answer_start_tensor
]
if (
np is not None
and hasattr(model.answer_start_keys, "shape")
and len(model.answer_start_keys.shape) == 2
):
model.answer_start_key_norms = np.linalg.norm(model.answer_start_keys, axis=1).tolist()
else:
model.answer_start_key_norms = [norm(key) for key in model.answer_start_keys]
raw_answer_start_values = checkpoint.tensors.get("answer_start_values", [])
model.answer_start_values = [
int(value) for value in (
raw_answer_start_values.tolist()
if hasattr(raw_answer_start_values, "tolist")
else raw_answer_start_values
)
]
answer_sequence_tensor = checkpoint.tensors.get("answer_sequence_keys", [])
if np is not None and hasattr(answer_sequence_tensor, "shape"):
model.answer_sequence_keys = (
answer_sequence_tensor.astype(float, copy=False)
if len(answer_sequence_tensor.shape) == 2
else []
)
else:
model.answer_sequence_keys = [
[float(value) for value in row] for row in answer_sequence_tensor
]
if (
np is not None
and hasattr(model.answer_sequence_keys, "shape")
and len(model.answer_sequence_keys.shape) == 2
):
model.answer_sequence_key_norms = np.linalg.norm(
model.answer_sequence_keys,
axis=1,
).tolist()
else:
model.answer_sequence_key_norms = [norm(key) for key in model.answer_sequence_keys]
raw_answer_sequence_prompt_tokens = checkpoint.tensors.get("answer_sequence_prompt_tokens", [])
if np is not None and hasattr(raw_answer_sequence_prompt_tokens, "shape"):
model.answer_sequence_prompt_tokens = raw_answer_sequence_prompt_tokens.astype(int, copy=False)
else:
model.answer_sequence_prompt_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_prompt_tokens
]
raw_answer_sequence_tokens = checkpoint.tensors.get("answer_sequence_tokens", [])
if np is not None and hasattr(raw_answer_sequence_tokens, "shape"):
model.answer_sequence_tokens = raw_answer_sequence_tokens.astype(int, copy=False)
else:
model.answer_sequence_tokens = [
[int(value) for value in row] for row in raw_answer_sequence_tokens
]
model.transition_tables = model._deserialize_transition_tables(
json.loads(metadata.get("transition_tables", "{}"))
)
model._refresh_numeric_caches()
return model
def _collect_training_examples(
self,
tokens: list[str],
) -> tuple[list[Vector], list[Vector], list[int]]:
assert self.embedding_model is not None
if np is not None:
hidden_states = [
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
context_traces = [
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
zero_embedding: Vector | object = np.zeros(self.config.embedding_dim, dtype=np.float64)
else:
hidden_states = [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
context_traces = [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
zero_embedding = zeros_vector(self.config.embedding_dim)
states: list[Vector] = []
labels: list[Vector] = []
label_ids: list[int] = []
token_ids = [
self.embedding_model.token_to_id.get(token, -1)
for token in tokens
]
example_count = max(0, len(tokens) - 1)
stride = 1
if self.config.max_training_examples and example_count > self.config.max_training_examples:
stride = max(
1,
(example_count + self.config.max_training_examples - 1) // self.config.max_training_examples,
)
for index in range(len(tokens) - 1):
token = tokens[index]
token_id = token_ids[index]
embedding = (
self.embedding_model.embeddings[token_id]
if token_id >= 0
else zero_embedding
)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
hidden_states, context_traces, combined_state = self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
if stride > 1 and index % stride != 0 and index != len(tokens) - 2:
continue
states.append(combined_state)
next_token_id = token_ids[index + 1]
labels.append(self._one_hot_from_id(next_token_id))
label_ids.append(next_token_id)
if self.config.max_training_examples and len(states) > self.config.max_training_examples:
states = states[: self.config.max_training_examples]
labels = labels[: self.config.max_training_examples]
label_ids = label_ids[: self.config.max_training_examples]
return states, labels, label_ids
def _is_punctuation_piece(self, piece: str) -> bool:
return bool(piece) and all(character in string.punctuation for character in piece)
def _encode_context(self, tokens: list[str]) -> Vector:
return self._masked_decode_state(self._build_decode_state(tokens))
def _build_decode_state(self, tokens: list[str]) -> DecodeState:
assert self.memory_units is not None
state = DecodeState(
hidden_states=(
[
np.zeros(self.config.state_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.state_dim) for _ in self.config.timescales]
),
context_traces=(
[
np.zeros(self.config.embedding_dim, dtype=np.float64)
for _ in self.config.timescales
]
if np is not None
else [zeros_vector(self.config.embedding_dim) for _ in self.config.timescales]
),
combined_state=self._zero_combined_state(),
context_tokens=[],
)
for token in tokens:
self._advance_decode_state(state, token)
return state
def _advance_decode_state(self, state: DecodeState, token: str) -> DecodeState:
next_hidden_states, next_context_traces, combined_state = self._step_hidden_states(
state.hidden_states,
state.context_traces,
token,
)
state.hidden_states = next_hidden_states
state.context_traces = next_context_traces
state.combined_state = combined_state
state.context_tokens.append(token)
if token == "<answer>":
state.answer_anchor_state = combined_state.copy() if hasattr(combined_state, "copy") else combined_state[:]
state.answer_matches = None
state.answer_start_matches = None
state.answer_sequence_matches = None
state.prompt_answer_prior = None
state.prompt_answer_start_prior = None
return state
def _masked_decode_state(self, state: DecodeState) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(state.combined_state, self.ternary_mask, self.ternary_scale)
def _masked_combined_state(self, combined_state: Vector) -> Vector:
assert self.ternary_mask is not None
return apply_ternary_mask(combined_state, self.ternary_mask, self.ternary_scale)
def _masked_decode_state_array(self, state: DecodeState) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_decode_state(state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(state.combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _masked_combined_state_array(self, combined_state: Vector) -> object:
assert np is not None
if self.ternary_mask_array is None:
return np.asarray(self._masked_combined_state(combined_state), dtype=RUNTIME_ARRAY_DTYPE)
return (
np.asarray(combined_state, dtype=RUNTIME_ARRAY_DTYPE)
* self.ternary_scale
* self.ternary_mask_array
)
def _center_state_vector(self, state: Vector) -> Vector:
if not self.state_offset or len(self.state_offset) != len(state):
return state
return [value - self.state_offset[index] for index, value in enumerate(state)]
def _center_state_array(self, state: object) -> object:
assert np is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is None or self.state_offset_array.shape != state_array.shape:
return state_array
return state_array - self.state_offset_array
def _zero_combined_state(self) -> Vector:
return [0.0 for _ in range(self._combined_state_width())]
def _combined_state_width(self) -> int:
return (self.config.state_dim + self.config.embedding_dim) * len(self.config.timescales)
def _derive_trace_token_weights_from_counts(self, token_counts: dict[str, float]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
counts = [
float(token_counts.get(token, 0.0))
for token in self.embedding_model.id_to_token
]
positive_counts = sorted(value for value in counts if value > 0.0)
reference = (
positive_counts[len(positive_counts) // 2]
if positive_counts
else 1.0
)
weights: Vector = []
for token, count in zip(self.embedding_model.id_to_token, counts):
if token in self.tokenizer.special_tokens:
weights.append(0.0)
elif count <= 0.0:
weights.append(1.0)
else:
weight = (reference / count) ** 0.75
weights.append(max(0.08, min(4.8, weight)))
return weights
def _token_id_for_token(self, token: str) -> int:
assert self.embedding_model is not None
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None and token.lower() != token:
token_id = self.embedding_model.token_to_id.get(token.lower())
return int(token_id) if token_id is not None else -1
def _trace_embedding_from_token_id(
self,
embedding: Vector | object,
token_id: int,
) -> Vector | object:
if token_id < 0:
return embedding
if self.trace_embedding_table_array is not None:
return self.trace_embedding_table_array[token_id]
weight = self.trace_token_weights[token_id] if self.trace_token_weights is not None else 1.0
dimension = self.config.embedding_dim
if hasattr(embedding, "shape"):
trace_embedding = embedding * weight
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_embedding[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_embedding
trace_values = [float(value) * weight for value in embedding]
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
bucket = (token_id * bucket_multiplier + bucket_offset) % dimension
sign = 1.0 if ((token_id * sign_multiplier + sign_offset) & 1) == 0 else -1.0
trace_values[bucket] += weight * TRACE_IDENTITY_SCALE * sign
return trace_values
def _build_trace_embedding_table_array(self, embedding_array: object) -> object | None:
if np is None or self.trace_token_weights is None:
return None
values = np.asarray(embedding_array, dtype=np.float64)
if values.size == 0 or len(values.shape) != 2:
return None
weights = np.asarray(self.trace_token_weights, dtype=np.float64)
if weights.shape[0] != values.shape[0]:
return None
trace_values = values * weights[:, None]
if values.shape[1] <= 0:
return trace_values
token_ids = np.arange(values.shape[0], dtype=np.int64)
for bucket_multiplier, bucket_offset, sign_multiplier, sign_offset in TRACE_IDENTITY_HASHES:
buckets = ((token_ids * bucket_multiplier + bucket_offset) % values.shape[1]).astype(
np.int64,
copy=False,
)
signs = np.where(
((token_ids * sign_multiplier + sign_offset) & 1) == 0,
1.0,
-1.0,
)
np.add.at(trace_values, (token_ids, buckets), weights * TRACE_IDENTITY_SCALE * signs)
return trace_values
def _refresh_numeric_caches(self) -> None:
if np is None:
self.ternary_mask_array = None
self.readout_weights_array = None
self.readout_bias_array = None
self.prompt_answer_weights_array = None
self.prompt_answer_bias_array = None
self.prompt_answer_start_weights_array = None
self.prompt_answer_start_bias_array = None
self.trace_token_weights_array = None
self.trace_embedding_table_array = None
self.preference_bias_array = None
self.preference_valid_mask_array = None
self.state_offset_array = None
self.associative_keys_array = None
self.associative_key_norms_array = None
self.associative_values_array = None
self.associative_valid_mask_array = None
self.answer_keys_array = None
self.answer_key_norms_array = None
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
self.answer_values_array = None
self.answer_valid_mask_array = None
self.answer_start_keys_array = None
self.answer_start_key_norms_array = None
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
self.answer_start_values_array = None
self.answer_start_valid_mask_array = None
self.answer_sequence_keys_array = None
self.answer_sequence_key_norms_array = None
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
self.answer_sequence_prompt_tokens_array = None
self.answer_sequence_tokens_array = None
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self._refresh_answer_sequence_prompt_overlap_cache()
return
self.ternary_mask_array = (
np.asarray(self.ternary_mask, dtype=RUNTIME_ARRAY_DTYPE)
if self.ternary_mask is not None
else None
)
self.readout_weights_array = (
np.asarray(self.readout_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_weights is not None
else None
)
self.readout_bias_array = (
np.asarray(self.readout_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.readout_bias is not None
else None
)
self.prompt_answer_weights_array = (
np.asarray(self.prompt_answer_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_weights is not None
and len(self.prompt_answer_weights) > 0
else None
)
self.prompt_answer_bias_array = (
np.asarray(self.prompt_answer_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_bias is not None
else None
)
self.prompt_answer_start_weights_array = (
np.asarray(self.prompt_answer_start_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_weights is not None
and len(self.prompt_answer_start_weights) > 0
else None
)
self.prompt_answer_start_bias_array = (
np.asarray(self.prompt_answer_start_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.prompt_answer_start_bias is not None
else None
)
self.trace_token_weights_array = (
np.asarray(self.trace_token_weights, dtype=RUNTIME_ARRAY_DTYPE)
if self.trace_token_weights is not None
else None
)
trace_embedding_table = (
self._build_trace_embedding_table_array(self.embedding_model.embeddings)
if self.embedding_model is not None and self.trace_token_weights is not None
else None
)
self.trace_embedding_table_array = (
trace_embedding_table.astype(RUNTIME_ARRAY_DTYPE, copy=False)
if trace_embedding_table is not None
else None
)
self.preference_bias_array = (
np.asarray(self.preference_bias, dtype=RUNTIME_ARRAY_DTYPE)
if self.preference_bias is not None
else None
)
self.preference_valid_mask_array = (
np.asarray(
[
self._eligible_preference_token(token)
for token in self.embedding_model.id_to_token
],
dtype=bool,
)
if self.embedding_model is not None and self.tokenizer is not None
else None
)
self.state_offset_array = (
np.asarray(self.state_offset, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset is not None
else None
)
self.associative_keys_array = (
np.asarray(self.associative_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.associative_keys is not None and len(self.associative_keys) > 0
else None
)
self.associative_key_norms_array = (
np.asarray(self.associative_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.associative_key_norms is not None and len(self.associative_key_norms) > 0
else None
)
self.associative_values_array = (
np.asarray(self.associative_values, dtype=np.int64)
if self.associative_values is not None and len(self.associative_values) > 0
else None
)
self.associative_valid_mask_array = (
self.associative_values_array >= 0
if self.associative_values_array is not None
else None
)
self.answer_keys_array = (
np.asarray(self.answer_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_keys is not None and len(self.answer_keys) > 0
else None
)
self.answer_key_norms_array = (
np.asarray(self.answer_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_key_norms is not None and len(self.answer_key_norms) > 0
else None
)
self.answer_similarity_keys_array = None
self.answer_similarity_key_norms_array = None
self.answer_similarity_mask_array = None
if self.answer_keys_array is not None and len(self.answer_keys_array.shape) == 2:
width = int(self.answer_keys_array.shape[1])
block_width = self.config.state_dim + self.config.embedding_dim
expected_width = block_width * len(self.config.timescales)
if block_width > 0 and width == expected_width:
mask = np.zeros(width, dtype=RUNTIME_ARRAY_DTYPE)
for scale_index in range(len(self.config.timescales)):
start = scale_index * block_width + self.config.state_dim
end = start + self.config.embedding_dim
mask[start:end] = 1.0
self.answer_similarity_mask_array = mask
self.answer_similarity_keys_array = self.answer_keys_array * mask[None, :]
self.answer_similarity_key_norms_array = np.linalg.norm(
self.answer_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_values_array = (
np.asarray(self.answer_values, dtype=np.int64)
if self.answer_values is not None and len(self.answer_values) > 0
else None
)
self.answer_valid_mask_array = (
self.answer_values_array >= 0
if self.answer_values_array is not None
else None
)
self.answer_start_keys_array = (
np.asarray(self.answer_start_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_start_keys is not None and len(self.answer_start_keys) > 0
else None
)
self.answer_start_key_norms_array = (
np.asarray(self.answer_start_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_start_key_norms is not None and len(self.answer_start_key_norms) > 0
else None
)
self.answer_start_similarity_keys_array = None
self.answer_start_similarity_key_norms_array = None
if (
self.answer_start_keys_array is not None
and len(self.answer_start_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_start_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
self.answer_start_similarity_keys_array = (
self.answer_start_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_start_similarity_key_norms_array = np.linalg.norm(
self.answer_start_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_start_values_array = (
np.asarray(self.answer_start_values, dtype=np.int64)
if self.answer_start_values is not None and len(self.answer_start_values) > 0
else None
)
self.answer_start_valid_mask_array = (
self.answer_start_values_array >= 0
if self.answer_start_values_array is not None
else None
)
self.answer_sequence_keys_array = (
np.asarray(self.answer_sequence_keys, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_sequence_keys is not None and len(self.answer_sequence_keys) > 0
else None
)
self.answer_sequence_key_norms_array = (
np.asarray(self.answer_sequence_key_norms, dtype=RUNTIME_ARRAY_DTYPE)
if self.answer_sequence_key_norms is not None and len(self.answer_sequence_key_norms) > 0
else None
)
self.answer_sequence_similarity_keys_array = None
self.answer_sequence_similarity_key_norms_array = None
if (
self.answer_sequence_keys_array is not None
and len(self.answer_sequence_keys_array.shape) == 2
and self.answer_similarity_mask_array is not None
and int(self.answer_sequence_keys_array.shape[1]) == int(self.answer_similarity_mask_array.shape[0])
):
self.answer_sequence_similarity_keys_array = (
self.answer_sequence_keys_array * self.answer_similarity_mask_array[None, :]
)
self.answer_sequence_similarity_key_norms_array = np.linalg.norm(
self.answer_sequence_similarity_keys_array,
axis=1,
).astype(RUNTIME_ARRAY_DTYPE, copy=False)
self.answer_sequence_tokens_array = (
np.asarray(self.answer_sequence_tokens, dtype=np.int64)
if self.answer_sequence_tokens is not None and len(self.answer_sequence_tokens) > 0
else None
)
self.answer_sequence_prompt_tokens_array = (
np.asarray(self.answer_sequence_prompt_tokens, dtype=np.int64)
if self.answer_sequence_prompt_tokens is not None
and len(self.answer_sequence_prompt_tokens) > 0
else None
)
self._refresh_answer_sequence_prompt_overlap_cache()
def _refresh_answer_sequence_prompt_overlap_cache(self) -> None:
self.answer_sequence_prompt_weight_maps = None
self.answer_sequence_prompt_weight_norms = None
self.answer_sequence_prompt_bigram_sets = None
self.answer_sequence_prompt_trigram_sets = None
self.answer_sequence_prompt_number_sets = None
self.answer_sequence_prompt_inverted_index = None
self.answer_sequence_prompt_specificity = None
if self.answer_sequence_prompt_tokens is None or self.trace_token_weights is None:
return
inverted: dict[int, list[int]] = {}
row_id_lists: list[list[int]] = []
for row in self.answer_sequence_prompt_tokens:
row_values = row.tolist() if hasattr(row, "tolist") else row
row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
row_ids.append(token_id)
sequence_index = len(row_id_lists)
for token_id in set(row_ids):
inverted.setdefault(token_id, []).append(sequence_index)
row_id_lists.append(row_ids)
total_rows = len(row_id_lists)
specificity = {
token_id: self._prompt_overlap_token_specificity(len(indices), total_rows)
for token_id, indices in inverted.items()
}
self.answer_sequence_prompt_inverted_index = inverted
self.answer_sequence_prompt_specificity = specificity
weight_maps: list[dict[int, float]] = []
weight_norms: list[float] = []
bigram_sets: list[set[tuple[int, int]]] = []
trigram_sets: list[set[tuple[int, int, int]]] = []
number_sets: list[set[str]] = []
for row_ids in row_id_lists:
row_weights: dict[int, float] = {}
for token_id in row_ids:
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
float(self.trace_token_weights[token_id]) * specificity.get(token_id, 1.0),
)
weight_maps.append(row_weights)
weight_norms.append(sum(value * value for value in row_weights.values()) ** 0.5)
bigram_sets.append(
{
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
)
trigram_sets.append(
{
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
)
number_sets.append(self._number_strings_from_token_ids(row_ids))
self.answer_sequence_prompt_weight_maps = weight_maps
self.answer_sequence_prompt_weight_norms = weight_norms
self.answer_sequence_prompt_bigram_sets = bigram_sets
self.answer_sequence_prompt_trigram_sets = trigram_sets
self.answer_sequence_prompt_number_sets = number_sets
@staticmethod
def _prompt_overlap_token_specificity(document_frequency: int, total_documents: int) -> float:
if document_frequency <= 0 or total_documents <= 0:
return 1.0
coverage = min(1.0, document_frequency / total_documents)
return max(0.02, 1.0 - (coverage ** 0.5))
def _number_strings_from_token_ids(self, token_ids: list[int]) -> set[str]:
assert self.embedding_model is not None
tokens = [
self.embedding_model.id_to_token[token_id]
for token_id in token_ids
if 0 <= token_id < len(self.embedding_model.id_to_token)
]
return self._number_strings_from_tokens(tokens)
def _number_strings_from_tokens(self, tokens: list[str]) -> set[str]:
numbers: set[str] = set()
current = ""
for token in tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
if current:
numbers.add(current)
current = ""
continue
rendered = self._render_token(token)
digits = "".join(character for character in rendered if character.isdigit())
starts_number = self._starts_new_word(token) if self.tokenizer is not None else True
if digits and starts_number:
if current:
numbers.add(current)
current = digits
elif digits and current:
current += digits
else:
if current:
numbers.add(current)
current = ""
if current:
numbers.add(current)
return numbers
@staticmethod
def _numeric_prompt_can_match(query_numbers: set[str], row_numbers: set[str]) -> bool:
if not query_numbers:
return True
if not row_numbers:
return False
return query_numbers.issubset(row_numbers)
def _apply_readout_fast(self, state: Vector) -> Vector:
if self.readout_weights_array is None or np is None:
assert self.readout_weights is not None
centered_state = self._center_state_vector(state)
logits = apply_readout(self.readout_weights, centered_state)
if self.readout_bias:
logits = [
value + self.readout_bias[index]
for index, value in enumerate(logits)
]
return logits
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits.tolist()
def _apply_readout_array(self, state: object) -> object:
assert np is not None
assert self.readout_weights_array is not None
state_array = np.asarray(state, dtype=RUNTIME_ARRAY_DTYPE)
if self.state_offset_array is not None and self.state_offset_array.shape == state_array.shape:
state_array = state_array - self.state_offset_array
logits = self.readout_weights_array @ state_array
if self.readout_bias_array is not None and self.readout_bias_array.shape == logits.shape:
logits = logits + self.readout_bias_array
return logits
def _step_hidden_states(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
token: str,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.embedding_model is not None
assert self.tokenizer is not None
token_id = self._token_id_for_token(token)
embedding = self.embedding_model.vector(token)
trace_embedding = self._trace_embedding_from_token_id(embedding, token_id)
return self._step_hidden_states_from_embedding(
hidden_states,
context_traces,
embedding,
trace_embedding=trace_embedding,
)
def _step_hidden_states_from_embedding(
self,
hidden_states: list[Vector],
context_traces: list[Vector],
embedding: Vector | object,
*,
trace_embedding: Vector | object | None = None,
) -> tuple[list[Vector], list[Vector], Vector]:
assert self.memory_units is not None
if trace_embedding is None:
trace_embedding = embedding
if np is not None and hidden_states and hasattr(hidden_states[0], "shape"):
embedding_array = (
embedding
if hasattr(embedding, "shape")
else np.asarray(embedding, dtype=np.float64)
)
trace_embedding_array = (
trace_embedding
if hasattr(trace_embedding, "shape")
else np.asarray(trace_embedding, dtype=np.float64)
)
drive = analytical_embedding_drive_fast(embedding_array, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector_fast(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = trace + ((1.0 - decay) * trace_embedding_array)
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state.tolist())
combined_state.extend(next_trace.tolist())
return next_states, next_traces, combined_state
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else embedding
trace_embedding_vector = (
trace_embedding.tolist()
if hasattr(trace_embedding, "tolist")
else trace_embedding
)
drive = analytical_embedding_drive(embedding_vector, self.config.state_dim)
next_states: list[Vector] = []
next_traces: list[Vector] = []
combined_state: Vector = []
for unit, state, trace in zip(self.memory_units, hidden_states, context_traces):
next_state = unit.step_vector(state, drive)
decay = 1.0 / (1.0 + unit.timescale)
next_trace = [
previous + ((1.0 - decay) * value)
for previous, value in zip(trace, trace_embedding_vector)
]
next_states.append(next_state)
next_traces.append(next_trace)
combined_state.extend(next_state)
combined_state.extend(next_trace)
return next_states, next_traces, combined_state
def _one_hot(self, token: str) -> Vector:
assert self.embedding_model is not None
return self._one_hot_from_id(self.embedding_model.token_to_id.get(token, -1))
def _one_hot_from_id(self, token_id: int) -> Vector:
assert self.embedding_model is not None
vector = [0.0 for _ in self.embedding_model.id_to_token]
if token_id >= 0:
vector[token_id] = 1.0
return vector
def _blend_probabilities(
self,
base: Vector,
answer: Vector,
associative: Vector,
transition: Vector,
copy: Vector,
preference: Vector,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
) -> tuple[Vector, dict[str, float]]:
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
preference_weight = FAST_PREFERENCE_BLEND
if answer_locked:
base_weight *= 0.18
answer_weight *= 5.0
associative_weight *= 0.2
transition_weight *= 0.2
copy_weight *= 0.2
preference_weight *= 0.2
elif answer_guided_start:
base_weight *= 0.35
answer_weight *= 3.5
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
preference_weight *= 0.2
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, Vector]] = [("base", base_weight, base)]
if any(value > 0.0 for value in answer):
sources.append(("answer", answer_weight, answer))
if any(value > 0.0 for value in associative):
sources.append(("associative", associative_weight, associative))
if any(value > 0.0 for value in transition):
sources.append(("transition", transition_weight, transition))
if any(value > 0.0 for value in copy):
sources.append(("copy", copy_weight, copy))
if any(value > 0.0 for value in preference):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = [0.0 for _ in base]
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
for index, value in enumerate(source):
blended[index] += normalized_weight * value
return _normalize_vector(blended), blend_weights
def _blend_probability_arrays(
self,
base: object,
answer: object,
associative: object,
transition: object,
copy: object,
preference: object,
*,
transition_order: int | None,
generated_count: int = 0,
answer_locked: bool = False,
answer_guided_start: bool = False,
) -> tuple[object, dict[str, float]]:
assert np is not None
base_weight = FAST_BASE_BLEND
answer_weight = FAST_ANSWER_BLEND
associative_weight = FAST_ASSOCIATIVE_BLEND
transition_weight = FAST_TRANSITION_BLEND
copy_weight = FAST_COPY_BLEND
preference_weight = FAST_PREFERENCE_BLEND
if answer_locked:
base_weight *= 0.18
answer_weight *= 5.0
associative_weight *= 0.2
transition_weight *= 0.2
copy_weight *= 0.2
preference_weight *= 0.2
elif answer_guided_start:
base_weight *= 0.35
answer_weight *= 3.5
associative_weight *= 0.2
transition_weight *= 0.35
copy_weight *= 0.2
preference_weight *= 0.2
elif generated_count > 0:
answer_weight *= 0.32
transition_weight *= 2.0
copy_weight *= 0.75
if transition_order is None:
answer_weight *= 1.1
associative_weight *= 0.75
copy_weight += 0.02
elif transition_order <= 2:
answer_weight *= 1.15
associative_weight *= 0.65
transition_weight *= 0.55
copy_weight += 0.01
elif transition_order >= 5:
transition_weight *= 1.25
sources: list[tuple[str, float, object]] = [("base", base_weight, base)]
if np.any(answer > 0.0):
sources.append(("answer", answer_weight, answer))
if np.any(associative > 0.0):
sources.append(("associative", associative_weight, associative))
if np.any(transition > 0.0):
sources.append(("transition", transition_weight, transition))
if np.any(copy > 0.0):
sources.append(("copy", copy_weight, copy))
if np.any(preference > 0.0):
sources.append(("preference", preference_weight, preference))
total_weight = sum(weight for _, weight, _ in sources)
blended = np.zeros_like(base, dtype=np.float64)
blend_weights: dict[str, float] = {}
for name, weight, source in sources:
normalized_weight = weight / total_weight if total_weight else 0.0
blend_weights[name] = normalized_weight
blended += normalized_weight * source
total = float(blended.sum())
if total <= 0.0:
return base, blend_weights
return blended / total, blend_weights
def _score_associative_matches(
self,
state: Vector,
*,
limit: int = ASSOCIATIVE_TOP_K,
) -> list[tuple[float, int, int]]:
if (
self.associative_keys is None
or self.associative_values is None
or self.associative_key_norms is None
or len(self.associative_keys) == 0
or len(self.associative_values) == 0
or len(self.associative_key_norms) == 0
):
return []
if (
np is not None
and
self.associative_keys_array is not None
and self.associative_key_norms_array is not None
and self.associative_values_array is not None
and self.associative_valid_mask_array is not None
and limit > 0
):
state_array = self._center_state_array(state).astype(self.associative_keys_array.dtype, copy=False)
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = self.associative_keys_array @ state_array
denominators = self.associative_key_norms_array * state_norm
valid_mask = self.associative_valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=self.associative_keys_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(self.associative_values_array[position]),
int(position),
)
for position in ordered_positions
]
state = self._center_state_vector(state)
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(self.associative_keys, self.associative_key_norms, self.associative_values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _associative_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
for similarity, token_id, _ in matches[:ASSOCIATIVE_TOP_K]:
prior[token_id] += similarity
return _normalize_vector(prior)
def _associative_prior(self, state: Vector) -> Vector:
return self._associative_prior_from_matches(self._score_associative_matches(state))
def _score_answer_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_TOP_K,
) -> list[tuple[float, int, int]]:
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_keys,
self.answer_key_norms,
self.answer_values,
self.answer_keys_array,
self.answer_key_norms_array,
self.answer_values_array,
self.answer_valid_mask_array,
self.answer_similarity_keys_array,
self.answer_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
def _score_answer_start_matches(
self,
answer_anchor_state: Vector | None,
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
return self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_start_keys,
self.answer_start_key_norms,
self.answer_start_values,
self.answer_start_keys_array,
self.answer_start_key_norms_array,
self.answer_start_values_array,
self.answer_start_valid_mask_array,
self.answer_start_similarity_keys_array,
self.answer_start_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=limit,
)
def _score_answer_sequence_matches(
self,
answer_anchor_state: Vector | None,
context_tokens: list[str],
*,
limit: int = ANSWER_START_TOP_K,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or self.answer_sequence_keys is None
or self.answer_sequence_key_norms is None
or self.answer_sequence_tokens is None
):
return []
values = list(range(len(self.answer_sequence_tokens)))
values_array = np.arange(len(values), dtype=np.int64) if np is not None else None
anchor_matches = self._score_prompt_anchor_matches(
answer_anchor_state,
self.answer_sequence_keys,
self.answer_sequence_key_norms,
values,
self.answer_sequence_keys_array,
self.answer_sequence_key_norms_array,
values_array,
values_array >= 0 if values_array is not None else None,
self.answer_sequence_similarity_keys_array,
self.answer_sequence_similarity_key_norms_array,
self.answer_similarity_mask_array,
limit=max(limit * 4, limit),
)
overlap_scores = self._answer_sequence_prompt_overlap_scores(context_tokens)
if overlap_scores is None:
return anchor_matches[:limit]
if not overlap_scores:
return []
best_overlap = max(overlap_scores.values()) if overlap_scores else 0.0
overlap_floor = max(0.16, best_overlap * 0.90)
focused_overlap_scores = {
sequence_index: overlap
for sequence_index, overlap in overlap_scores.items()
if overlap >= overlap_floor
}
if not focused_overlap_scores:
focused_overlap_scores = overlap_scores
focused_indices = set(focused_overlap_scores)
merged: dict[int, float] = {}
for similarity, sequence_index, _ in anchor_matches:
if sequence_index not in focused_indices:
continue
merged[sequence_index] = max(merged.get(sequence_index, 0.0), 0.20 * similarity)
for sequence_index, overlap in focused_overlap_scores.items():
merged[sequence_index] = merged.get(sequence_index, 0.0) + (0.80 * overlap)
ranked = [
(score, sequence_index, sequence_index)
for sequence_index, score in merged.items()
if score > 0.0
]
ranked.sort(key=lambda item: item[0], reverse=True)
return ranked[:limit]
def _answer_sequence_prompt_overlap_scores(
self,
context_tokens: list[str],
) -> dict[int, float] | None:
if (
self.embedding_model is None
or self.answer_sequence_prompt_tokens is None
or self.trace_token_weights is None
):
return None
answer_boundary = _last_index(context_tokens, "<answer>")
prompt_tokens = (
context_tokens[:answer_boundary]
if answer_boundary is not None
else context_tokens
)
if self.answer_sequence_prompt_specificity is None:
self._refresh_answer_sequence_prompt_overlap_cache()
specificity_map = self.answer_sequence_prompt_specificity or {}
query_weights: dict[int, float] = {}
query_specificity: dict[int, float] = {}
query_content_weight = 0.0
query_ids: list[int] = []
for token in prompt_tokens:
if self.tokenizer is not None and token in self.tokenizer.special_tokens:
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
query_ids.append(token_id)
specificity = specificity_map.get(token_id, 1.0)
weight = specificity
query_weights[token_id] = max(
query_weights.get(token_id, 0.0),
weight,
)
query_specificity[token_id] = max(
query_specificity.get(token_id, 0.0),
specificity,
)
if specificity >= 0.20:
query_content_weight += weight
if not query_weights:
return None
query_norm = sum(value * value for value in query_weights.values()) ** 0.5
if query_norm <= 0.0:
return None
query_bigrams = {
(query_ids[index], query_ids[index + 1])
for index in range(len(query_ids) - 1)
}
query_trigrams = {
(query_ids[index], query_ids[index + 1], query_ids[index + 2])
for index in range(len(query_ids) - 2)
}
query_numbers = self._number_strings_from_tokens(prompt_tokens)
def ordered_ngram_score(
query_grams: set[tuple[int, ...]],
row_grams: set[tuple[int, ...]],
) -> float:
if not query_grams or not row_grams:
return 0.0
overlap = len(query_grams & row_grams)
if overlap <= 0:
return 0.0
return overlap / ((len(query_grams) * len(row_grams)) ** 0.5)
cached_maps = self.answer_sequence_prompt_weight_maps
cached_norms = self.answer_sequence_prompt_weight_norms
cached_bigrams = self.answer_sequence_prompt_bigram_sets
cached_trigrams = self.answer_sequence_prompt_trigram_sets
cached_numbers = self.answer_sequence_prompt_number_sets
cached_index = self.answer_sequence_prompt_inverted_index
if (
cached_maps is not None
and cached_norms is not None
and cached_bigrams is not None
and cached_trigrams is not None
and cached_numbers is not None
and len(cached_maps) == len(self.answer_sequence_prompt_tokens)
):
candidate_indices: set[int] | range
if cached_index is not None:
candidates: set[int] = set()
for token_id in query_weights:
candidates.update(cached_index.get(token_id, ()))
candidate_indices = candidates if candidates else range(len(cached_maps))
else:
candidate_indices = range(len(cached_maps))
candidate_indices = list(candidate_indices)
if cached_index is not None and candidate_indices:
candidate_set = set(candidate_indices)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
len(candidate_indices),
)
weight = specificity
local_query_weights[token_id] = weight
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += weight
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
if local_query_content_weight > 0.0:
query_content_weight = local_query_content_weight
query_norm = local_query_norm
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row_weights = cached_maps[sequence_index]
if not row_weights:
continue
if not self._numeric_prompt_can_match(query_numbers, cached_numbers[sequence_index]):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.40
and row_token_coverage < 0.75
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = cached_norms[sequence_index]
if row_norm <= 0.0:
continue
token_score = numerator / (query_norm * row_norm)
bigram_score = ordered_ngram_score(
query_bigrams,
cached_bigrams[sequence_index],
)
trigram_score = ordered_ngram_score(
query_trigrams,
cached_trigrams[sequence_index],
)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
)
return scores
if cached_index is not None:
candidate_set: set[int] = set()
for token_id in query_weights:
candidate_set.update(cached_index.get(token_id, ()))
if not candidate_set:
return {}
candidate_indices: list[int] | range = sorted(candidate_set)
local_query_weights: dict[int, float] = {}
local_query_specificity: dict[int, float] = {}
local_query_content_weight = 0.0
candidate_count = len(candidate_indices)
for token_id in query_weights:
local_frequency = len(candidate_set & set(cached_index.get(token_id, ())))
if local_frequency <= 0:
continue
specificity = self._prompt_overlap_token_specificity(
local_frequency,
candidate_count,
)
local_query_weights[token_id] = specificity
local_query_specificity[token_id] = specificity
if specificity >= 0.20:
local_query_content_weight += specificity
local_query_norm = sum(value * value for value in local_query_weights.values()) ** 0.5
if local_query_norm > 0.0:
query_weights = local_query_weights
query_specificity = local_query_specificity
if local_query_content_weight > 0.0:
query_content_weight = local_query_content_weight
query_norm = local_query_norm
else:
candidate_indices = range(len(self.answer_sequence_prompt_tokens))
scores: dict[int, float] = {}
for sequence_index in candidate_indices:
row = self.answer_sequence_prompt_tokens[sequence_index]
row_values = row.tolist() if hasattr(row, "tolist") else row
row_weights: dict[int, float] = {}
row_ids: list[int] = []
for raw_token_id in row_values:
token_id = int(raw_token_id)
if token_id < 0 or token_id >= len(self.trace_token_weights):
continue
row_ids.append(token_id)
row_weights[token_id] = max(
row_weights.get(token_id, 0.0),
specificity_map.get(token_id, 1.0),
)
if not row_weights:
continue
if not self._numeric_prompt_can_match(
query_numbers,
self._number_strings_from_token_ids(row_ids),
):
continue
matched_content_weight = sum(
query_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
if query_specificity.get(token_id, 0.0) >= 0.20
)
row_token_coverage = len(query_weights.keys() & row_weights.keys()) / max(
1,
len(row_weights),
)
if (
query_content_weight > 0.0
and matched_content_weight / query_content_weight < 0.40
and row_token_coverage < 0.75
):
continue
query_coverage = (
matched_content_weight / query_content_weight
if query_content_weight > 0.0
else row_token_coverage
)
numerator = sum(
query_weights[token_id] * row_weights[token_id]
for token_id in query_weights.keys() & row_weights.keys()
)
if numerator <= 0.0:
continue
row_norm = sum(value * value for value in row_weights.values()) ** 0.5
if row_norm > 0.0:
token_score = numerator / (query_norm * row_norm)
row_bigrams = {
(row_ids[index], row_ids[index + 1])
for index in range(len(row_ids) - 1)
}
row_trigrams = {
(row_ids[index], row_ids[index + 1], row_ids[index + 2])
for index in range(len(row_ids) - 2)
}
bigram_score = ordered_ngram_score(query_bigrams, row_bigrams)
trigram_score = ordered_ngram_score(query_trigrams, row_trigrams)
scores[sequence_index] = (
(0.35 * token_score)
+ (0.35 * query_coverage)
+ (0.15 * bigram_score)
+ (0.15 * trigram_score)
)
return scores
def _score_prompt_anchor_matches(
self,
answer_anchor_state: Vector | None,
keys: object | None,
key_norms_list: object | None,
values: object | None,
keys_array: object | None,
key_norms_array: object | None,
values_array: object | None,
valid_mask_array: object | None,
similarity_keys_array: object | None,
similarity_key_norms_array: object | None,
similarity_mask_array: object | None,
*,
limit: int,
) -> list[tuple[float, int, int]]:
if (
answer_anchor_state is None
or keys is None
or key_norms_list is None
or values is None
):
return []
if (
np is not None
and keys_array is not None
and key_norms_array is not None
and values_array is not None
and valid_mask_array is not None
and limit > 0
):
state_array = self._center_state_array(
self._masked_combined_state_array(answer_anchor_state)
).astype(keys_array.dtype, copy=False)
key_array = keys_array
key_norms = key_norms_array
if (
similarity_keys_array is not None
and similarity_key_norms_array is not None
and similarity_mask_array is not None
):
state_array = state_array * similarity_mask_array
key_array = similarity_keys_array
key_norms = similarity_key_norms_array
state_norm = float(np.linalg.norm(state_array))
if state_norm == 0.0:
return []
numerators = key_array @ state_array
denominators = key_norms * state_norm
valid_mask = valid_mask_array & (denominators > 0.0)
if np.any(valid_mask):
scores = np.zeros_like(numerators, dtype=key_array.dtype)
np.divide(numerators, denominators, out=scores, where=valid_mask)
positive_positions = np.flatnonzero(valid_mask & (scores > 0.0))
if positive_positions.size:
selected_positions = positive_positions
if positive_positions.size > limit:
partition = np.argpartition(scores[positive_positions], -limit)[-limit:]
selected_positions = positive_positions[partition]
ordered_positions = selected_positions[np.argsort(scores[selected_positions])[::-1]]
return [
(
float(scores[position]),
int(values_array[position]),
int(position),
)
for position in ordered_positions
]
state = self._center_state_vector(self._masked_combined_state(answer_anchor_state))
state_norm = norm(state)
if state_norm == 0.0:
return []
scored: list[tuple[float, int, int]] = []
for example_index, (key, key_norm, token_id) in enumerate(
zip(keys, key_norms_list, values)
):
if token_id < 0:
continue
denominator = state_norm * key_norm
if denominator == 0.0:
continue
similarity = dot(state, key) / denominator
if similarity > 0.0:
scored.append((similarity, token_id, example_index))
scored.sort(key=lambda item: item[0], reverse=True)
return scored[:limit]
def _answer_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
) -> Vector:
assert self.embedding_model is not None
if not matches:
return [0.0 for _ in self.embedding_model.id_to_token]
prior = [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = {
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
}
for similarity, token_id, _ in matches[:ANSWER_TOP_K]:
token = self.embedding_model.id_to_token[token_id]
if not self._allowed_generation_token(token, generated_tokens):
continue
if token_id in generated_ids:
prior[token_id] += similarity * 0.35
else:
prior[token_id] += similarity
return _normalize_vector(prior)
def _answer_sequence_prior_from_matches(
self,
matches: list[tuple[float, int, int]],
generated_tokens: list[str],
) -> Vector:
assert self.embedding_model is not None
if not matches or self.answer_sequence_tokens is None:
return [0.0 for _ in self.embedding_model.id_to_token]
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
prior = [0.0 for _ in self.embedding_model.id_to_token]
best_similarity = matches[0][0]
match_floor = best_similarity - 0.02 if best_similarity >= 0.9 else 0.0
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < match_floor:
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_generation_token(token, generated_tokens):
prior[next_token_id] += max(1e-9, similarity - match_floor)
return _normalize_vector(prior)
def _should_stop_answer_sequence(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_is_complete(generated_tokens, matches)
def _answer_decode_has_continuation(
self,
decode_state: DecodeState,
generated_tokens: list[str],
) -> bool:
matches = decode_state.answer_sequence_matches
if matches is None:
matches = self._score_answer_sequence_matches(
decode_state.answer_anchor_state,
decode_state.context_tokens,
)
return self._answer_sequence_has_continuation(generated_tokens, matches)
def _answer_sequence_is_complete(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids or len(generated_ids) < len(token_ids):
continue
if generated_ids[: len(token_ids)] == token_ids:
return True
return False
def _answer_sequence_has_continuation(
self,
generated_tokens: list[str],
matches: list[tuple[float, int, int]],
) -> bool:
if (
self.embedding_model is None
or self.answer_sequence_tokens is None
or not generated_tokens
or not matches
):
return False
generated_ids = [
self.embedding_model.token_to_id[token]
for token in generated_tokens
if token in self.embedding_model.token_to_id
]
if not generated_ids:
return False
for similarity, sequence_index, _ in matches[:ANSWER_START_TOP_K]:
if similarity < ANSWER_SEQUENCE_MATCH_FLOOR or sequence_index >= len(self.answer_sequence_tokens):
continue
row = self.answer_sequence_tokens[sequence_index]
token_ids = [
int(value)
for value in (row.tolist() if hasattr(row, "tolist") else row)
if int(value) >= 0
]
if not token_ids:
continue
next_token_id = self._next_sequence_token_id(token_ids, generated_ids)
if next_token_id is None:
continue
token = self.embedding_model.id_to_token[next_token_id]
if self._allowed_generation_token(token, generated_tokens):
return True
return False
def _next_sequence_token_id(
self,
token_ids: list[int],
generated_ids: list[int],
) -> int | None:
if not generated_ids:
return token_ids[0]
if len(generated_ids) >= len(token_ids):
return None
if token_ids[: len(generated_ids)] != generated_ids:
return None
return token_ids[len(generated_ids)]
def _transition_prior(self, context_tokens: list[str]) -> Vector:
prior, _ = self._transition_prior_with_order(context_tokens)
return prior
def _transition_prior_with_order(
self,
context_tokens: list[str],
) -> tuple[Vector, int | None]:
assert self.embedding_model is not None
if not self.transition_tables:
return [0.0 for _ in self.embedding_model.id_to_token], None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
prior = [0.0 for _ in self.embedding_model.id_to_token]
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
return _normalize_vector(prior), order
return [0.0 for _ in self.embedding_model.id_to_token], None
def _transition_prior_array_with_order(
self,
context_tokens: list[str],
) -> tuple[object, int | None]:
assert np is not None
assert self.embedding_model is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if not self.transition_tables:
return prior, None
for order in TRANSITION_ORDERS:
if len(context_tokens) < order:
continue
key = tuple(context_tokens[-order:])
transitions = self.transition_tables.get(order, {}).get(key)
if not transitions:
continue
for token, probability in transitions.items():
token_id = self.embedding_model.token_to_id.get(token)
if token_id is not None:
prior[token_id] = probability
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior, order
return prior, None
def _copy_prior(self, context_tokens: list[str]) -> Vector:
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = [0.0 for _ in self.embedding_model.id_to_token]
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens
if not source_tokens:
return prior
for distance, token in enumerate(reversed(source_tokens[-8:])):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += decay**distance
return _normalize_vector(prior)
def _copy_prior_array(self, context_tokens: list[str]) -> object:
assert np is not None
assert self.embedding_model is not None
assert self.tokenizer is not None
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
decay = 0.82
answer_start = None
for index in range(len(context_tokens) - 1, -1, -1):
if context_tokens[index] == "<answer>":
answer_start = index + 1
break
source_tokens = context_tokens[answer_start:] if answer_start is not None else context_tokens
for distance, token in enumerate(reversed(source_tokens[-8:])):
if token in self.tokenizer.special_tokens:
continue
if not self._eligible_copy_token(token):
continue
token_id = self.embedding_model.token_to_id.get(token)
if token_id is None:
continue
prior[token_id] += decay**distance
total = float(prior.sum())
if total > 0.0:
prior /= total
return prior
def _preference_prior(self) -> Vector:
assert self.embedding_model is not None
if not self.preference_bias or not any(value != 0.0 for value in self.preference_bias):
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_indices = [
index
for index, token in enumerate(self.embedding_model.id_to_token)
if self.preference_bias[index] > 0.0 and self._eligible_preference_token(token)
]
if not eligible_indices:
return [0.0 for _ in self.embedding_model.id_to_token]
eligible_probabilities = self._calibrated_softmax(
[self.preference_bias[index] for index in eligible_indices]
)
prior = [0.0 for _ in self.embedding_model.id_to_token]
for index, probability in zip(eligible_indices, eligible_probabilities):
prior[index] = probability
return prior
def _preference_prior_array(self) -> object:
assert np is not None
assert self.embedding_model is not None
if self.preference_bias_array is None or not np.any(self.preference_bias_array != 0.0):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
if self.preference_valid_mask_array is None or not np.any(self.preference_valid_mask_array):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
positive_mask = self.preference_bias_array > 0.0
active_mask = self.preference_valid_mask_array & positive_mask
if not np.any(active_mask):
return np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior = np.zeros(len(self.embedding_model.id_to_token), dtype=np.float64)
prior[active_mask] = self._calibrated_softmax_array(
self.preference_bias_array[active_mask]
)
return prior
def _eligible_preference_token(self, token: str) -> bool:
assert self.tokenizer is not None
if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens:
return False
if not self._starts_new_word(token):
return False
rendered = self._render_token(token)
if not rendered.strip() or self._is_punctuation_piece(rendered):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 1
def _build_transition_tables(
self,
tokens: list[str],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
counts: dict[int, dict[tuple[str, ...], dict[str, int]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order in sorted(TRANSITION_ORDERS):
for index in range(order - 1, len(tokens) - 1):
key = tuple(tokens[index - order + 1 : index + 1])
nxt = tokens[index + 1]
bucket = counts[order].setdefault(key, {})
bucket[nxt] = bucket.get(nxt, 0) + 1
probabilities: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order, mapping in counts.items():
items = list(mapping.items())
items.sort(key=lambda item: (-sum(item[1].values()), item[0]))
if (
self.config.max_transition_contexts_per_order is not None
and self.config.max_transition_contexts_per_order >= 0
):
items = items[: self.config.max_transition_contexts_per_order]
for key, bucket in items:
next_items = sorted(bucket.items(), key=lambda item: (-item[1], item[0]))
if self.config.max_transition_next_tokens > 0:
next_items = next_items[: self.config.max_transition_next_tokens]
total = sum(value for _, value in next_items)
if total <= 0:
continue
probabilities[order][key] = {
token: value / total
for token, value in next_items
}
return probabilities
def _serialize_transition_tables(self) -> dict[str, dict[str, dict[str, float]]]:
assert self.transition_tables is not None
return {
str(order): {
_encode_ngram_key(key): value
for key, value in mapping.items()
}
for order, mapping in self.transition_tables.items()
}
def _deserialize_transition_tables(
self,
payload: dict[str, dict[str, dict[str, float]]],
) -> dict[int, dict[tuple[str, ...], dict[str, float]]]:
tables: dict[int, dict[tuple[str, ...], dict[str, float]]] = {
order: {} for order in sorted(TRANSITION_ORDERS)
}
for order_text, mapping in payload.items():
order = int(order_text)
tables[order] = {
_decode_ngram_key(key): {
str(token): float(probability)
for token, probability in value.items()
}
for key, value in mapping.items()
}
return tables
def _eligible_copy_token(self, token: str) -> bool:
rendered = self._render_token(token)
if not rendered.strip():
return False
if self._is_punctuation_piece(rendered):
return False
if not self._starts_new_word(token):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 2
def _allowed_generation_token(self, token: str, generated_tokens: list[str]) -> bool:
assert self.embedding_model is not None
if len(self.embedding_model.id_to_token) < 1024:
return True
if token == self.tokenizer.unk_token or token in self.tokenizer.special_tokens:
return False
rendered = self._render_token(token)
if rendered == "\n":
return bool(generated_tokens)
if not rendered.strip():
return False
if self._is_word_joiner_token(token):
return (
self._can_attach_word_joiner(generated_tokens)
or self._can_start_line_with_word_joiner(token, generated_tokens)
)
if self._is_structural_punctuation_token(token):
return bool(generated_tokens) or self._can_start_answer_with_structural_punctuation(token)
if self._is_structural_symbol_token(token):
return bool(generated_tokens) or self._starts_new_word(token)
if not self._starts_new_word(token):
return False
alphanumeric = "".join(character for character in rendered if character.isalnum())
return len(alphanumeric) >= 1 or not self._is_punctuation_piece(rendered)
def _would_repeat_recent_pattern(
self,
candidate: str,
generated_tokens: list[str],
recent_rendered_words: list[str] | None = None,
) -> bool:
if len(generated_tokens) >= 2 and generated_tokens[-1] == candidate and generated_tokens[-2] == candidate:
return True
if len(generated_tokens) >= 2:
trigram = tuple(generated_tokens[-2:] + [candidate])
recent_tokens = generated_tokens[-12:]
for index in range(max(0, len(recent_tokens) - 4)):
if tuple(recent_tokens[index : index + 3]) == trigram:
return True
rendered_words = recent_rendered_words
if rendered_words is None:
rendered_words = self._recent_rendered_words(generated_tokens)
candidate_word = self._render_token(candidate).casefold()
if (
rendered_words
and self._starts_new_word(candidate)
and any(character.isalnum() for character in candidate_word)
):
candidate_bigram = (rendered_words[-1], candidate_word)
recent_window = rendered_words[-10:]
recent_bigrams = {
(recent_window[index], recent_window[index + 1])
for index in range(len(recent_window) - 1)
}
if candidate_bigram in recent_bigrams:
return True
if (
len(candidate_word) > 2
and rendered_words[-10:].count(candidate_word) >= 2
and not self._is_common_connector_token(candidate)
):
return True
return False
def _recent_rendered_words(self, generated_tokens: list[str]) -> list[str]:
rendered_words: list[str] = []
for token in generated_tokens:
if not self._starts_new_word(token):
continue
rendered = self._render_token(token).casefold()
if any(character.isalnum() for character in rendered):
rendered_words.append(rendered)
return rendered_words
def _select_generation_token(
self,
distribution: dict[str, float],
*,
context_tokens: list[str] | None = None,
generated_tokens: list[str] | None = None,
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
) -> str:
assert self.tokenizer is not None
generated_tokens = generated_tokens or []
candidates = self._prepare_generation_candidates(
distribution,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
)
if candidates:
return self._sample_generation_candidate(
candidates,
context_tokens=context_tokens or [],
generated_tokens=generated_tokens,
stochastic=temperature > 0.0,
)
for token, _ in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens:
continue
if token == self.tokenizer.unk_token:
continue
if not self._allowed_generation_token(token, generated_tokens):
continue
return token
return ""
def _select_generation_token_from_array(
self,
probabilities: object,
*,
context_tokens: list[str],
generated_tokens: list[str],
temperature: float = DEFAULT_GENERATION_TEMPERATURE,
top_k: int = DEFAULT_GENERATION_TOP_K,
top_p: float = DEFAULT_GENERATION_TOP_P,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
preserve_dominant_candidates: bool = False,
) -> str:
assert np is not None
assert self.tokenizer is not None
assert self.embedding_model is not None
values = np.asarray(probabilities, dtype=np.float64)
if values.size == 0:
return ""
pool_size = min(values.size, max(top_k * 4, 64))
if pool_size <= 0:
pool_size = min(values.size, 64)
if pool_size < values.size:
candidate_indices = np.argpartition(values, -pool_size)[-pool_size:]
candidate_indices = candidate_indices[np.argsort(values[candidate_indices])[::-1]]
else:
candidate_indices = np.argsort(values)[::-1]
distribution: dict[str, float] = {}
for raw_index in candidate_indices:
index = int(raw_index)
score = float(values[index])
if score <= 0.0:
continue
token = self.embedding_model.id_to_token[index]
if token in self.tokenizer.special_tokens or token == self.tokenizer.unk_token:
continue
distribution[token] = score
return self._select_generation_token(
distribution,
context_tokens=context_tokens,
generated_tokens=generated_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
preserve_dominant_candidates=preserve_dominant_candidates,
)
def _prepare_generation_candidates(
self,
distribution: dict[str, float],
*,
generated_tokens: list[str],
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
preserve_dominant_candidates: bool = False,
) -> list[tuple[str, float]]:
assert self.tokenizer is not None
assert self.embedding_model is not None
generated_word_count = self._generated_word_count(generated_tokens)
clause_words = self._words_since_clause_break(generated_tokens)
recent_rendered_words = self._recent_rendered_words(generated_tokens)
best_probability = max(distribution.values(), default=0.0)
adjusted: list[tuple[str, float]] = []
for token, probability in sorted(distribution.items(), key=lambda item: item[1], reverse=True):
if token in self.tokenizer.special_tokens:
continue
if token == self.tokenizer.unk_token or probability <= 0.0:
continue
if not self._allowed_generation_token(token, generated_tokens):
continue
repeats_recent_pattern = self._would_repeat_recent_pattern(
token,
generated_tokens,
recent_rendered_words=recent_rendered_words,
)
if (
repeats_recent_pattern
and not (
preserve_dominant_candidates
and best_probability > 0.0
and probability >= best_probability * 0.80
)
):
continue
score = probability
rendered = self._render_token(token)
punctuation_token = self._is_structural_punctuation_token(token)
starts_new_word = self._starts_new_word(token)
alphanumeric = "".join(character for character in rendered if character.isalnum())
if generated_tokens and starts_new_word and alphanumeric:
previous_rendered = self._render_token(generated_tokens[-1])
previous_alphanumeric = "".join(
character for character in previous_rendered if character.isalnum()
)
if previous_alphanumeric.casefold() == alphanumeric.casefold():
continue
common_connector = self._is_common_connector_token(token)
if (
starts_new_word
and len(alphanumeric) == 1
and not common_connector
):
score *= 0.08
recent_count = generated_tokens[-12:].count(token)
if recent_count > 0 and not common_connector:
score /= repetition_penalty ** (2 * recent_count)
if generated_tokens and token == generated_tokens[-1]:
score /= repetition_penalty**3
if generated_tokens and token in generated_tokens[-4:] and not common_connector:
score *= 0.35
if generated_tokens and not starts_new_word and self._starts_new_word(generated_tokens[-1]):
score *= 0.08
if not generated_tokens and punctuation_token:
if best_probability <= 0.0 or probability < best_probability * 0.80:
score *= 0.01
elif not generated_tokens and not starts_new_word:
score *= 0.02
if punctuation_token:
if generated_tokens and self._is_structural_punctuation_token(generated_tokens[-1]):
score *= 0.05
if clause_words >= 6:
score *= 1.0 + min(1.4, 0.18 * (clause_words - 5))
elif generated_word_count >= 12:
score *= 1.1
if score > 0.0:
adjusted.append((token, score))
if not adjusted:
return []
adjusted.sort(key=lambda item: item[1], reverse=True)
if top_k > 0:
adjusted = adjusted[:top_k]
if 0.0 < top_p < 1.0:
kept: list[tuple[str, float]] = []
cumulative = 0.0
total = sum(score for _, score in adjusted)
for token, score in adjusted:
normalized = score / total if total else 0.0
kept.append((token, score))
cumulative += normalized
if cumulative >= top_p:
break
adjusted = kept
if temperature <= 0.0:
return [(adjusted[0][0], 1.0)]
exponent = 1.0 / temperature
tempered = [
(token, score**exponent)
for token, score in adjusted
if score > 0.0
]
total = sum(score for _, score in tempered)
if total <= 0.0:
return []
return [(token, score / total) for token, score in tempered]
def _sample_generation_candidate(
self,
candidates: list[tuple[str, float]],
*,
context_tokens: list[str],
generated_tokens: list[str],
stochastic: bool = False,
) -> str:
if not candidates:
return ""
if len(candidates) == 1:
return candidates[0][0]
top_probability = candidates[0][1]
second_probability = candidates[1][1]
top_has_clear_half_majority = top_probability >= 0.5 and (
second_probability <= 0.0
or top_probability - second_probability >= 0.02
)
if top_has_clear_half_majority or (
second_probability > 0.0 and top_probability >= second_probability * 2.5
) or (
top_probability >= 0.08
and second_probability > 0.0
and top_probability >= second_probability * 1.35
):
return candidates[0][0]
if stochastic:
threshold = random.random()
else:
seed_payload = "\u0002".join([*context_tokens, "<generated>", *generated_tokens, str(len(candidates))])
seed = int.from_bytes(hashlib.sha256(seed_payload.encode("utf-8")).digest()[:8], "big")
threshold = random.Random(seed).random()
cumulative = 0.0
for token, probability in candidates:
cumulative += probability
if threshold <= cumulative:
return token
return candidates[-1][0]
def _top_entries_from_vector(
self,
values: Vector,
limit: int,
) -> list[dict[str, object]]:
if limit <= 0:
return []
ranked = sorted(
enumerate(values),
key=lambda item: item[1],
reverse=True,
)
return [
self._token_entry(index, probability)
for index, probability in ranked[:limit]
if probability > 0.0
]
def _token_entry(
self,
index: int,
probability: float,
) -> dict[str, object]:
assert self.embedding_model is not None
token = self.embedding_model.id_to_token[index]
return {
"token": token,
"text": self._render_token(token),
"probability": probability,
}
def _build_reasoning_summary(
self,
transition_order: int | None,
blend_weights: dict[str, float],
) -> str:
dominant_source = max(blend_weights.items(), key=lambda item: item[1])[0] if blend_weights else "base"
if transition_order is not None:
transition_message = f" Transition prior is using order-{transition_order} context."
else:
transition_message = " Transition prior found no matching n-gram."
return (
"Generation is running on analytical state, recurrent traces, and corpus-derived token transitions."
f"{transition_message}"
f" Dominant blend source: {dominant_source}."
)
def _generated_word_count(self, tokens: list[str]) -> int:
return len(self._decode_tokens(tokens).split())
def _is_structural_punctuation_text(self, text: str) -> bool:
if len(text) != 1:
return False
if self._is_word_joiner_text(text):
return False
category = unicodedata.category(text)
return category.startswith("P")
def _is_structural_punctuation_token(self, token: str) -> bool:
return self._is_structural_punctuation_text(self._render_token(token))
def _is_structural_symbol_token(self, token: str) -> bool:
rendered = self._render_token(token)
return len(rendered) == 1 and unicodedata.category(rendered).startswith("S")
def _is_word_joiner_token(self, token: str) -> bool:
return self._is_word_joiner_text(self._render_token(token))
def _is_word_joiner_text(self, text: str) -> bool:
if len(text) != 1:
return False
category = unicodedata.category(text)
if category in ("Pc", "Pd", "Lm"):
return True
name = unicodedata.name(text, "")
return "APOSTROPHE" in name or (
"SINGLE" in name and "QUOTATION MARK" in name
)
def _can_start_line_with_word_joiner(self, token: str, generated_tokens: list[str]) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or unicodedata.category(rendered) != "Pd":
return False
if not self._starts_new_word(token):
return False
return not generated_tokens or self._render_token(generated_tokens[-1]) == "\n"
def _can_start_answer_with_structural_punctuation(self, token: str) -> bool:
rendered = self._render_token(token)
if len(rendered) != 1 or not self._starts_new_word(token):
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _is_common_connector_token(self, token: str) -> bool:
rendered = self._render_token(token)
return rendered.isalpha() and len(rendered) <= 3
def _can_attach_word_joiner(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
rendered = self._render_token(generated_tokens[-1])
if not rendered:
return False
if any(character.isalnum() for character in rendered):
return True
if len(rendered) != 1:
return False
return unicodedata.category(rendered) in ("Ps", "Pi")
def _words_since_clause_break(self, tokens: list[str]) -> int:
assert self.tokenizer is not None
words = 0
for token in reversed(tokens):
if token in self.tokenizer.special_tokens:
continue
rendered = self._render_token(token)
if self._is_structural_punctuation_text(rendered):
break
if self._starts_new_word(token) and not self._is_punctuation_piece(rendered):
words += 1
return words
def _should_stop_generation(self, generated_tokens: list[str]) -> bool:
if not generated_tokens:
return False
if not self._is_terminal_punctuation_text(self._render_token(generated_tokens[-1])):
return False
return self._generated_word_count(generated_tokens) >= 14
def _is_terminal_punctuation_text(self, text: str) -> bool:
if not self._is_structural_punctuation_text(text):
return False
name = unicodedata.name(text, "")
return (
"FULL STOP" in name
or "QUESTION MARK" in name
or "EXCLAMATION MARK" in name
)
def _starts_new_word(self, token: str) -> bool:
assert self.tokenizer is not None
if token in self.tokenizer.special_tokens:
return True
if token.startswith(self.tokenizer.word_prefix):
return True
return len(token) == 1 and not token.isalnum() and not self._is_word_joiner_token(token)
def _decode_tokens(self, tokens: list[str]) -> str:
assert self.tokenizer is not None
return self.tokenizer.decode(tokens)
def _render_token(self, token: str) -> str:
assert self.tokenizer is not None
if token.startswith(self.tokenizer.word_prefix):
return token[len(self.tokenizer.word_prefix) :]
return token
def _require_fit(self) -> None:
if (
self.tokenizer is None
or self.embedding_model is None
or self.memory_units is None
or self.readout_weights is None
or self.ternary_mask is None
or self.associative_keys is None
or self.associative_key_norms is None
or self.associative_values is None
or self.transition_tables is None
):
raise RuntimeError("Call fit() before using the REFRAMR model.")
def _ensure_numeric_caches(self) -> None:
if np is None:
return
if self.readout_weights_array is None:
self._refresh_numeric_caches()