Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| import spaces | |
| # Model configuration | |
| MODEL_PATH = "PaddlePaddle/PaddleOCR-VL" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Task prompts | |
| PROMPTS = { | |
| "OCR": "OCR:", | |
| "Table Recognition": "Table Recognition:", | |
| "Formula Recognition": "Formula Recognition:", | |
| "Chart Recognition": "Chart Recognition:", | |
| } | |
| # Load model and processor | |
| print(f"Loading model on {DEVICE}...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(DEVICE).eval() | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| print("Model loaded successfully!") | |
| def process_image(image, task): | |
| """ | |
| Process an image with PaddleOCR-VL model. | |
| Args: | |
| image: PIL Image or path to image | |
| task: Task type (OCR, Table Recognition, etc.) | |
| Returns: | |
| str: Recognition result | |
| """ | |
| if image is None: | |
| return "Please upload an image first." | |
| # Convert to PIL Image if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| # Get prompt for the task | |
| prompt = PROMPTS.get(task, PROMPTS["OCR"]) | |
| # Prepare messages | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ] | |
| } | |
| ] | |
| # Process with model | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| # Generate output | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=1024) | |
| # Decode and return result | |
| result = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return result | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Radio( | |
| choices=list(PROMPTS.keys()), | |
| value="OCR", | |
| label="Task Type" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Result", lines=10), | |
| title="PaddleOCR-VL: Multilingual Document Parsing", | |
| description="Upload an image and select a task. This model supports OCR in 109 languages, table recognition, formula recognition, and chart recognition.", | |
| examples=[ | |
| ["example.png", "OCR"], | |
| ] if False else None, # Add examples if you upload sample images | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |