File size: 1,619 Bytes
c525245
 
87748d8
de8eaee
c525245
 
de8eaee
 
 
c525245
de8eaee
 
87748d8
de8eaee
c525245
87748d8
de8eaee
 
87748d8
 
 
 
de8eaee
 
 
c525245
87748d8
de8eaee
 
 
 
 
87748d8
de8eaee
c525245
de8eaee
 
 
c525245
 
de8eaee
87748d8
c525245
de8eaee
 
c525245
 
de8eaee
 
 
17f3020
 
 
87748d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

BASE    = "microsoft/phi-3-mini-4k-instruct"
ADAPTER = "Znilsson/survivalai-phi3-lora"   # <-- replace if your adapter repo ID differs
TOKEN   = os.environ.get("HF_TOKEN")

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE, trust_remote_code=True)

print("Loading base model (fp16)...")
model = AutoModelForCausalLM.from_pretrained(
    BASE,
    dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
)

print("Attaching + merging LoRA adapter...")
model = PeftModel.from_pretrained(model, ADAPTER, token=TOKEN)
model = model.merge_and_unload()
model.eval()

def chat(message, history):
    prompt = ""
    for user, assistant in history:
        prompt += f"<|user|>\n{user}<|end|>\n<|assistant|>\n{assistant}<|end|>\n"
    prompt += f"<|user|>\n{message}<|end|>\n<|assistant|>\n"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=400,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    resp = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return resp.strip()

demo = gr.ChatInterface(
    fn=chat,
    title="SurvivalAI",
    description="Fine-tuned Phi-3-mini on survival & emergency preparedness corpus.",
)

if __name__ == "__main__":
    demo.launch()