File size: 3,397 Bytes
7a8afa9
40b1cc9
7a8afa9
0fbc572
7a8afa9
2aa22b3
 
 
 
40b1cc9
 
2aa22b3
3c29912
217c8d8
 
7a8afa9
217c8d8
7a8afa9
217c8d8
7a8afa9
217c8d8
7a8afa9
 
 
2aa22b3
 
7a8afa9
 
 
40b1cc9
 
 
 
0fbc572
40b1cc9
35799ef
0fbc572
 
 
 
 
35799ef
 
 
 
40b1cc9
 
7a8afa9
 
 
 
 
 
 
 
 
 
40b1cc9
 
 
 
 
 
35799ef
 
 
40b1cc9
 
 
 
2aa22b3
 
 
7a8afa9
 
 
 
 
 
 
 
 
 
2aa22b3
7a8afa9
2aa22b3
 
 
 
 
 
7a8afa9
 
2aa22b3
7a8afa9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from ddgs import DDGS

SYSTEM_PROMPT = """You are Stack 2.9, an expert AI coding assistant.
- Answer questions naturally and helpfully
- When the user asks for code, write clean complete code
- When the user asks a question, answer in plain language
- Be concise and practical
- If asked to search the internet, use the search: command"""

MODEL_NAME = "/Users/walidsobhi/stack-2-9-final-model"

print(f"Loading {MODEL_NAME} from HuggingFace...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("✅ Ready!\n")

# Generation settings
MAX_TOKENS = 200
TEMPERATURE = 0.4
TOP_P = 0.9
REP_PENALTY = 1.2

print(f"Settings: max_tokens={MAX_TOKENS}, temperature={TEMPERATURE}, top_p={TOP_P}")
print("Commands: search:<query> - search the web, quit/exit - stop\n")

def web_search(query, count=5):
    """Search the web using DuckDuckGo (no API key needed)"""
    try:
        results = []
        with DDGS() as ddgs:
            for r in ddgs.text(query, max_results=count):
                results.append(f"{r['body'][:200]}")
                if len(results) >= count:
                    break
        
        if results:
            return {"success": True, "results": results, "query": query}
        return {"success": False, "error": "No results found"}
    except Exception as e:
        return {"success": False, "error": str(e)}

# Interactive loop
while True:
    try:
        prompt = input("You: ")
        if prompt.lower() in ['quit', 'exit', 'q']:
            break
        if not prompt.strip():
            continue

        # Handle search command
        if prompt.lower().startswith("search:"):
            query = prompt[7:].strip()
            print("🔍 Searching...")
            result = web_search(query)
            if result["success"]:
                print(f"✅ Results for '{result['query']}':\n")
                for i, r in enumerate(result["results"], 1):
                    print(f"  {i}. {r}")
            else:
                print(f"❌ Search failed: {result['error']}")
            continue

        # Prepend system prompt
        full_prompt = f"{SYSTEM_PROMPT}\n\nUser: {prompt}\nAssistant:"
        inputs = tokenizer(full_prompt, return_tensors='pt').to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            repetition_penalty=REP_PENALTY,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode full response
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract only the assistant's response (after "Assistant:")
        if "Assistant:" in full_response:
            response = full_response.split("Assistant:")[-1].strip()
        else:
            response = full_response[len(full_prompt):].strip()

        # Stop at common stop points
        for stop in ['\n\n\n', 'User:', 'You:']:
            if stop in response:
                response = response.split(stop)[0].strip()

        print(f"AI: {response}\n")

    except KeyboardInterrupt:
        print("\nExiting...")
        break

print("Goodbye!")