File size: 3,091 Bytes
234780c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e558c0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM

# Use CPU as requested
device = "cpu"

def load_vlm(model_name):
    """Helper to load model and processor."""
    try:
        print(f"Loading {model_name}...")
        model = AutoModelForCausalLM.from_pretrained(
            f'microsoft/{model_name}', 
            trust_remote_code=True
        ).to(device).eval()
        processor = AutoProcessor.from_pretrained(
            f'microsoft/{model_name}', 
            trust_remote_code=True
        )
        return model, processor
    except Exception as e:
        print(f"Error loading {model_name}: {e}")
        return None, None

# Load both models
model_base, proc_base = load_vlm('Florence-2-base')
model_large, proc_large = load_vlm('Florence-2-large')

def describe_image(uploaded_image, model_choice):
    if uploaded_image is None:
        return "Please upload an image."

    # Select model based on UI choice
    if model_choice == "Florence-2-base":
        model, processor = model_base, proc_base
    else:
        model, processor = model_large, proc_large

    if model is None:
        return f"{model_choice} failed to load."

    if not isinstance(uploaded_image, Image.Image):
        uploaded_image = Image.fromarray(uploaded_image)

    # Core generation logic
    inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3,
            do_sample=False,
        )

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    result = processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(uploaded_image.width, uploaded_image.height)
    )
    
    return result["<MORE_DETAILED_CAPTION>"]

# Simplified Gradio Layout
css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }"

with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
    gr.Markdown("# **Florence-2 Models Image Captions**")
    gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Image", type="pil")
            model_choice = gr.Radio(
                choices=["Florence-2-base", "Florence-2-large"], 
                label="Model Choice", 
                value="Florence-2-base"
            )
            generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn")
            
        with gr.Column():
            output = gr.Textbox(label="Generated Caption", lines=6, interactive=True)

    generate_btn.click(
        fn=describe_image, 
        inputs=[image_input, model_choice], 
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)