k-l-lambda commited on
Commit
a8785dd
·
1 Parent(s): 4252956

refined mask_monitor.

Browse files
Files changed (1) hide show
  1. lilyscript/mask_monitor.py +95 -31
lilyscript/mask_monitor.py CHANGED
@@ -2,34 +2,37 @@
2
 
3
  A lightweight, parse-free counterpart to tools/lilylet_blacklist_gen.py's
4
  BlacklistMonitor. It does NOT call any parser/oracle and has no Node dependency:
5
- it simply trusts a pre-discovered 2-gram blacklist and masks the forbidden next
6
- tokens during sampling. Suitable for the Gradio app's live generation path.
7
 
8
  Wire-compatible with StreamingLilyletGenerator's `monitor` hook:
9
- banned() -> ids to mask for the next draw (blacklist[current 2-gram])
10
  accept(id)->bool -> always True (we trust the mask; never re-parse)
11
- commit_forced(id) -> advance running text/context for a forced token
12
  Plus mark()/rollback() so the generator's `[r:0/<measures>]` priming re-sample
13
- (a probe-then-discard draw) can rewind the running text before the forced redraw.
14
 
15
- The 2-gram context is the last two CONTENT token ids of the marker-stripped
16
- stream (markers aren't real Lilylet), recomputed from a bounded text suffix so it
17
- "sees through" `[r:x/y]` markers exactly as the discovery monitor did.
 
 
18
  """
19
 
20
  import os
21
  import re
22
  import json
23
 
24
- # A complete stream marker `[r:<digits>/<digits>]`; a trailing partial `[r…`.
25
- # (Mirror tools/lilylet_blacklist_gen.py: only strip a confirmed `[r` prefix, never
26
- # a bare `[`, which is real content a header `[composer …]` or beam `c8[`.)
27
- _MARKER_COMPLETE = re.compile(r'\[r:\d+/\d*\]')
28
- _MARKER_PARTIAL_TAIL = re.compile(r'\[r(:(\d+(/\d*)?)?)?$')
29
-
30
-
31
- def _clean (raw):
32
- return _MARKER_PARTIAL_TAIL.sub('', _MARKER_COMPLETE.sub('', raw))
 
33
 
34
 
35
  def _whitespace_ids (tokenizer):
@@ -71,8 +74,16 @@ class MaskMonitor:
71
  self._ws = set(_whitespace_ids(self.tk))
72
  self._key_lengths = sorted({len(k) for k in self.blacklist}, reverse=True) if self.blacklist else []
73
  self._max_ctx = max(self._key_lengths) if self._key_lengths else 0
74
- self.raw = ''
75
- self._ctx_ids = [] # last <=_max_ctx content non-whitespace ids
 
 
 
 
 
 
 
 
76
  self._mark = None
77
 
78
  def _is_content (self, tid):
@@ -85,19 +96,71 @@ class MaskMonitor:
85
  tid = int(tid)
86
  return '' if not self._is_content(tid) else self.tk.text_by_id.get(tid, '')
87
 
88
- def _sync_ctx (self):
89
- clean = _clean(self.raw)
90
- ids = self.tk.encode(clean[-256:])
91
- ctx = [i for i in ids if self._is_ctx(i)]
92
- self._ctx_ids = ctx[-self._max_ctx:] if self._max_ctx else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # ---- generator-facing API ----
95
 
96
  def banned (self):
97
- '''Union forbidden sets of every stored key that is a suffix of the context.'''
 
 
 
98
  if not self._key_lengths:
99
  return ()
100
  ctx = self._ctx_ids
 
 
101
  out = set()
102
  for n in self._key_lengths:
103
  if n <= len(ctx):
@@ -112,17 +175,18 @@ class MaskMonitor:
112
  return True
113
 
114
  def commit_forced (self, tid):
115
- self.raw += self._text(tid)
116
- self._sync_ctx()
117
 
118
  # ---- priming support (probe-then-discard rewind) ----
119
 
120
  def mark (self):
121
- '''Remember the current running text so a discarded probe patch can be undone.'''
122
- self._mark = self.raw
123
 
124
  def rollback (self):
125
  '''Rewind to the last mark() (drops a probe patch's effect on the context).'''
126
  if self._mark is not None:
127
- self.raw = self._mark
128
- self._sync_ctx()
 
 
2
 
3
  A lightweight, parse-free counterpart to tools/lilylet_blacklist_gen.py's
4
  BlacklistMonitor. It does NOT call any parser/oracle and has no Node dependency:
5
+ it trusts a pre-discovered variable-length n-gram blacklist and masks the
6
+ forbidden next tokens during sampling. Suitable for the Gradio live-gen path.
7
 
8
  Wire-compatible with StreamingLilyletGenerator's `monitor` hook:
9
+ banned() -> ids to mask for the next draw (suffix match on the context)
10
  accept(id)->bool -> always True (we trust the mask; never re-parse)
11
+ commit_forced(id) -> advance the running context for a forced token
12
  Plus mark()/rollback() so the generator's `[r:0/<measures>]` priming re-sample
13
+ (a probe-then-discard draw) can rewind the context before the forced redraw.
14
 
15
+ The context is the last N CONTENT token ids (whitespace dropped, `[r:x/y]` stream
16
+ markers excluded), maintained INCREMENTALLY in id space no per-token tokenizer
17
+ call. Markers are detected by buffering a potential `[r:]` run and classifying
18
+ the tiny buffer text with the same regexes _clean uses, so the resulting context
19
+ is identical to clean_for_parse()+tokenize but at near-zero cost.
20
  """
21
 
22
  import os
23
  import re
24
  import json
25
 
26
+ # Anchored marker regexes for the id-space state machine — applied to the tiny
27
+ # marker BUFFER text (a candidate `[r:…]` run), never the whole stream:
28
+ # - COMPLETE `\[r:\d+/\d*\]` -> drop the buffer entirely (a finished marker)
29
+ # - viable PARTIAL `\[r(:(\d+(/\d*)?)?)?` -> keep buffering; it is excluded from
30
+ # the context (mirrors discovery's _MARKER_PARTIAL_TAIL at end-of-stream)
31
+ # - the lone `[` is a viable partial too, but stays VISIBLE as content (a bare
32
+ # `[` is a header `[composer …]` / beam `c8[`, which _clean keeps)
33
+ # - anything else -> not a marker -> flush the buffer back into the context
34
+ _MARK_COMPLETE_FULL = re.compile(r'\[r:\d+/\d*\]\Z')
35
+ _MARK_PARTIAL_FULL = re.compile(r'\[r(:(\d+(/\d*)?)?)?\Z')
36
 
37
 
38
  def _whitespace_ids (tokenizer):
 
74
  self._ws = set(_whitespace_ids(self.tk))
75
  self._key_lengths = sorted({len(k) for k in self.blacklist}, reverse=True) if self.blacklist else []
76
  self._max_ctx = max(self._key_lengths) if self._key_lengths else 0
77
+ self._lbrack = self.tk.encode('[')[0]
78
+ # Context is maintained INCREMENTALLY in id space — no per-token tokenizer
79
+ # call. The only thing that needs care is excluding `[r:x/y]` stream markers,
80
+ # whose chars share ids with real content. We buffer a *potential* marker
81
+ # (always starting at `[`) and decide using the SAME regexes as _clean applied
82
+ # to the tiny buffer text (id->char is a cheap dict lookup), so the resulting
83
+ # context is provably identical to _clean()+encode of the full stream.
84
+ self._ctx_ids = [] # confirmed visible content ids (whitespace dropped)
85
+ self._buf = [] # token ids of an in-progress potential marker
86
+ self._buf_text = '' # their concatenated text (starts with '[')
87
  self._mark = None
88
 
89
  def _is_content (self, tid):
 
96
  tid = int(tid)
97
  return '' if not self._is_content(tid) else self.tk.text_by_id.get(tid, '')
98
 
99
+ def _push (self, tid):
100
+ self._ctx_ids.append(int(tid))
101
+ if len(self._ctx_ids) > self._max_ctx:
102
+ del self._ctx_ids[:-self._max_ctx]
103
+
104
+ def _push_content (self, tid):
105
+ '''Route a non-marker token into the context: drop whitespace, keep content.'''
106
+ if self._is_ctx(tid):
107
+ self._push(tid)
108
+
109
+ def _flush_buf (self):
110
+ '''The buffered tokens are NOT a marker after all -> they are real content;
111
+ replay them through the content path, then clear the buffer.'''
112
+ buf = self._buf
113
+ self._buf = []
114
+ self._buf_text = ''
115
+ for tid in buf:
116
+ self._push_content(tid)
117
+
118
+ def _feed (self, tid):
119
+ '''Advance the incremental context with one committed token id.'''
120
+ tid = int(tid)
121
+ if not self._is_content(tid):
122
+ # pad/bos/eos: not content, and breaks any pending marker buffer.
123
+ if self._buf:
124
+ self._flush_buf()
125
+ return
126
+
127
+ if self._buf:
128
+ # extend the potential marker and re-classify the buffer text.
129
+ text = self._buf_text + self._text(tid)
130
+ if _MARK_COMPLETE_FULL.match(text):
131
+ # a full `[r:x/y]` -> the whole buffer contributes nothing; drop it.
132
+ self._buf = []
133
+ self._buf_text = ''
134
+ return
135
+ if _MARK_PARTIAL_FULL.match(text):
136
+ # still a viable marker prefix (`[r`, `[r:`, `[r:5`, `[r:5/`, `[r:5/3`)
137
+ self._buf.append(tid)
138
+ self._buf_text = text
139
+ return
140
+ # not a marker: flush the buffer as content, then process tid afresh below.
141
+ self._flush_buf()
142
+
143
+ if tid == self._lbrack:
144
+ # start a potential marker (a lone `[` is itself real content until/unless
145
+ # an `r` follows; that visibility is handled in banned()).
146
+ self._buf = [tid]
147
+ self._buf_text = '['
148
+ return
149
+
150
+ self._push_content(tid)
151
 
152
  # ---- generator-facing API ----
153
 
154
  def banned (self):
155
+ '''Union forbidden sets of every stored key that is a suffix of the context.
156
+
157
+ The visible context is _ctx_ids plus a pending lone `[` (which _clean keeps);
158
+ a longer `[r…` partial buffer is stripped (matches _MARKER_PARTIAL_TAIL).'''
159
  if not self._key_lengths:
160
  return ()
161
  ctx = self._ctx_ids
162
+ if self._buf_text == '[':
163
+ ctx = ctx + [self._lbrack]
164
  out = set()
165
  for n in self._key_lengths:
166
  if n <= len(ctx):
 
175
  return True
176
 
177
  def commit_forced (self, tid):
178
+ if self._max_ctx:
179
+ self._feed(tid)
180
 
181
  # ---- priming support (probe-then-discard rewind) ----
182
 
183
  def mark (self):
184
+ '''Snapshot the context so a discarded priming-probe patch can be undone.'''
185
+ self._mark = (list(self._ctx_ids), list(self._buf), self._buf_text)
186
 
187
  def rollback (self):
188
  '''Rewind to the last mark() (drops a probe patch's effect on the context).'''
189
  if self._mark is not None:
190
+ self._ctx_ids = list(self._mark[0])
191
+ self._buf = list(self._mark[1])
192
+ self._buf_text = self._mark[2]