| # # # import gradio as gr | |
| # # # import torch | |
| # # # import torch.nn as nn | |
| # # # from torchvision import models, transforms | |
| # # # from PIL import Image | |
| # # # from transformers import AutoModel, AutoTokenizer | |
| # # # import easyocr | |
| # # # import json | |
| # # # import os | |
| # # # import spaces # добавьте в начале | |
| # # # @spaces.GPU(duration=60) # добавьте перед predict | |
| # # # def predict_demo(image, caption_text=""): | |
| # # # # ... ваш код | |
| # # # # ====================== | |
| # # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!) | |
| # # # # ====================== | |
| # # # # Модели и веса лежат в той же папке, что и app.py | |
| # # # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # # # # Загрузка названий классов | |
| # # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: | |
| # # # id2label = json.load(f) | |
| # # # id2label = {int(k): v for k, v in id2label.items()} | |
| # # # NUM_CLASSES = len(id2label) | |
| # # # # ====================== | |
| # # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием) | |
| # # # # ====================== | |
| # # # @gr.cache_resource | |
| # # # def load_models(): | |
| # # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # # # print(f"Using device: {DEVICE}") | |
| # # # # Визуальный энкодер | |
| # # # visual = models.resnet50(weights=None) | |
| # # # visual.fc = nn.Identity() | |
| # # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE)) | |
| # # # visual.to(DEVICE) | |
| # # # visual.eval() | |
| # # # for p in visual.parameters(): | |
| # # # p.requires_grad = False | |
| # # # # Текстовые энкодеры | |
| # # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
| # # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # # # for p in ocr_encoder.parameters(): | |
| # # # p.requires_grad = False | |
| # # # for p in caption_encoder.parameters(): | |
| # # # p.requires_grad = False | |
| # # # # Классификатор | |
| # # # class ConcatFusionModel(nn.Module): | |
| # # # def __init__(self, num_classes, dropout=0.3): | |
| # # # super().__init__() | |
| # # # self.classifier = nn.Sequential( | |
| # # # nn.Linear(2048 + 312 + 312, 512), | |
| # # # nn.BatchNorm1d(512), | |
| # # # nn.ReLU(), | |
| # # # nn.Dropout(dropout), | |
| # # # nn.Linear(512, num_classes) | |
| # # # ) | |
| # # # def forward(self, v, ocr, cap): | |
| # # # x = torch.cat([v, ocr, cap], dim=1) | |
| # # # return self.classifier(x) | |
| # # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) | |
| # # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE)) | |
| # # # model.to(DEVICE) | |
| # # # model.eval() | |
| # # # # EasyOCR | |
| # # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) | |
| # # # # Трансформы | |
| # # # val_transform = transforms.Compose([ | |
| # # # transforms.Resize(256), | |
| # # # transforms.CenterCrop(224), | |
| # # # transforms.ToTensor(), | |
| # # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| # # # ]) | |
| # # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE | |
| # # # # Загружаем всё при старте | |
| # # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models() | |
| # # # # ====================== | |
| # # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ | |
| # # # # ====================== | |
| # # # def predict(image, caption_text=""): | |
| # # # image = image.convert("RGB") | |
| # # # # OCR | |
| # # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) | |
| # # # ocr_text = " ".join(ocr_result) if ocr_result else "" | |
| # # # # Image | |
| # # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) | |
| # # # with torch.no_grad(): | |
| # # # v = visual(image_tensor) | |
| # # # v = torch.flatten(v, 1) | |
| # # # # OCR encode | |
| # # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt") | |
| # # # ocr_ids = ocr_enc["input_ids"].to(DEVICE) | |
| # # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE) | |
| # # # with torch.no_grad(): | |
| # # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask) | |
| # # # ocr = ocr_out.last_hidden_state[:, 0] | |
| # # # # Caption encode | |
| # # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt") | |
| # # # cap_ids = cap_enc["input_ids"].to(DEVICE) | |
| # # # cap_mask = cap_enc["attention_mask"].to(DEVICE) | |
| # # # with torch.no_grad(): | |
| # # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask) | |
| # # # cap = cap_out.last_hidden_state[:, 0] | |
| # # # # Предсказание | |
| # # # with torch.no_grad(): | |
| # # # logits = model(v, ocr, cap) | |
| # # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| # # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| # # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) | |
| # # # # ====================== | |
| # # # # GRADIO ИНТЕРФЕЙС | |
| # # # # ====================== | |
| # # # demo = gr.Interface( | |
| # # # fn=predict, | |
| # # # inputs=[ | |
| # # # gr.Image(type="pil", label="Загрузите изображение"), | |
| # # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...") | |
| # # # ], | |
| # # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"), | |
| # # # title="Мультимодальный классификатор контента", | |
| # # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" | |
| # # # ) | |
| # # # if __name__ == "__main__": | |
| # # # demo.launch() | |
| # # import gradio as gr | |
| # # import torch | |
| # # import torch.nn as nn | |
| # # from torchvision import models, transforms | |
| # # from PIL import Image | |
| # # from transformers import AutoModel, AutoTokenizer | |
| # # import easyocr | |
| # # import json | |
| # # import os | |
| # # import numpy as np | |
| # # import spaces | |
| # # # ====================== | |
| # # # УСТАНОВКА УСТРОЙСТВА | |
| # # # ====================== | |
| # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # # print(f"Using device: {DEVICE}") | |
| # # # ====================== | |
| # # # ПУТИ | |
| # # # ====================== | |
| # # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # # # Загрузка названий классов | |
| # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: | |
| # # id2label = json.load(f) | |
| # # id2label = {int(k): v for k, v in id2label.items()} | |
| # # NUM_CLASSES = len(id2label) | |
| # # # ====================== | |
| # # # ОПРЕДЕЛЕНИЕ МОДЕЛИ | |
| # # # ====================== | |
| # # class ConcatFusionModel(nn.Module): | |
| # # def __init__(self, num_classes, dropout=0.3): | |
| # # super().__init__() | |
| # # self.classifier = nn.Sequential( | |
| # # nn.Linear(2048 + 312 + 312, 512), | |
| # # nn.BatchNorm1d(512), | |
| # # nn.ReLU(), | |
| # # nn.Dropout(dropout), | |
| # # nn.Linear(512, 256), | |
| # # nn.BatchNorm1d(256), | |
| # # nn.ReLU(), | |
| # # nn.Dropout(0.3), | |
| # # nn.Linear(256, num_classes) | |
| # # ) | |
| # # def forward(self, v, ocr, cap): | |
| # # x = torch.cat([v, ocr, cap], dim=1) | |
| # # return self.classifier(x) | |
| # # # ====================== | |
| # # # ЗАГРУЗКА МОДЕЛЕЙ | |
| # # # ====================== | |
| # # @gr.cache | |
| # # def load_models(): | |
| # # # Визуальный энкодер (загружаем предобученный из torchvision) | |
| # # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| # # visual.fc = nn.Identity() # убираем классификатор | |
| # # visual.to(DEVICE) | |
| # # visual.eval() | |
| # # for p in visual.parameters(): | |
| # # p.requires_grad = False | |
| # # # Текстовые энкодеры (загружаем предобученные из Hugging Face) | |
| # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
| # # ocr_encoder = AutoModel.from_pretrained( | |
| # # "cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # # caption_encoder = AutoModel.from_pretrained( | |
| # # "cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # # for p in ocr_encoder.parameters(): | |
| # # p.requires_grad = False | |
| # # for p in caption_encoder.parameters(): | |
| # # p.requires_grad = False | |
| # # # Классификационная голова (обученная) | |
| # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) | |
| # # model.load_state_dict(torch.load(os.path.join( | |
| # # BASE_DIR, "concat_model.pth"), map_location=DEVICE)) | |
| # # model.to(DEVICE) | |
| # # model.eval() | |
| # # # EasyOCR | |
| # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) | |
| # # # Трансформы для изображений | |
| # # val_transform = transforms.Compose([ | |
| # # transforms.Resize(256), | |
| # # transforms.CenterCrop(224), | |
| # # transforms.ToTensor(), | |
| # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ | |
| # # 0.229, 0.224, 0.225]), | |
| # # ]) | |
| # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform | |
| # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models() | |
| # # # ====================== | |
| # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ | |
| # # # ====================== | |
| # # @spaces.GPU(duration=60) | |
| # # def predict(image, caption_text=""): | |
| # # image = image.convert("RGB") | |
| # # # OCR | |
| # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) | |
| # # ocr_text = " ".join(ocr_result) if ocr_result else "" | |
| # # # Image | |
| # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) | |
| # # with torch.no_grad(): | |
| # # v = visual(image_tensor) | |
| # # v = torch.flatten(v, 1) | |
| # # # OCR encode | |
| # # ocr_enc = tokenizer(ocr_text, truncation=True, | |
| # # padding="max_length", max_length=64, return_tensors="pt") | |
| # # with torch.no_grad(): | |
| # # ocr_out = ocr_encoder( | |
| # # input_ids=ocr_enc["input_ids"].to(DEVICE), | |
| # # attention_mask=ocr_enc["attention_mask"].to(DEVICE) | |
| # # ) | |
| # # ocr = ocr_out.last_hidden_state[:, 0] | |
| # # # Caption encode | |
| # # cap_enc = tokenizer(caption_text, truncation=True, | |
| # # padding="max_length", max_length=128, return_tensors="pt") | |
| # # with torch.no_grad(): | |
| # # cap_out = caption_encoder( | |
| # # input_ids=cap_enc["input_ids"].to(DEVICE), | |
| # # attention_mask=cap_enc["attention_mask"].to(DEVICE) | |
| # # ) | |
| # # cap = cap_out.last_hidden_state[:, 0] | |
| # # # Предсказание | |
| # # with torch.no_grad(): | |
| # # logits = model(v, ocr, cap) | |
| # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) | |
| # # # ====================== | |
| # # # GRADIO ИНТЕРФЕЙС | |
| # # # ====================== | |
| # # demo = gr.Interface( | |
| # # fn=predict, | |
| # # inputs=[ | |
| # # gr.Image(type="pil", label="📸 Загрузите изображение"), | |
| # # gr.Textbox(label="📝 Подпись (необязательно)", | |
| # # placeholder="Введите текст подписи...") | |
| # # ], | |
| # # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), | |
| # # title="Мультимодальный классификатор контента", | |
| # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" | |
| # # ) | |
| # # if __name__ == "__main__": | |
| # # demo.launch() | |
| # import gradio as gr | |
| # import torch | |
| # import torch.nn as nn | |
| # from torchvision import models, transforms | |
| # from PIL import Image | |
| # from transformers import AutoModel, AutoTokenizer | |
| # import easyocr | |
| # import json | |
| # import os | |
| # import numpy as np | |
| # import spaces | |
| # # ====================== | |
| # # УСТАНОВКА УСТРОЙСТВА | |
| # # ====================== | |
| # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"Using device: {DEVICE}") | |
| # # ====================== | |
| # # ПУТИ | |
| # # ====================== | |
| # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # # Загрузка названий классов | |
| # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: | |
| # id2label = json.load(f) | |
| # id2label = {int(k): v for k, v in id2label.items()} | |
| # NUM_CLASSES = len(id2label) | |
| # # ====================== | |
| # # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!) | |
| # # ====================== | |
| # class ConcatFusionModel(nn.Module): | |
| # def __init__(self, num_classes, dropout=0.3): | |
| # super().__init__() | |
| # self.classifier = nn.Sequential( | |
| # nn.Linear(2048 + 312 + 312, 512), | |
| # nn.BatchNorm1d(512), | |
| # nn.ReLU(), | |
| # nn.Dropout(dropout), | |
| # nn.Linear(512, 256), | |
| # nn.BatchNorm1d(256), | |
| # nn.ReLU(), | |
| # nn.Dropout(0.3), | |
| # nn.Linear(256, num_classes) | |
| # ) | |
| # def forward(self, v, ocr, cap): | |
| # x = torch.cat([v, ocr, cap], dim=1) | |
| # return self.classifier(x) | |
| # # ====================== | |
| # # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально) | |
| # # ====================== | |
| # # Визуальный энкодер | |
| # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| # visual.fc = nn.Identity() | |
| # visual.to(DEVICE) | |
| # visual.eval() | |
| # for p in visual.parameters(): | |
| # p.requires_grad = False | |
| # # Текстовые энкодеры | |
| # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
| # ocr_encoder = AutoModel.from_pretrained( | |
| # "cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # caption_encoder = AutoModel.from_pretrained( | |
| # "cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| # for p in ocr_encoder.parameters(): | |
| # p.requires_grad = False | |
| # for p in caption_encoder.parameters(): | |
| # p.requires_grad = False | |
| # # Классификационная голова | |
| # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) | |
| # # model.load_state_dict(torch.load(os.path.join( | |
| # # BASE_DIR, "concat_model_head.pth"), map_location=DEVICE)) | |
| # # model.to(DEVICE) | |
| # # model.eval() | |
| # # В демо-скрипте | |
| # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) | |
| # head_state = torch.load("best_head_only.pth", map_location=DEVICE) | |
| # model.load_state_dict(head_state, strict=False) # strict=False позволяет игнорировать отсутствие энкодеров | |
| # model.to(DEVICE) | |
| # model.eval() | |
| # # EasyOCR | |
| # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) | |
| # # Трансформы | |
| # val_transform = transforms.Compose([ | |
| # transforms.Resize(256), | |
| # transforms.CenterCrop(224), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| # std=[0.229, 0.224, 0.225]), | |
| # ]) | |
| # # ====================== | |
| # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ | |
| # # ====================== | |
| # @spaces.GPU(duration=60) | |
| # def predict(image, caption_text=""): | |
| # image = image.convert("RGB") | |
| # # OCR | |
| # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) | |
| # ocr_text = " ".join(ocr_result) if ocr_result else "" | |
| # # Image | |
| # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) | |
| # with torch.no_grad(): | |
| # v = visual(image_tensor) | |
| # v = torch.flatten(v, 1) | |
| # # OCR encode | |
| # ocr_enc = tokenizer(ocr_text, truncation=True, | |
| # padding="max_length", max_length=64, return_tensors="pt") | |
| # with torch.no_grad(): | |
| # ocr_out = ocr_encoder( | |
| # input_ids=ocr_enc["input_ids"].to(DEVICE), | |
| # attention_mask=ocr_enc["attention_mask"].to(DEVICE) | |
| # ) | |
| # ocr = ocr_out.last_hidden_state[:, 0] | |
| # # Caption encode | |
| # cap_enc = tokenizer(caption_text, truncation=True, | |
| # padding="max_length", max_length=128, return_tensors="pt") | |
| # with torch.no_grad(): | |
| # cap_out = caption_encoder( | |
| # input_ids=cap_enc["input_ids"].to(DEVICE), | |
| # attention_mask=cap_enc["attention_mask"].to(DEVICE) | |
| # ) | |
| # cap = cap_out.last_hidden_state[:, 0] | |
| # # Предсказание | |
| # with torch.no_grad(): | |
| # logits = model(v, ocr, cap) | |
| # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) | |
| # # ====================== | |
| # # GRADIO ИНТЕРФЕЙС | |
| # # ====================== | |
| # demo = gr.Interface( | |
| # fn=predict, | |
| # inputs=[ | |
| # gr.Image(type="pil", label="📸 Загрузите изображение"), | |
| # gr.Textbox(label="📝 Подпись (необязательно)", | |
| # placeholder="Введите текст подписи...") | |
| # ], | |
| # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), | |
| # title="Мультимодальный классификатор контента", | |
| # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer | |
| import easyocr | |
| import json | |
| import os | |
| import numpy as np | |
| # ====================== | |
| # УСТАНОВКА УСТРОЙСТВА | |
| # ====================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {DEVICE}") | |
| # ====================== | |
| # ПУТИ | |
| # ====================== | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Загрузка названий классов | |
| with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: | |
| id2label = json.load(f) | |
| id2label = {int(k): v for k, v in id2label.items()} | |
| NUM_CLASSES = len(id2label) | |
| # ====================== | |
| # ТОЛЬКО ГОЛОВА | |
| # ====================== | |
| class FusionHead(nn.Module): | |
| def __init__(self, num_classes, dropout=0.3): | |
| super().__init__() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(2048 + 312 + 312, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, v, ocr, cap): | |
| x = torch.cat([v, ocr, cap], dim=1) | |
| return self.classifier(x) | |
| # ====================== | |
| # ЗАГРУЗКА ЭНКОДЕРОВ И ГОЛОВЫ | |
| # ====================== | |
| # Визуальный энкодер (предобученный) | |
| visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| visual.fc = nn.Identity() | |
| visual.to(DEVICE) | |
| visual.eval() | |
| for p in visual.parameters(): | |
| p.requires_grad = False | |
| # Текстовые энкодеры | |
| tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
| ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() | |
| for p in ocr_encoder.parameters(): | |
| p.requires_grad = False | |
| for p in caption_encoder.parameters(): | |
| p.requires_grad = False | |
| # Голова (загружаем только веса головы) | |
| model = FusionHead(NUM_CLASSES, dropout=0.3) | |
| head_state = torch.load(os.path.join(BASE_DIR, "concat_model_head.pth"), map_location=DEVICE) | |
| model.load_state_dict(head_state, strict=True) # strict=True, потому что в файле только голова | |
| model.to(DEVICE) | |
| model.eval() | |
| # EasyOCR | |
| reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) | |
| # Трансформы | |
| val_transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ====================== | |
| # ФУНКЦИЯ ПРЕДСКАЗАНИЯ | |
| # ====================== | |
| def predict(image, caption_text=""): | |
| image = image.convert("RGB") | |
| # OCR | |
| ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) | |
| ocr_text = " ".join(ocr_result) if ocr_result else "" | |
| # Image | |
| image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| v = visual(image_tensor) | |
| v = torch.flatten(v, 1) | |
| # OCR encode | |
| ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt") | |
| with torch.no_grad(): | |
| ocr_out = ocr_encoder( | |
| input_ids=ocr_enc["input_ids"].to(DEVICE), | |
| attention_mask=ocr_enc["attention_mask"].to(DEVICE) | |
| ) | |
| ocr = ocr_out.last_hidden_state[:, 0] | |
| # Caption encode | |
| cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt") | |
| with torch.no_grad(): | |
| cap_out = caption_encoder( | |
| input_ids=cap_enc["input_ids"].to(DEVICE), | |
| attention_mask=cap_enc["attention_mask"].to(DEVICE) | |
| ) | |
| cap = cap_out.last_hidden_state[:, 0] | |
| # Предсказание | |
| with torch.no_grad(): | |
| logits = model(v, ocr, cap) | |
| probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) | |
| # ====================== | |
| # GRADIO ИНТЕРФЕЙС | |
| # ====================== | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="📸 Загрузите изображение"), | |
| gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...") | |
| ], | |
| outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), | |
| title="Мультимодальный классификатор контента", | |
| description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |