| | from io import BytesIO |
| |
|
| | import string |
| | import gradio as gr |
| | import requests |
| | from utils import Endpoint, get_token |
| |
|
| |
|
| | def encode_image(image): |
| | buffered = BytesIO() |
| | image.save(buffered, format="JPEG") |
| | buffered.seek(0) |
| |
|
| | return buffered |
| |
|
| |
|
| | def query_chat_api( |
| | image, prompt, decoding_method, temperature, len_penalty, repetition_penalty |
| | ): |
| |
|
| | url = endpoint.url |
| | url = url + "/api/generate" |
| |
|
| | headers = { |
| | "User-Agent": "BLIP-2 HuggingFace Space", |
| | "Auth-Token": get_token(), |
| | } |
| |
|
| | data = { |
| | "prompt": prompt, |
| | "use_nucleus_sampling": decoding_method == "Nucleus sampling", |
| | "temperature": temperature, |
| | "length_penalty": len_penalty, |
| | "repetition_penalty": repetition_penalty, |
| | } |
| |
|
| | image = encode_image(image) |
| | files = {"image": image} |
| |
|
| | response = requests.post(url, data=data, files=files, headers=headers) |
| |
|
| | if response.status_code == 200: |
| | return response.json() |
| | else: |
| | return "Error: " + response.text |
| |
|
| |
|
| | def query_caption_api( |
| | image, decoding_method, temperature, len_penalty, repetition_penalty |
| | ): |
| |
|
| | url = endpoint.url |
| | url = url + "/api/caption" |
| |
|
| | headers = { |
| | "User-Agent": "BLIP-2 HuggingFace Space", |
| | "Auth-Token": get_token(), |
| | } |
| |
|
| | data = { |
| | "use_nucleus_sampling": decoding_method == "Nucleus sampling", |
| | "temperature": temperature, |
| | "length_penalty": len_penalty, |
| | "repetition_penalty": repetition_penalty, |
| | } |
| |
|
| | image = encode_image(image) |
| | files = {"image": image} |
| |
|
| | response = requests.post(url, data=data, files=files, headers=headers) |
| |
|
| | if response.status_code == 200: |
| | return response.json() |
| | else: |
| | return "Error: " + response.text |
| |
|
| |
|
| | def postprocess_output(output): |
| | |
| | if not output[0][-1] in string.punctuation: |
| | output[0] += "." |
| |
|
| | return output |
| |
|
| |
|
| | def inference_chat( |
| | image, |
| | text_input, |
| | decoding_method, |
| | temperature, |
| | length_penalty, |
| | repetition_penalty, |
| | history=[], |
| | ): |
| | text_input = text_input |
| | history.append(text_input) |
| |
|
| | prompt = " ".join(history) |
| |
|
| | output = query_chat_api( |
| | image, prompt, decoding_method, temperature, length_penalty, repetition_penalty |
| | ) |
| | output = postprocess_output(output) |
| | history += output |
| |
|
| | chat = [ |
| | (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) |
| | ] |
| |
|
| | return {chatbot: chat, state: history} |
| |
|
| |
|
| | def inference_caption( |
| | image, |
| | decoding_method, |
| | temperature, |
| | length_penalty, |
| | repetition_penalty, |
| | ): |
| | output = query_caption_api( |
| | image, decoding_method, temperature, length_penalty, repetition_penalty |
| | ) |
| |
|
| | return output[0] |
| |
|
| |
|
| | title = """<h1 align="center">BLIP-2</h1>""" |
| | description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. |
| | <br> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected.""" |
| | article = """<strong>Paper</strong>: <a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a> |
| | <br> <strong>Code</strong>: BLIP2 is now integrated into GitHub repo: <a href='https://github.com/salesforce/LAVIS' target='_blank'>LAVIS: a One-stop Library for Language and Vision</a> |
| | <br> <strong>🤗 `transformers` integration</strong>: You can now use `transformers` to use our BLIP-2 models! Check out the <a href='https://huggingface.co/docs/transformers/main/en/model_doc/blip-2' target='_blank'> official docs </a> |
| | <p> <strong>Project Page</strong>: <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'> BLIP2 on LAVIS</a> |
| | <br> <strong>Description</strong>: Captioning results from <strong>BLIP2_OPT_6.7B</strong>. Chat results from <strong>BLIP2_FlanT5xxl</strong>. |
| | |
| | <p><strong>We have now suspended the official BLIP2 demo from March 23. 2023. </strong> |
| | <p><strong>For example usage, see notebooks https://github.com/salesforce/LAVIS/tree/main/examples.</strong> |
| | """ |
| |
|
| | endpoint = Endpoint() |
| |
|
| | examples = [ |
| | ["house.png", "How could someone get out of the house?"], |
| | ["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"], |
| | ["pizza.jpg", "What are steps to cook it?"], |
| | ["sunset.jpg", "Here is a romantic message going along the photo:"], |
| | ["forbidden_city.webp", "In what dynasties was this place built?"], |
| | ] |
| |
|
| | with gr.Blocks( |
| | css=""" |
| | .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} |
| | #component-21 > div.wrap.svelte-w6rprc {height: 600px;} |
| | """ |
| | ) as iface: |
| | state = gr.State([]) |
| |
|
| | gr.Markdown(title) |
| | gr.Markdown(description) |
| | gr.Markdown(article) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(type="pil", interactive=False) |
| |
|
| | |
| | sampling = gr.Radio( |
| | choices=["Beam search", "Nucleus sampling"], |
| | value="Beam search", |
| | label="Text Decoding Method", |
| | interactive=True, |
| | ) |
| |
|
| | temperature = gr.Slider( |
| | minimum=0.5, |
| | maximum=1.0, |
| | value=1.0, |
| | step=0.1, |
| | interactive=True, |
| | label="Temperature (used with nucleus sampling)", |
| | ) |
| |
|
| | len_penalty = gr.Slider( |
| | minimum=-1.0, |
| | maximum=2.0, |
| | value=1.0, |
| | step=0.2, |
| | interactive=True, |
| | label="Length Penalty (set to larger for longer sequence, used with beam search)", |
| | ) |
| |
|
| | rep_penalty = gr.Slider( |
| | minimum=1.0, |
| | maximum=5.0, |
| | value=1.5, |
| | step=0.5, |
| | interactive=True, |
| | label="Repeat Penalty (larger value prevents repetition)", |
| | ) |
| |
|
| | with gr.Column(scale=1.8): |
| |
|
| | with gr.Column(): |
| | caption_output = gr.Textbox(lines=1, label="Caption Output") |
| | caption_button = gr.Button( |
| | value="Caption it!", interactive=True, variant="primary" |
| | ) |
| | caption_button.click( |
| | inference_caption, |
| | [ |
| | image_input, |
| | sampling, |
| | temperature, |
| | len_penalty, |
| | rep_penalty, |
| | ], |
| | [caption_output], |
| | ) |
| |
|
| | gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""") |
| | with gr.Row(): |
| | with gr.Column( |
| | scale=1.5, |
| | ): |
| | chatbot = gr.Chatbot( |
| | label="Chat Output (from FlanT5)", |
| | ) |
| |
|
| | |
| | with gr.Column(scale=1): |
| | chat_input = gr.Textbox(lines=1, label="Chat Input") |
| | chat_input.submit( |
| | inference_chat, |
| | [ |
| | image_input, |
| | chat_input, |
| | sampling, |
| | temperature, |
| | len_penalty, |
| | rep_penalty, |
| | state, |
| | ], |
| | [chatbot, state], |
| | ) |
| |
|
| | with gr.Row(): |
| | clear_button = gr.Button(value="Clear", interactive=True) |
| | clear_button.click( |
| | lambda: ("", [], []), |
| | [], |
| | [chat_input, chatbot, state], |
| | queue=False, |
| | ) |
| |
|
| | submit_button = gr.Button( |
| | value="Submit", interactive=True, variant="primary" |
| | ) |
| | submit_button.click( |
| | inference_chat, |
| | [ |
| | image_input, |
| | chat_input, |
| | sampling, |
| | temperature, |
| | len_penalty, |
| | rep_penalty, |
| | state, |
| | ], |
| | [chatbot, state], |
| | ) |
| |
|
| | image_input.change( |
| | lambda: ("", "", []), |
| | [], |
| | [chatbot, caption_output, state], |
| | queue=False, |
| | ) |
| |
|
| | examples = gr.Examples( |
| | examples=examples, |
| | inputs=[image_input, chat_input], |
| | ) |
| |
|
| | iface.queue(concurrency_count=1, api_open=False, max_size=10) |
| | iface.launch(enable_queue=True) |