# # # import gradio as gr # # # import torch # # # import torch.nn as nn # # # from torchvision import models, transforms # # # from PIL import Image # # # from transformers import AutoModel, AutoTokenizer # # # import easyocr # # # import json # # # import os # # # import spaces # добавьте в начале # # # @spaces.GPU(duration=60) # добавьте перед predict # # # def predict_demo(image, caption_text=""): # # # # ... ваш код # # # # ====================== # # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!) # # # # ====================== # # # # Модели и веса лежат в той же папке, что и app.py # # # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # # # # Загрузка названий классов # # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: # # # id2label = json.load(f) # # # id2label = {int(k): v for k, v in id2label.items()} # # # NUM_CLASSES = len(id2label) # # # # ====================== # # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием) # # # # ====================== # # # @gr.cache_resource # # # def load_models(): # # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # # print(f"Using device: {DEVICE}") # # # # Визуальный энкодер # # # visual = models.resnet50(weights=None) # # # visual.fc = nn.Identity() # # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE)) # # # visual.to(DEVICE) # # # visual.eval() # # # for p in visual.parameters(): # # # p.requires_grad = False # # # # Текстовые энкодеры # # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") # # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() # # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() # # # for p in ocr_encoder.parameters(): # # # p.requires_grad = False # # # for p in caption_encoder.parameters(): # # # p.requires_grad = False # # # # Классификатор # # # class ConcatFusionModel(nn.Module): # # # def __init__(self, num_classes, dropout=0.3): # # # super().__init__() # # # self.classifier = nn.Sequential( # # # nn.Linear(2048 + 312 + 312, 512), # # # nn.BatchNorm1d(512), # # # nn.ReLU(), # # # nn.Dropout(dropout), # # # nn.Linear(512, num_classes) # # # ) # # # def forward(self, v, ocr, cap): # # # x = torch.cat([v, ocr, cap], dim=1) # # # return self.classifier(x) # # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) # # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE)) # # # model.to(DEVICE) # # # model.eval() # # # # EasyOCR # # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) # # # # Трансформы # # # val_transform = transforms.Compose([ # # # transforms.Resize(256), # # # transforms.CenterCrop(224), # # # transforms.ToTensor(), # # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # # # ]) # # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE # # # # Загружаем всё при старте # # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models() # # # # ====================== # # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ # # # # ====================== # # # def predict(image, caption_text=""): # # # image = image.convert("RGB") # # # # OCR # # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) # # # ocr_text = " ".join(ocr_result) if ocr_result else "" # # # # Image # # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) # # # with torch.no_grad(): # # # v = visual(image_tensor) # # # v = torch.flatten(v, 1) # # # # OCR encode # # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt") # # # ocr_ids = ocr_enc["input_ids"].to(DEVICE) # # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE) # # # with torch.no_grad(): # # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask) # # # ocr = ocr_out.last_hidden_state[:, 0] # # # # Caption encode # # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt") # # # cap_ids = cap_enc["input_ids"].to(DEVICE) # # # cap_mask = cap_enc["attention_mask"].to(DEVICE) # # # with torch.no_grad(): # # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask) # # # cap = cap_out.last_hidden_state[:, 0] # # # # Предсказание # # # with torch.no_grad(): # # # logits = model(v, ocr, cap) # # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} # # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) # # # # ====================== # # # # GRADIO ИНТЕРФЕЙС # # # # ====================== # # # demo = gr.Interface( # # # fn=predict, # # # inputs=[ # # # gr.Image(type="pil", label="Загрузите изображение"), # # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...") # # # ], # # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"), # # # title="Мультимодальный классификатор контента", # # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" # # # ) # # # if __name__ == "__main__": # # # demo.launch() # # import gradio as gr # # import torch # # import torch.nn as nn # # from torchvision import models, transforms # # from PIL import Image # # from transformers import AutoModel, AutoTokenizer # # import easyocr # # import json # # import os # # import numpy as np # # import spaces # # # ====================== # # # УСТАНОВКА УСТРОЙСТВА # # # ====================== # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # print(f"Using device: {DEVICE}") # # # ====================== # # # ПУТИ # # # ====================== # # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # # # Загрузка названий классов # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: # # id2label = json.load(f) # # id2label = {int(k): v for k, v in id2label.items()} # # NUM_CLASSES = len(id2label) # # # ====================== # # # ОПРЕДЕЛЕНИЕ МОДЕЛИ # # # ====================== # # class ConcatFusionModel(nn.Module): # # def __init__(self, num_classes, dropout=0.3): # # super().__init__() # # self.classifier = nn.Sequential( # # nn.Linear(2048 + 312 + 312, 512), # # nn.BatchNorm1d(512), # # nn.ReLU(), # # nn.Dropout(dropout), # # nn.Linear(512, 256), # # nn.BatchNorm1d(256), # # nn.ReLU(), # # nn.Dropout(0.3), # # nn.Linear(256, num_classes) # # ) # # def forward(self, v, ocr, cap): # # x = torch.cat([v, ocr, cap], dim=1) # # return self.classifier(x) # # # ====================== # # # ЗАГРУЗКА МОДЕЛЕЙ # # # ====================== # # @gr.cache # # def load_models(): # # # Визуальный энкодер (загружаем предобученный из torchvision) # # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # # visual.fc = nn.Identity() # убираем классификатор # # visual.to(DEVICE) # # visual.eval() # # for p in visual.parameters(): # # p.requires_grad = False # # # Текстовые энкодеры (загружаем предобученные из Hugging Face) # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") # # ocr_encoder = AutoModel.from_pretrained( # # "cointegrated/rubert-tiny2").to(DEVICE).eval() # # caption_encoder = AutoModel.from_pretrained( # # "cointegrated/rubert-tiny2").to(DEVICE).eval() # # for p in ocr_encoder.parameters(): # # p.requires_grad = False # # for p in caption_encoder.parameters(): # # p.requires_grad = False # # # Классификационная голова (обученная) # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) # # model.load_state_dict(torch.load(os.path.join( # # BASE_DIR, "concat_model.pth"), map_location=DEVICE)) # # model.to(DEVICE) # # model.eval() # # # EasyOCR # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) # # # Трансформы для изображений # # val_transform = transforms.Compose([ # # transforms.Resize(256), # # transforms.CenterCrop(224), # # transforms.ToTensor(), # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ # # 0.229, 0.224, 0.225]), # # ]) # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models() # # # ====================== # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ # # # ====================== # # @spaces.GPU(duration=60) # # def predict(image, caption_text=""): # # image = image.convert("RGB") # # # OCR # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) # # ocr_text = " ".join(ocr_result) if ocr_result else "" # # # Image # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) # # with torch.no_grad(): # # v = visual(image_tensor) # # v = torch.flatten(v, 1) # # # OCR encode # # ocr_enc = tokenizer(ocr_text, truncation=True, # # padding="max_length", max_length=64, return_tensors="pt") # # with torch.no_grad(): # # ocr_out = ocr_encoder( # # input_ids=ocr_enc["input_ids"].to(DEVICE), # # attention_mask=ocr_enc["attention_mask"].to(DEVICE) # # ) # # ocr = ocr_out.last_hidden_state[:, 0] # # # Caption encode # # cap_enc = tokenizer(caption_text, truncation=True, # # padding="max_length", max_length=128, return_tensors="pt") # # with torch.no_grad(): # # cap_out = caption_encoder( # # input_ids=cap_enc["input_ids"].to(DEVICE), # # attention_mask=cap_enc["attention_mask"].to(DEVICE) # # ) # # cap = cap_out.last_hidden_state[:, 0] # # # Предсказание # # with torch.no_grad(): # # logits = model(v, ocr, cap) # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) # # # ====================== # # # GRADIO ИНТЕРФЕЙС # # # ====================== # # demo = gr.Interface( # # fn=predict, # # inputs=[ # # gr.Image(type="pil", label="📸 Загрузите изображение"), # # gr.Textbox(label="📝 Подпись (необязательно)", # # placeholder="Введите текст подписи...") # # ], # # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), # # title="Мультимодальный классификатор контента", # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" # # ) # # if __name__ == "__main__": # # demo.launch() # import gradio as gr # import torch # import torch.nn as nn # from torchvision import models, transforms # from PIL import Image # from transformers import AutoModel, AutoTokenizer # import easyocr # import json # import os # import numpy as np # import spaces # # ====================== # # УСТАНОВКА УСТРОЙСТВА # # ====================== # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"Using device: {DEVICE}") # # ====================== # # ПУТИ # # ====================== # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # # Загрузка названий классов # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: # id2label = json.load(f) # id2label = {int(k): v for k, v in id2label.items()} # NUM_CLASSES = len(id2label) # # ====================== # # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!) # # ====================== # class ConcatFusionModel(nn.Module): # def __init__(self, num_classes, dropout=0.3): # super().__init__() # self.classifier = nn.Sequential( # nn.Linear(2048 + 312 + 312, 512), # nn.BatchNorm1d(512), # nn.ReLU(), # nn.Dropout(dropout), # nn.Linear(512, 256), # nn.BatchNorm1d(256), # nn.ReLU(), # nn.Dropout(0.3), # nn.Linear(256, num_classes) # ) # def forward(self, v, ocr, cap): # x = torch.cat([v, ocr, cap], dim=1) # return self.classifier(x) # # ====================== # # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально) # # ====================== # # Визуальный энкодер # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # visual.fc = nn.Identity() # visual.to(DEVICE) # visual.eval() # for p in visual.parameters(): # p.requires_grad = False # # Текстовые энкодеры # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") # ocr_encoder = AutoModel.from_pretrained( # "cointegrated/rubert-tiny2").to(DEVICE).eval() # caption_encoder = AutoModel.from_pretrained( # "cointegrated/rubert-tiny2").to(DEVICE).eval() # for p in ocr_encoder.parameters(): # p.requires_grad = False # for p in caption_encoder.parameters(): # p.requires_grad = False # # Классификационная голова # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) # # model.load_state_dict(torch.load(os.path.join( # # BASE_DIR, "concat_model_head.pth"), map_location=DEVICE)) # # model.to(DEVICE) # # model.eval() # # В демо-скрипте # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3) # head_state = torch.load("best_head_only.pth", map_location=DEVICE) # model.load_state_dict(head_state, strict=False) # strict=False позволяет игнорировать отсутствие энкодеров # model.to(DEVICE) # model.eval() # # EasyOCR # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) # # Трансформы # val_transform = transforms.Compose([ # transforms.Resize(256), # transforms.CenterCrop(224), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]), # ]) # # ====================== # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ # # ====================== # @spaces.GPU(duration=60) # def predict(image, caption_text=""): # image = image.convert("RGB") # # OCR # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) # ocr_text = " ".join(ocr_result) if ocr_result else "" # # Image # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) # with torch.no_grad(): # v = visual(image_tensor) # v = torch.flatten(v, 1) # # OCR encode # ocr_enc = tokenizer(ocr_text, truncation=True, # padding="max_length", max_length=64, return_tensors="pt") # with torch.no_grad(): # ocr_out = ocr_encoder( # input_ids=ocr_enc["input_ids"].to(DEVICE), # attention_mask=ocr_enc["attention_mask"].to(DEVICE) # ) # ocr = ocr_out.last_hidden_state[:, 0] # # Caption encode # cap_enc = tokenizer(caption_text, truncation=True, # padding="max_length", max_length=128, return_tensors="pt") # with torch.no_grad(): # cap_out = caption_encoder( # input_ids=cap_enc["input_ids"].to(DEVICE), # attention_mask=cap_enc["attention_mask"].to(DEVICE) # ) # cap = cap_out.last_hidden_state[:, 0] # # Предсказание # with torch.no_grad(): # logits = model(v, ocr, cap) # probs = torch.softmax(logits, dim=1)[0].cpu().numpy() # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) # # ====================== # # GRADIO ИНТЕРФЕЙС # # ====================== # demo = gr.Interface( # fn=predict, # inputs=[ # gr.Image(type="pil", label="📸 Загрузите изображение"), # gr.Textbox(label="📝 Подпись (необязательно)", # placeholder="Введите текст подписи...") # ], # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), # title="Мультимодальный классификатор контента", # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" # ) # if __name__ == "__main__": # demo.launch() import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image from transformers import AutoModel, AutoTokenizer import easyocr import json import os import numpy as np # ====================== # УСТАНОВКА УСТРОЙСТВА # ====================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {DEVICE}") # ====================== # ПУТИ # ====================== BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Загрузка названий классов with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f: id2label = json.load(f) id2label = {int(k): v for k, v in id2label.items()} NUM_CLASSES = len(id2label) # ====================== # ТОЛЬКО ГОЛОВА # ====================== class FusionHead(nn.Module): def __init__(self, num_classes, dropout=0.3): super().__init__() self.classifier = nn.Sequential( nn.Linear(2048 + 312 + 312, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) def forward(self, v, ocr, cap): x = torch.cat([v, ocr, cap], dim=1) return self.classifier(x) # ====================== # ЗАГРУЗКА ЭНКОДЕРОВ И ГОЛОВЫ # ====================== # Визуальный энкодер (предобученный) visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) visual.fc = nn.Identity() visual.to(DEVICE) visual.eval() for p in visual.parameters(): p.requires_grad = False # Текстовые энкодеры tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval() for p in ocr_encoder.parameters(): p.requires_grad = False for p in caption_encoder.parameters(): p.requires_grad = False # Голова (загружаем только веса головы) model = FusionHead(NUM_CLASSES, dropout=0.3) head_state = torch.load(os.path.join(BASE_DIR, "concat_model_head.pth"), map_location=DEVICE) model.load_state_dict(head_state, strict=True) # strict=True, потому что в файле только голова model.to(DEVICE) model.eval() # EasyOCR reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda")) # Трансформы val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ====================== # ФУНКЦИЯ ПРЕДСКАЗАНИЯ # ====================== def predict(image, caption_text=""): image = image.convert("RGB") # OCR ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True) ocr_text = " ".join(ocr_result) if ocr_result else "" # Image image_tensor = val_transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): v = visual(image_tensor) v = torch.flatten(v, 1) # OCR encode ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt") with torch.no_grad(): ocr_out = ocr_encoder( input_ids=ocr_enc["input_ids"].to(DEVICE), attention_mask=ocr_enc["attention_mask"].to(DEVICE) ) ocr = ocr_out.last_hidden_state[:, 0] # Caption encode cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt") with torch.no_grad(): cap_out = caption_encoder( input_ids=cap_enc["input_ids"].to(DEVICE), attention_mask=cap_enc["attention_mask"].to(DEVICE) ) cap = cap_out.last_hidden_state[:, 0] # Предсказание with torch.no_grad(): logits = model(v, ocr, cap) probs = torch.softmax(logits, dim=1)[0].cpu().numpy() result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)} return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) # ====================== # GRADIO ИНТЕРФЕЙС # ====================== demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="📸 Загрузите изображение"), gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...") ], outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"), title="Мультимодальный классификатор контента", description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)" ) if __name__ == "__main__": demo.launch()