| import gradio as gr |
| import tensorflow as tf |
| import numpy as np |
| from PIL import Image |
| import os |
| from datasets import load_dataset |
| import random |
|
|
| |
| try: |
| model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5") |
| except: |
| |
| model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5") |
|
|
| |
| CLASS_NAMES = [ |
| "american_football", "baseball", "basketball", "billiard_ball", |
| "bowling_ball", "cricket_ball", "football", "golf_ball", |
| "hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock", |
| "table_tennis_ball", "tennis_ball", "volleyball" |
| ] |
|
|
| def preprocess_image(img, target_size=(225, 225)): |
| """Preprocess image for model prediction""" |
| if isinstance(img, str): |
| img = Image.open(img) |
| |
| img = img.convert("RGB") |
| img = img.resize(target_size) |
| img_array = np.array(img).astype("float32") / 255.0 |
| img_array = np.expand_dims(img_array, axis=0) |
| return img_array |
|
|
| def classify_sports_ball(image): |
| try: |
| |
| input_tensor = preprocess_image(image) |
| |
| |
| predictions = model.predict(input_tensor, verbose=0) |
| probs = predictions[0] |
| |
| |
| class_idx = int(np.argmax(probs)) |
| confidence = float(np.max(probs)) |
| |
| |
| pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} |
| |
| |
| pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)) |
| |
| return pred_dict |
| |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def load_random_dataset_image(): |
| try: |
| dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True) |
| random_idx = random.randint(0, len(dataset) - 1) |
| sample = dataset[random_idx] |
| |
| |
| image = None |
| for col in ["image", "img", "photo", "picture"]: |
| if col in sample: |
| image = sample[col] |
| break |
| |
| if image is None: |
| |
| for col, val in sample.items(): |
| if isinstance(val, Image.Image): |
| image = val |
| break |
| |
| if image is None: |
| return None |
| |
| if not isinstance(image, Image.Image): |
| image = Image.open(image) |
| |
| return image |
| |
| except Exception as e: |
| print(f"Error loading dataset: {e}") |
| return None |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # Sports Ball Classifier |
| |
| Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning |
| to identify 15 different types of sports balls. |
| |
| **Supported Sports Balls:** |
| American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football, |
| Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image( |
| type="pil", |
| label="Upload Sports Ball Image", |
| scale=1 |
| ) |
| with gr.Row(): |
| submit_button = gr.Button("Classify", variant="primary", scale=2) |
| random_button = gr.Button("Random Dataset", variant="secondary", scale=1) |
| |
| with gr.Column(): |
| output = gr.Label(label="Prediction Confidence", num_top_classes=5) |
| |
| with gr.Row(): |
| gr.Markdown( |
| """ |
| ### How to Use: |
| 1. Upload or drag-and-drop an image containing a sports ball |
| 2. Click the 'Classify' button |
| 3. View the prediction results with confidence scores |
| |
| ### Model Details: |
| - Architecture: InceptionV3 (transfer learning from ImageNet) |
| - Training: Two-stage training (feature extraction + fine-tuning) |
| - Accuracy: High performance across all 15 sports ball classes |
| - Preprocessing: Automatic image resizing, normalization, and enhancement |
| """ |
| ) |
| |
| with gr.Row(): |
| gr.Examples( |
| examples=[], |
| inputs=image_input, |
| label="Example Images (Available)", |
| run_on_click=False |
| ) |
| |
| |
| submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output) |
| random_button.click(fn=load_random_dataset_image, outputs=image_input).then( |
| fn=classify_sports_ball, inputs=image_input, outputs=output |
| ) |
| |
| |
| image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |
|
|