File size: 2,879 Bytes
b5c3af0
 
15db5ed
 
 
b5c3af0
15db5ed
b5c3af0
15db5ed
b5c3af0
15db5ed
 
b5c3af0
15db5ed
 
 
 
 
c997032
 
15db5ed
 
b5c3af0
15db5ed
 
 
b5c3af0
15db5ed
 
b5c3af0
15db5ed
 
 
 
b5c3af0
15db5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5c3af0
15db5ed
b5c3af0
 
15db5ed
 
 
b5c3af0
 
 
 
 
15db5ed
b5c3af0
 
 
 
 
 
15db5ed
b5c3af0
15db5ed
 
 
 
 
 
 
 
 
 
b5c3af0
 
15db5ed
 
 
 
 
 
 
b5c3af0
 
15db5ed
 
 
 
 
 
b5c3af0
 
15db5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5c3af0
 
15db5ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "lazarus19/AuroraImageGen"

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    device_map="auto"
)

# Generate function
def generate(
    prompt,
    max_new_tokens,
    temperature,
    top_p,
):
    if not prompt.strip():
        return "Please enter a prompt."

    inputs = tokenizer(
        prompt,
        return_tensors="pt"
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(
        outputs[0],
        skip_special_tokens=True
    )

    return response

examples = [
    "Write a short story about a robot explorer.",
    "Explain quantum computing in simple terms.",
    "Create a fantasy character profile.",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 900px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):

        gr.Markdown("# AuroraImageGen Chat")

        prompt = gr.Textbox(
            label="Prompt",
            lines=6,
            placeholder="Enter your prompt..."
        )

        output = gr.Textbox(
            label="Response",
            lines=20
        )

        with gr.Accordion("Advanced Settings", open=False):

            max_new_tokens = gr.Slider(
                minimum=32,
                maximum=2048,
                value=512,
                step=32,
                label="Max New Tokens"
            )

            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.7,
                step=0.1,
                label="Temperature"
            )

            top_p = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.9,
                step=0.05,
                label="Top-P"
            )

        run_button = gr.Button(
            "Generate",
            variant="primary"
        )

        gr.Examples(
            examples=examples,
            inputs=[prompt]
        )

        run_button.click(
            fn=generate,
            inputs=[
                prompt,
                max_new_tokens,
                temperature,
                top_p,
            ],
            outputs=output,
        )

if __name__ == "__main__":
    demo.launch()