| | import os |
| | import subprocess |
| | import spaces |
| | import torch |
| |
|
| | import gradio as gr |
| |
|
| | from gradio_client.client import DEFAULT_TEMP_DIR |
| | from playwright.sync_api import sync_playwright |
| | from threading import Thread |
| | from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer |
| | from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension |
| | from typing import List |
| | from PIL import Image |
| |
|
| | from transformers.image_transforms import resize, to_channel_dimension_format |
| |
|
| |
|
| | subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
| |
|
| | DEVICE = torch.device("cuda") |
| | PROCESSOR = AutoProcessor.from_pretrained( |
| | "HuggingFaceM4/VLM_WebSight_finetuned", |
| | ) |
| | MODEL = AutoModelForCausalLM.from_pretrained( |
| | "HuggingFaceM4/VLM_WebSight_finetuned", |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | ).to(DEVICE) |
| | if MODEL.config.use_resampler: |
| | image_seq_len = MODEL.config.perceiver_config.resampler_n_latents |
| | else: |
| | image_seq_len = ( |
| | MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size |
| | ) ** 2 |
| | BOS_TOKEN = PROCESSOR.tokenizer.bos_token |
| | BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids |
| |
|
| |
|
| | |
| |
|
| | def convert_to_rgb(image): |
| | |
| | |
| | if image.mode == "RGB": |
| | return image |
| |
|
| | image_rgba = image.convert("RGBA") |
| | background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) |
| | alpha_composite = Image.alpha_composite(background, image_rgba) |
| | alpha_composite = alpha_composite.convert("RGB") |
| | return alpha_composite |
| |
|
| | |
| | |
| | def custom_transform(x): |
| | x = convert_to_rgb(x) |
| | x = to_numpy_array(x) |
| | x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) |
| | x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) |
| | x = PROCESSOR.image_processor.normalize( |
| | x, |
| | mean=PROCESSOR.image_processor.image_mean, |
| | std=PROCESSOR.image_processor.image_std |
| | ) |
| | x = to_channel_dimension_format(x, ChannelDimension.FIRST) |
| | x = torch.tensor(x) |
| | return x |
| |
|
| | |
| |
|
| |
|
| | IMAGE_GALLERY_PATHS = [ |
| | f"example_images/{ex_image}" |
| | for ex_image in os.listdir(f"example_images") |
| | ] |
| |
|
| |
|
| | def install_playwright(): |
| | try: |
| | subprocess.run(["playwright", "install"], check=True) |
| | print("Playwright installation successful.") |
| | except subprocess.CalledProcessError as e: |
| | print(f"Error during Playwright installation: {e}") |
| |
|
| | install_playwright() |
| |
|
| |
|
| | def add_file_gallery( |
| | selected_state: gr.SelectData, |
| | gallery_list: List[str] |
| | ): |
| | return Image.open(gallery_list.root[selected_state.index].image.path) |
| |
|
| |
|
| | def render_webpage( |
| | html_css_code, |
| | ): |
| | with sync_playwright() as p: |
| | browser = p.chromium.launch(headless=True) |
| | context = browser.new_context( |
| | user_agent=( |
| | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" |
| | " Safari/537.36" |
| | ) |
| | ) |
| | page = context.new_page() |
| | page.set_content(html_css_code) |
| | page.wait_for_load_state("networkidle") |
| | output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" |
| | _ = page.screenshot(path=output_path_screenshot, full_page=True) |
| |
|
| | context.close() |
| | browser.close() |
| |
|
| | return Image.open(output_path_screenshot) |
| |
|
| |
|
| | @spaces.GPU(duration=180) |
| | def model_inference( |
| | image, |
| | ): |
| | if image is None: |
| | raise ValueError("`image` is None. It should be a PIL image.") |
| |
|
| | inputs = PROCESSOR.tokenizer( |
| | f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>", |
| | return_tensors="pt", |
| | add_special_tokens=False, |
| | ) |
| | inputs["pixel_values"] = PROCESSOR.image_processor( |
| | [image], |
| | transform=custom_transform |
| | ) |
| | inputs = { |
| | k: v.to(DEVICE) |
| | for k, v in inputs.items() |
| | } |
| |
|
| | streamer = TextIteratorStreamer( |
| | PROCESSOR.tokenizer, |
| | skip_prompt=True, |
| | ) |
| | generation_kwargs = dict( |
| | inputs, |
| | bad_words_ids=BAD_WORDS_IDS, |
| | max_length=4096, |
| | streamer=streamer, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | thread = Thread( |
| | target=MODEL.generate, |
| | kwargs=generation_kwargs, |
| | ) |
| | thread.start() |
| | generated_text = "" |
| | for new_text in streamer: |
| | if "</s>" in new_text: |
| | new_text = new_text.replace("</s>", "") |
| | rendered_image = render_webpage(generated_text) |
| | else: |
| | rendered_image = None |
| | generated_text += new_text |
| | yield generated_text, rendered_image |
| |
|
| |
|
| | generated_html = gr.Code( |
| | label="Extracted HTML", |
| | elem_id="generated_html", |
| | ) |
| | rendered_html = gr.Image( |
| | label="Rendered HTML", |
| | show_download_button=False, |
| | show_share_button=False, |
| | ) |
| | |
| | |
| | |
| |
|
| |
|
| | css = """ |
| | .gradio-container{max-width: 1000px!important} |
| | h1{display: flex;align-items: center;justify-content: center;gap: .25em} |
| | *{transition: width 0.5s ease, flex-grow 0.5s ease} |
| | """ |
| |
|
| |
|
| | with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo: |
| | gr.Markdown( |
| | "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content." |
| | ) |
| | with gr.Row(equal_height=True): |
| | with gr.Column(scale=4, min_width=250) as upload_area: |
| | imagebox = gr.Image( |
| | type="pil", |
| | label="Screenshot to extract", |
| | visible=True, |
| | sources=["upload", "clipboard"], |
| | ) |
| | with gr.Group(): |
| | with gr.Row(): |
| | submit_btn = gr.Button( |
| | value="▶️ Submit", visible=True, min_width=120 |
| | ) |
| | clear_btn = gr.ClearButton( |
| | [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 |
| | ) |
| | regenerate_btn = gr.Button( |
| | value="🔄 Regenerate", visible=True, min_width=120 |
| | ) |
| | with gr.Column(scale=4): |
| | rendered_html.render() |
| |
|
| | with gr.Row(): |
| | generated_html.render() |
| |
|
| | with gr.Row(): |
| | template_gallery = gr.Gallery( |
| | value=IMAGE_GALLERY_PATHS, |
| | label="Templates Gallery", |
| | allow_preview=False, |
| | columns=5, |
| | elem_id="gallery", |
| | show_share_button=False, |
| | height=400, |
| | ) |
| |
|
| | gr.on( |
| | triggers=[ |
| | imagebox.upload, |
| | submit_btn.click, |
| | regenerate_btn.click, |
| | ], |
| | fn=model_inference, |
| | inputs=[imagebox], |
| | outputs=[generated_html, rendered_html], |
| | ) |
| | regenerate_btn.click( |
| | fn=model_inference, |
| | inputs=[imagebox], |
| | outputs=[generated_html, rendered_html], |
| | ) |
| | template_gallery.select( |
| | fn=add_file_gallery, |
| | inputs=[template_gallery], |
| | outputs=[imagebox], |
| | ).success( |
| | fn=model_inference, |
| | inputs=[imagebox], |
| | outputs=[generated_html, rendered_html], |
| | ) |
| | demo.load() |
| |
|
| | demo.queue(max_size=40, api_open=False) |
| | demo.launch(max_threads=400) |
| |
|