Spaces:
Running
Running
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| # app title | |
| title = "ScriptSense" | |
| description = "<p style='text-align: center; font-size: 22px; font-weight: bold;'>design and crafted by aryan verma</p>" | |
| article = "<p style='text-align: center; font-size: 14px;'>aryan verma | 241306064</p>" | |
| css = "footer {display: none !important;}" | |
| # sample images | |
| examples = [ | |
| ["", "images/1.jpg", "images/1.jpg"], | |
| ["", "images/sample-handwritten-2.PNG", "images/sample-handwritten-2.PNG"] | |
| ] | |
| #you can load any model from huggingface | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') | |
| model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten') | |
| # prediction function for handwritting | |
| def predict(ImageUrl,imgDraw,imgUplod): | |
| image = None | |
| #fetch the image from url or handwritten canvas or the uplaoded image | |
| if ImageUrl: | |
| try: | |
| image = Image.open(requests.get(ImageUrl, stream=True).raw).convert("RGB") | |
| except: | |
| return "Error: Invalid Image URL" | |
| # Prioritize uploaded image if it exists | |
| elif imgUplod is not None: | |
| image = imgUplod.convert("RGB") | |
| # Fallback to the drawing canvas | |
| elif imgDraw is not None: | |
| # Handle Gradio 4+ sketchpad returning a dictionary | |
| if isinstance(imgDraw, dict) and "composite" in imgDraw: | |
| if imgDraw["composite"] is not None: | |
| image = imgDraw["composite"].convert("RGB") | |
| elif not isinstance(imgDraw, dict): | |
| image = imgDraw.convert("RGB") | |
| if image is None: | |
| return "Please provide an image via URL, Sketchpad, or Upload." | |
| #predict the image using microsoft/trocr-large-handwritten model loaded earlier | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
| generated_ids = model.generate(pixel_values) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| #gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=["text", gr.Sketchpad(type="pil"), gr.Image(type="pil")], | |
| outputs="text", | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| css=css | |
| ) | |
| interface.launch() | |