VikTsrv's picture
fix app
61c75a3
# # # import gradio as gr
# # # import torch
# # # import torch.nn as nn
# # # from torchvision import models, transforms
# # # from PIL import Image
# # # from transformers import AutoModel, AutoTokenizer
# # # import easyocr
# # # import json
# # # import os
# # # import spaces # добавьте в начале
# # # @spaces.GPU(duration=60) # добавьте перед predict
# # # def predict_demo(image, caption_text=""):
# # # # ... ваш код
# # # # ======================
# # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
# # # # ======================
# # # # Модели и веса лежат в той же папке, что и app.py
# # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# # # # Загрузка названий классов
# # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
# # # id2label = json.load(f)
# # # id2label = {int(k): v for k, v in id2label.items()}
# # # NUM_CLASSES = len(id2label)
# # # # ======================
# # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
# # # # ======================
# # # @gr.cache_resource
# # # def load_models():
# # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # # print(f"Using device: {DEVICE}")
# # # # Визуальный энкодер
# # # visual = models.resnet50(weights=None)
# # # visual.fc = nn.Identity()
# # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
# # # visual.to(DEVICE)
# # # visual.eval()
# # # for p in visual.parameters():
# # # p.requires_grad = False
# # # # Текстовые энкодеры
# # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
# # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
# # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
# # # for p in ocr_encoder.parameters():
# # # p.requires_grad = False
# # # for p in caption_encoder.parameters():
# # # p.requires_grad = False
# # # # Классификатор
# # # class ConcatFusionModel(nn.Module):
# # # def __init__(self, num_classes, dropout=0.3):
# # # super().__init__()
# # # self.classifier = nn.Sequential(
# # # nn.Linear(2048 + 312 + 312, 512),
# # # nn.BatchNorm1d(512),
# # # nn.ReLU(),
# # # nn.Dropout(dropout),
# # # nn.Linear(512, num_classes)
# # # )
# # # def forward(self, v, ocr, cap):
# # # x = torch.cat([v, ocr, cap], dim=1)
# # # return self.classifier(x)
# # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
# # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
# # # model.to(DEVICE)
# # # model.eval()
# # # # EasyOCR
# # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
# # # # Трансформы
# # # val_transform = transforms.Compose([
# # # transforms.Resize(256),
# # # transforms.CenterCrop(224),
# # # transforms.ToTensor(),
# # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# # # ])
# # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
# # # # Загружаем всё при старте
# # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
# # # # ======================
# # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
# # # # ======================
# # # def predict(image, caption_text=""):
# # # image = image.convert("RGB")
# # # # OCR
# # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
# # # ocr_text = " ".join(ocr_result) if ocr_result else ""
# # # # Image
# # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
# # # with torch.no_grad():
# # # v = visual(image_tensor)
# # # v = torch.flatten(v, 1)
# # # # OCR encode
# # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
# # # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
# # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
# # # with torch.no_grad():
# # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
# # # ocr = ocr_out.last_hidden_state[:, 0]
# # # # Caption encode
# # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
# # # cap_ids = cap_enc["input_ids"].to(DEVICE)
# # # cap_mask = cap_enc["attention_mask"].to(DEVICE)
# # # with torch.no_grad():
# # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
# # # cap = cap_out.last_hidden_state[:, 0]
# # # # Предсказание
# # # with torch.no_grad():
# # # logits = model(v, ocr, cap)
# # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
# # # # ======================
# # # # GRADIO ИНТЕРФЕЙС
# # # # ======================
# # # demo = gr.Interface(
# # # fn=predict,
# # # inputs=[
# # # gr.Image(type="pil", label="Загрузите изображение"),
# # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
# # # ],
# # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
# # # title="Мультимодальный классификатор контента",
# # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
# # # )
# # # if __name__ == "__main__":
# # # demo.launch()
# # import gradio as gr
# # import torch
# # import torch.nn as nn
# # from torchvision import models, transforms
# # from PIL import Image
# # from transformers import AutoModel, AutoTokenizer
# # import easyocr
# # import json
# # import os
# # import numpy as np
# # import spaces
# # # ======================
# # # УСТАНОВКА УСТРОЙСТВА
# # # ======================
# # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # print(f"Using device: {DEVICE}")
# # # ======================
# # # ПУТИ
# # # ======================
# # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# # # Загрузка названий классов
# # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
# # id2label = json.load(f)
# # id2label = {int(k): v for k, v in id2label.items()}
# # NUM_CLASSES = len(id2label)
# # # ======================
# # # ОПРЕДЕЛЕНИЕ МОДЕЛИ
# # # ======================
# # class ConcatFusionModel(nn.Module):
# # def __init__(self, num_classes, dropout=0.3):
# # super().__init__()
# # self.classifier = nn.Sequential(
# # nn.Linear(2048 + 312 + 312, 512),
# # nn.BatchNorm1d(512),
# # nn.ReLU(),
# # nn.Dropout(dropout),
# # nn.Linear(512, 256),
# # nn.BatchNorm1d(256),
# # nn.ReLU(),
# # nn.Dropout(0.3),
# # nn.Linear(256, num_classes)
# # )
# # def forward(self, v, ocr, cap):
# # x = torch.cat([v, ocr, cap], dim=1)
# # return self.classifier(x)
# # # ======================
# # # ЗАГРУЗКА МОДЕЛЕЙ
# # # ======================
# # @gr.cache
# # def load_models():
# # # Визуальный энкодер (загружаем предобученный из torchvision)
# # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
# # visual.fc = nn.Identity() # убираем классификатор
# # visual.to(DEVICE)
# # visual.eval()
# # for p in visual.parameters():
# # p.requires_grad = False
# # # Текстовые энкодеры (загружаем предобученные из Hugging Face)
# # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
# # ocr_encoder = AutoModel.from_pretrained(
# # "cointegrated/rubert-tiny2").to(DEVICE).eval()
# # caption_encoder = AutoModel.from_pretrained(
# # "cointegrated/rubert-tiny2").to(DEVICE).eval()
# # for p in ocr_encoder.parameters():
# # p.requires_grad = False
# # for p in caption_encoder.parameters():
# # p.requires_grad = False
# # # Классификационная голова (обученная)
# # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
# # model.load_state_dict(torch.load(os.path.join(
# # BASE_DIR, "concat_model.pth"), map_location=DEVICE))
# # model.to(DEVICE)
# # model.eval()
# # # EasyOCR
# # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
# # # Трансформы для изображений
# # val_transform = transforms.Compose([
# # transforms.Resize(256),
# # transforms.CenterCrop(224),
# # transforms.ToTensor(),
# # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
# # 0.229, 0.224, 0.225]),
# # ])
# # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
# # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
# # # ======================
# # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
# # # ======================
# # @spaces.GPU(duration=60)
# # def predict(image, caption_text=""):
# # image = image.convert("RGB")
# # # OCR
# # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
# # ocr_text = " ".join(ocr_result) if ocr_result else ""
# # # Image
# # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
# # with torch.no_grad():
# # v = visual(image_tensor)
# # v = torch.flatten(v, 1)
# # # OCR encode
# # ocr_enc = tokenizer(ocr_text, truncation=True,
# # padding="max_length", max_length=64, return_tensors="pt")
# # with torch.no_grad():
# # ocr_out = ocr_encoder(
# # input_ids=ocr_enc["input_ids"].to(DEVICE),
# # attention_mask=ocr_enc["attention_mask"].to(DEVICE)
# # )
# # ocr = ocr_out.last_hidden_state[:, 0]
# # # Caption encode
# # cap_enc = tokenizer(caption_text, truncation=True,
# # padding="max_length", max_length=128, return_tensors="pt")
# # with torch.no_grad():
# # cap_out = caption_encoder(
# # input_ids=cap_enc["input_ids"].to(DEVICE),
# # attention_mask=cap_enc["attention_mask"].to(DEVICE)
# # )
# # cap = cap_out.last_hidden_state[:, 0]
# # # Предсказание
# # with torch.no_grad():
# # logits = model(v, ocr, cap)
# # probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
# # # ======================
# # # GRADIO ИНТЕРФЕЙС
# # # ======================
# # demo = gr.Interface(
# # fn=predict,
# # inputs=[
# # gr.Image(type="pil", label="📸 Загрузите изображение"),
# # gr.Textbox(label="📝 Подпись (необязательно)",
# # placeholder="Введите текст подписи...")
# # ],
# # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
# # title="Мультимодальный классификатор контента",
# # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
# # )
# # if __name__ == "__main__":
# # demo.launch()
# import gradio as gr
# import torch
# import torch.nn as nn
# from torchvision import models, transforms
# from PIL import Image
# from transformers import AutoModel, AutoTokenizer
# import easyocr
# import json
# import os
# import numpy as np
# import spaces
# # ======================
# # УСТАНОВКА УСТРОЙСТВА
# # ======================
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {DEVICE}")
# # ======================
# # ПУТИ
# # ======================
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# # Загрузка названий классов
# with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
# id2label = json.load(f)
# id2label = {int(k): v for k, v in id2label.items()}
# NUM_CLASSES = len(id2label)
# # ======================
# # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!)
# # ======================
# class ConcatFusionModel(nn.Module):
# def __init__(self, num_classes, dropout=0.3):
# super().__init__()
# self.classifier = nn.Sequential(
# nn.Linear(2048 + 312 + 312, 512),
# nn.BatchNorm1d(512),
# nn.ReLU(),
# nn.Dropout(dropout),
# nn.Linear(512, 256),
# nn.BatchNorm1d(256),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(256, num_classes)
# )
# def forward(self, v, ocr, cap):
# x = torch.cat([v, ocr, cap], dim=1)
# return self.classifier(x)
# # ======================
# # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально)
# # ======================
# # Визуальный энкодер
# visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
# visual.fc = nn.Identity()
# visual.to(DEVICE)
# visual.eval()
# for p in visual.parameters():
# p.requires_grad = False
# # Текстовые энкодеры
# tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
# ocr_encoder = AutoModel.from_pretrained(
# "cointegrated/rubert-tiny2").to(DEVICE).eval()
# caption_encoder = AutoModel.from_pretrained(
# "cointegrated/rubert-tiny2").to(DEVICE).eval()
# for p in ocr_encoder.parameters():
# p.requires_grad = False
# for p in caption_encoder.parameters():
# p.requires_grad = False
# # Классификационная голова
# # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
# # model.load_state_dict(torch.load(os.path.join(
# # BASE_DIR, "concat_model_head.pth"), map_location=DEVICE))
# # model.to(DEVICE)
# # model.eval()
# # В демо-скрипте
# model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
# head_state = torch.load("best_head_only.pth", map_location=DEVICE)
# model.load_state_dict(head_state, strict=False) # strict=False позволяет игнорировать отсутствие энкодеров
# model.to(DEVICE)
# model.eval()
# # EasyOCR
# reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
# # Трансформы
# val_transform = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]),
# ])
# # ======================
# # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
# # ======================
# @spaces.GPU(duration=60)
# def predict(image, caption_text=""):
# image = image.convert("RGB")
# # OCR
# ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
# ocr_text = " ".join(ocr_result) if ocr_result else ""
# # Image
# image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
# with torch.no_grad():
# v = visual(image_tensor)
# v = torch.flatten(v, 1)
# # OCR encode
# ocr_enc = tokenizer(ocr_text, truncation=True,
# padding="max_length", max_length=64, return_tensors="pt")
# with torch.no_grad():
# ocr_out = ocr_encoder(
# input_ids=ocr_enc["input_ids"].to(DEVICE),
# attention_mask=ocr_enc["attention_mask"].to(DEVICE)
# )
# ocr = ocr_out.last_hidden_state[:, 0]
# # Caption encode
# cap_enc = tokenizer(caption_text, truncation=True,
# padding="max_length", max_length=128, return_tensors="pt")
# with torch.no_grad():
# cap_out = caption_encoder(
# input_ids=cap_enc["input_ids"].to(DEVICE),
# attention_mask=cap_enc["attention_mask"].to(DEVICE)
# )
# cap = cap_out.last_hidden_state[:, 0]
# # Предсказание
# with torch.no_grad():
# logits = model(v, ocr, cap)
# probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
# result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
# # ======================
# # GRADIO ИНТЕРФЕЙС
# # ======================
# demo = gr.Interface(
# fn=predict,
# inputs=[
# gr.Image(type="pil", label="📸 Загрузите изображение"),
# gr.Textbox(label="📝 Подпись (необязательно)",
# placeholder="Введите текст подписи...")
# ],
# outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
# title="Мультимодальный классификатор контента",
# description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
# )
# if __name__ == "__main__":
# demo.launch()
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import easyocr
import json
import os
import numpy as np
# ======================
# УСТАНОВКА УСТРОЙСТВА
# ======================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
# ======================
# ПУТИ
# ======================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Загрузка названий классов
with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
id2label = json.load(f)
id2label = {int(k): v for k, v in id2label.items()}
NUM_CLASSES = len(id2label)
# ======================
# ТОЛЬКО ГОЛОВА
# ======================
class FusionHead(nn.Module):
def __init__(self, num_classes, dropout=0.3):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(2048 + 312 + 312, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, v, ocr, cap):
x = torch.cat([v, ocr, cap], dim=1)
return self.classifier(x)
# ======================
# ЗАГРУЗКА ЭНКОДЕРОВ И ГОЛОВЫ
# ======================
# Визуальный энкодер (предобученный)
visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
visual.fc = nn.Identity()
visual.to(DEVICE)
visual.eval()
for p in visual.parameters():
p.requires_grad = False
# Текстовые энкодеры
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
for p in ocr_encoder.parameters():
p.requires_grad = False
for p in caption_encoder.parameters():
p.requires_grad = False
# Голова (загружаем только веса головы)
model = FusionHead(NUM_CLASSES, dropout=0.3)
head_state = torch.load(os.path.join(BASE_DIR, "concat_model_head.pth"), map_location=DEVICE)
model.load_state_dict(head_state, strict=True) # strict=True, потому что в файле только голова
model.to(DEVICE)
model.eval()
# EasyOCR
reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
# Трансформы
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ======================
# ФУНКЦИЯ ПРЕДСКАЗАНИЯ
# ======================
def predict(image, caption_text=""):
image = image.convert("RGB")
# OCR
ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
ocr_text = " ".join(ocr_result) if ocr_result else ""
# Image
image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
v = visual(image_tensor)
v = torch.flatten(v, 1)
# OCR encode
ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
with torch.no_grad():
ocr_out = ocr_encoder(
input_ids=ocr_enc["input_ids"].to(DEVICE),
attention_mask=ocr_enc["attention_mask"].to(DEVICE)
)
ocr = ocr_out.last_hidden_state[:, 0]
# Caption encode
cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
with torch.no_grad():
cap_out = caption_encoder(
input_ids=cap_enc["input_ids"].to(DEVICE),
attention_mask=cap_enc["attention_mask"].to(DEVICE)
)
cap = cap_out.last_hidden_state[:, 0]
# Предсказание
with torch.no_grad():
logits = model(v, ocr, cap)
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
# ======================
# GRADIO ИНТЕРФЕЙС
# ======================
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="📸 Загрузите изображение"),
gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...")
],
outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
title="Мультимодальный классификатор контента",
description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
)
if __name__ == "__main__":
demo.launch()