Update inference.py
Browse files- inference.py +21 -11
inference.py
CHANGED
|
@@ -31,7 +31,7 @@ import torch.nn.functional as F
|
|
| 31 |
|
| 32 |
|
| 33 |
# ββ Auto-download from HuggingFace ββββββββββββββββββββββββββββ
|
| 34 |
-
HF_REPO = "
|
| 35 |
|
| 36 |
def download_models(repo_id=HF_REPO, local_dir="indus_models"):
|
| 37 |
"""Download all model files from HuggingFace."""
|
|
@@ -48,18 +48,25 @@ def download_models(repo_id=HF_REPO, local_dir="indus_models"):
|
|
| 48 |
|
| 49 |
|
| 50 |
def get_model_dir():
|
| 51 |
-
"""
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
local = Path("DATA/models")
|
| 54 |
if local.exists():
|
| 55 |
return local, Path("DATA")
|
| 56 |
-
#
|
| 57 |
-
downloaded = Path("indus_models")
|
| 58 |
-
if downloaded.exists():
|
| 59 |
-
return downloaded / "models", downloaded
|
| 60 |
-
# Auto-download
|
| 61 |
path = download_models()
|
| 62 |
-
return Path(path) / "models", Path(path)
|
| 63 |
|
| 64 |
|
| 65 |
# ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -73,8 +80,11 @@ PAD_ID = 816
|
|
| 73 |
# ββ Load helpers βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
def load_tokenizer(data_dir):
|
| 75 |
from transformers import PreTrainedTokenizerFast
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def load_bert_mlm(model_dir):
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# ββ Auto-download from HuggingFace ββββββββββββββββββββββββββββ
|
| 34 |
+
HF_REPO = "hellosindh/indus-script-models" # update after upload
|
| 35 |
|
| 36 |
def download_models(repo_id=HF_REPO, local_dir="indus_models"):
|
| 37 |
"""Download all model files from HuggingFace."""
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def get_model_dir():
|
| 51 |
+
"""
|
| 52 |
+
Find model directory.
|
| 53 |
+
Priority:
|
| 54 |
+
1. ./models/ (running from cloned HuggingFace repo)
|
| 55 |
+
2. DATA/models/ (running from original indus_script folder)
|
| 56 |
+
3. Auto-download from HuggingFace
|
| 57 |
+
"""
|
| 58 |
+
# Running from cloned repo β models/ is right here
|
| 59 |
+
cloned = Path("models")
|
| 60 |
+
if cloned.exists() and (cloned / "nanogpt_indus.pt").exists():
|
| 61 |
+
data = Path("data") if Path("data").exists() else Path(".")
|
| 62 |
+
return cloned, data
|
| 63 |
+
# Running from original indus_script folder
|
| 64 |
local = Path("DATA/models")
|
| 65 |
if local.exists():
|
| 66 |
return local, Path("DATA")
|
| 67 |
+
# Auto-download from HuggingFace
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
path = download_models()
|
| 69 |
+
return Path(path) / "models", Path(path) / "data"
|
| 70 |
|
| 71 |
|
| 72 |
# ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 80 |
# ββ Load helpers βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
def load_tokenizer(data_dir):
|
| 82 |
from transformers import PreTrainedTokenizerFast
|
| 83 |
+
# Try data/indus_tokenizer first, then just data_dir itself
|
| 84 |
+
tok_path = data_dir / "indus_tokenizer"
|
| 85 |
+
if not tok_path.exists():
|
| 86 |
+
tok_path = data_dir
|
| 87 |
+
return PreTrainedTokenizerFast.from_pretrained(str(tok_path))
|
| 88 |
|
| 89 |
|
| 90 |
def load_bert_mlm(model_dir):
|