k-l-lambda commited on
Commit
764d3da
·
1 Parent(s): 4fb4554

load model online

Browse files
Files changed (2) hide show
  1. app.py +29 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -10,7 +10,8 @@ Right column:
10
  Generation streams patch-by-patch: raw decoded text (with `[r:x/y]` stream
11
  markers) goes to the run log, while the measure-segmented postprocessed text
12
  fills the editor. The backend is the int8 + two-level KV-cache ONNX generator
13
- (see lilyscript/generator.py); models load from a local dir for now.
 
14
  """
15
 
16
  import os
@@ -26,9 +27,13 @@ from lilyscript.generator import StreamingLilyletGenerator
26
  from lilyscript.postprocess import postprocess
27
 
28
  HERE = os.path.dirname(os.path.abspath(__file__))
29
- # TODO: swap for huggingface_hub.snapshot_download(repo_id=...) to pull the int8
30
- # ONNX weights from the hub instead of a local dir.
31
- MODEL_DIR = os.environ.get('LILYSCRIPT_MODEL_DIR', os.path.join(HERE, 'models'))
 
 
 
 
32
  ASSET_DIR = os.path.join(HERE, 'assets')
33
  EXAMPLES_DIR = os.path.join(HERE, 'examples')
34
  OUTPUT_DIR = os.path.join(HERE, 'outputs')
@@ -92,13 +97,31 @@ def _init_logging ():
92
  _init_logging()
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def get_generator ():
96
  '''Lazily build the (heavy) ONNX generator on first use.'''
97
  global _GEN
98
  if _GEN is None:
99
- LOG.info('loading ONNX generator from %s ...', MODEL_DIR)
 
100
  t0 = time.perf_counter()
101
- _GEN = StreamingLilyletGenerator(MODEL_DIR, ASSET_DIR)
102
  LOG.info('generator ready (%.1fs)', time.perf_counter() - t0)
103
  return _GEN
104
 
 
10
  Generation streams patch-by-patch: raw decoded text (with `[r:x/y]` stream
11
  markers) goes to the run log, while the measure-segmented postprocessed text
12
  fills the editor. The backend is the int8 + two-level KV-cache ONNX generator
13
+ (see lilyscript/generator.py); weights are pulled from the HF model repo
14
+ `k-l-lambda/LilyNota` on first use (override with LILYSCRIPT_MODEL_DIR locally).
15
  """
16
 
17
  import os
 
27
  from lilyscript.postprocess import postprocess
28
 
29
  HERE = os.path.dirname(os.path.abspath(__file__))
30
+ # Model weights are pulled from the HuggingFace model repo `k-l-lambda/LilyNota`
31
+ # at first use (the int8 + KV-cache ONNX bundle lives under its `onnx/` dir).
32
+ # For local development, point LILYSCRIPT_MODEL_DIR at a local onnx dir to skip
33
+ # the download.
34
+ HF_MODEL_REPO = os.environ.get('LILYSCRIPT_MODEL_REPO', 'k-l-lambda/LilyNota')
35
+ HF_MODEL_SUBDIR = 'onnx' # weights + geometry + tokenizer live here in the repo
36
+ MODEL_DIR = os.environ.get('LILYSCRIPT_MODEL_DIR') # set -> use this local dir instead of the hub
37
  ASSET_DIR = os.path.join(HERE, 'assets')
38
  EXAMPLES_DIR = os.path.join(HERE, 'examples')
39
  OUTPUT_DIR = os.path.join(HERE, 'outputs')
 
97
  _init_logging()
98
 
99
 
100
+ def resolve_model_dir ():
101
+ '''Where the ONNX weights live. If LILYSCRIPT_MODEL_DIR is set, use it as-is
102
+ (local dev). Otherwise pull the `onnx/` bundle from the HF model repo and
103
+ return its local snapshot path. The tokenizer is NOT pulled — it's read from
104
+ the app's own assets/ dir — so we only fetch the weight files.'''
105
+ if MODEL_DIR:
106
+ return MODEL_DIR
107
+ from huggingface_hub import snapshot_download
108
+ LOG.info('downloading model weights from hf:%s (%s/) ...', HF_MODEL_REPO, HF_MODEL_SUBDIR)
109
+ local = snapshot_download(
110
+ repo_id=HF_MODEL_REPO,
111
+ allow_patterns=[f'{HF_MODEL_SUBDIR}/patch_kv_int8.onnx', f'{HF_MODEL_SUBDIR}/token_kv_int8.onnx',
112
+ f'{HF_MODEL_SUBDIR}/wte.npy', f'{HF_MODEL_SUBDIR}/geometry.json'],
113
+ )
114
+ return os.path.join(local, HF_MODEL_SUBDIR)
115
+
116
+
117
  def get_generator ():
118
  '''Lazily build the (heavy) ONNX generator on first use.'''
119
  global _GEN
120
  if _GEN is None:
121
+ model_dir = resolve_model_dir()
122
+ LOG.info('loading ONNX generator from %s ...', model_dir)
123
  t0 = time.perf_counter()
124
+ _GEN = StreamingLilyletGenerator(model_dir, ASSET_DIR)
125
  LOG.info('generator ready (%.1fs)', time.perf_counter() - t0)
126
  return _GEN
127
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio==6.18.0
2
  onnxruntime
3
  numpy
 
 
1
  gradio==6.18.0
2
  onnxruntime
3
  numpy
4
+ huggingface-hub