hellosindh commited on
Commit
3872f06
Β·
verified Β·
1 Parent(s): 4409eca

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +21 -11
inference.py CHANGED
@@ -31,7 +31,7 @@ import torch.nn.functional as F
31
 
32
 
33
  # ── Auto-download from HuggingFace ────────────────────────────
34
- HF_REPO = "YOUR_USERNAME/indus-script-models" # update after upload
35
 
36
  def download_models(repo_id=HF_REPO, local_dir="indus_models"):
37
  """Download all model files from HuggingFace."""
@@ -48,18 +48,25 @@ def download_models(repo_id=HF_REPO, local_dir="indus_models"):
48
 
49
 
50
  def get_model_dir():
51
- """Find model directory β€” local DATA/models or downloaded."""
52
- # Try local development path first
 
 
 
 
 
 
 
 
 
 
 
53
  local = Path("DATA/models")
54
  if local.exists():
55
  return local, Path("DATA")
56
- # Try downloaded path
57
- downloaded = Path("indus_models")
58
- if downloaded.exists():
59
- return downloaded / "models", downloaded
60
- # Auto-download
61
  path = download_models()
62
- return Path(path) / "models", Path(path)
63
 
64
 
65
  # ── Device ─────────────────────────────────────────────────────
@@ -73,8 +80,11 @@ PAD_ID = 816
73
  # ── Load helpers ───────────────────────────────────────────────
74
  def load_tokenizer(data_dir):
75
  from transformers import PreTrainedTokenizerFast
76
- return PreTrainedTokenizerFast.from_pretrained(
77
- str(data_dir / "indus_tokenizer"))
 
 
 
78
 
79
 
80
  def load_bert_mlm(model_dir):
 
31
 
32
 
33
  # ── Auto-download from HuggingFace ────────────────────────────
34
+ HF_REPO = "hellosindh/indus-script-models" # update after upload
35
 
36
  def download_models(repo_id=HF_REPO, local_dir="indus_models"):
37
  """Download all model files from HuggingFace."""
 
48
 
49
 
50
  def get_model_dir():
51
+ """
52
+ Find model directory.
53
+ Priority:
54
+ 1. ./models/ (running from cloned HuggingFace repo)
55
+ 2. DATA/models/ (running from original indus_script folder)
56
+ 3. Auto-download from HuggingFace
57
+ """
58
+ # Running from cloned repo β€” models/ is right here
59
+ cloned = Path("models")
60
+ if cloned.exists() and (cloned / "nanogpt_indus.pt").exists():
61
+ data = Path("data") if Path("data").exists() else Path(".")
62
+ return cloned, data
63
+ # Running from original indus_script folder
64
  local = Path("DATA/models")
65
  if local.exists():
66
  return local, Path("DATA")
67
+ # Auto-download from HuggingFace
 
 
 
 
68
  path = download_models()
69
+ return Path(path) / "models", Path(path) / "data"
70
 
71
 
72
  # ── Device ─────────────────────────────────────────────────────
 
80
  # ── Load helpers ───────────────────────────────────────────────
81
  def load_tokenizer(data_dir):
82
  from transformers import PreTrainedTokenizerFast
83
+ # Try data/indus_tokenizer first, then just data_dir itself
84
+ tok_path = data_dir / "indus_tokenizer"
85
+ if not tok_path.exists():
86
+ tok_path = data_dir
87
+ return PreTrainedTokenizerFast.from_pretrained(str(tok_path))
88
 
89
 
90
  def load_bert_mlm(model_dir):