VikTsrv commited on
Commit
93101e3
·
1 Parent(s): 5006afe

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -99
app.py CHANGED
@@ -1,3 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # import gradio as gr
2
  # import torch
3
  # import torch.nn as nn
@@ -7,16 +161,19 @@
7
  # import easyocr
8
  # import json
9
  # import os
 
 
 
10
 
11
- # import spaces # добавьте в начале
 
 
 
 
12
 
13
- # @spaces.GPU(duration=60) # добавьте перед predict
14
- # def predict_demo(image, caption_text=""):
15
- # # ... ваш код
16
  # # ======================
17
- # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
18
  # # ======================
19
- # # Модели и веса лежат в той же папке, что и app.py
20
  # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
 
22
  # # Загрузка названий классов
@@ -28,74 +185,82 @@
28
 
29
 
30
  # # ======================
31
- # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
32
  # # ======================
33
- # @gr.cache_resource
34
- # def load_models():
35
- # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- # print(f"Using device: {DEVICE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # # Визуальный энкодер
39
- # visual = models.resnet50(weights=None)
40
- # visual.fc = nn.Identity()
41
- # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
 
 
 
 
 
42
  # visual.to(DEVICE)
43
  # visual.eval()
44
  # for p in visual.parameters():
45
  # p.requires_grad = False
46
 
47
- # # Текстовые энкодеры
48
  # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
49
- # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
50
- # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
 
 
51
 
52
  # for p in ocr_encoder.parameters():
53
  # p.requires_grad = False
54
  # for p in caption_encoder.parameters():
55
  # p.requires_grad = False
56
 
57
- # # Классификатор
58
- # class ConcatFusionModel(nn.Module):
59
- # def __init__(self, num_classes, dropout=0.3):
60
- # super().__init__()
61
- # self.classifier = nn.Sequential(
62
- # nn.Linear(2048 + 312 + 312, 512),
63
- # nn.BatchNorm1d(512),
64
- # nn.ReLU(),
65
- # nn.Dropout(dropout),
66
- # nn.Linear(512, num_classes)
67
- # )
68
-
69
- # def forward(self, v, ocr, cap):
70
- # x = torch.cat([v, ocr, cap], dim=1)
71
- # return self.classifier(x)
72
-
73
  # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
74
- # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
 
75
  # model.to(DEVICE)
76
  # model.eval()
77
 
78
  # # EasyOCR
79
  # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
80
 
81
- # # Трансформы
82
  # val_transform = transforms.Compose([
83
  # transforms.Resize(256),
84
  # transforms.CenterCrop(224),
85
  # transforms.ToTensor(),
86
- # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
87
  # ])
88
 
89
- # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
90
 
91
 
92
- # # Загружаем всё при старте
93
- # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
94
 
95
 
96
  # # ======================
97
  # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
98
  # # ======================
 
99
  # def predict(image, caption_text=""):
100
  # image = image.convert("RGB")
101
 
@@ -110,19 +275,23 @@
110
  # v = torch.flatten(v, 1)
111
 
112
  # # OCR encode
113
- # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
114
- # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
115
- # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
116
  # with torch.no_grad():
117
- # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
 
 
 
118
  # ocr = ocr_out.last_hidden_state[:, 0]
119
 
120
  # # Caption encode
121
- # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
122
- # cap_ids = cap_enc["input_ids"].to(DEVICE)
123
- # cap_mask = cap_enc["attention_mask"].to(DEVICE)
124
  # with torch.no_grad():
125
- # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
 
 
 
126
  # cap = cap_out.last_hidden_state[:, 0]
127
 
128
  # # Предсказание
@@ -140,10 +309,11 @@
140
  # demo = gr.Interface(
141
  # fn=predict,
142
  # inputs=[
143
- # gr.Image(type="pil", label="Загрузите изображение"),
144
- # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
 
145
  # ],
146
- # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
147
  # title="Мультимодальный классификатор контента",
148
  # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
149
  # )
@@ -151,7 +321,6 @@
151
  # if __name__ == "__main__":
152
  # demo.launch()
153
 
154
-
155
  import gradio as gr
156
  import torch
157
  import torch.nn as nn
@@ -185,7 +354,7 @@ NUM_CLASSES = len(id2label)
185
 
186
 
187
  # ======================
188
- # ОПРЕДЕЛЕНИЕ МОДЕЛИ
189
  # ======================
190
  class ConcatFusionModel(nn.Module):
191
  def __init__(self, num_classes, dropout=0.3):
@@ -208,53 +377,46 @@ class ConcatFusionModel(nn.Module):
208
 
209
 
210
  # ======================
211
- # ЗАГРУЗКА МОДЕЛЕЙ
212
  # ======================
213
- @gr.cache
214
- def load_models():
215
- # Визуальный энкодер (загружаем предобученный из torchvision)
216
- visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
217
- visual.fc = nn.Identity() # убираем классификатор
218
- visual.to(DEVICE)
219
- visual.eval()
220
- for p in visual.parameters():
221
- p.requires_grad = False
222
-
223
- # Текстовые энкодеры (загружаем предобученные из Hugging Face)
224
- tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
225
- ocr_encoder = AutoModel.from_pretrained(
226
- "cointegrated/rubert-tiny2").to(DEVICE).eval()
227
- caption_encoder = AutoModel.from_pretrained(
228
- "cointegrated/rubert-tiny2").to(DEVICE).eval()
229
-
230
- for p in ocr_encoder.parameters():
231
- p.requires_grad = False
232
- for p in caption_encoder.parameters():
233
- p.requires_grad = False
234
-
235
- # Классификационная голова (обученная)
236
- model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
237
- model.load_state_dict(torch.load(os.path.join(
238
- BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
- model.to(DEVICE)
240
- model.eval()
241
-
242
- # EasyOCR
243
- reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
244
-
245
- # Трансформы для изображений
246
- val_transform = transforms.Compose([
247
- transforms.Resize(256),
248
- transforms.CenterCrop(224),
249
- transforms.ToTensor(),
250
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
251
- 0.229, 0.224, 0.225]),
252
- ])
253
-
254
- return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
255
-
256
-
257
- visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
258
 
259
 
260
  # ======================
 
1
+ # # import gradio as gr
2
+ # # import torch
3
+ # # import torch.nn as nn
4
+ # # from torchvision import models, transforms
5
+ # # from PIL import Image
6
+ # # from transformers import AutoModel, AutoTokenizer
7
+ # # import easyocr
8
+ # # import json
9
+ # # import os
10
+
11
+ # # import spaces # добавьте в начале
12
+
13
+ # # @spaces.GPU(duration=60) # добавьте перед predict
14
+ # # def predict_demo(image, caption_text=""):
15
+ # # # ... ваш код
16
+ # # # ======================
17
+ # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
18
+ # # # ======================
19
+ # # # Модели и веса лежат в той же папке, что и app.py
20
+ # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
+
22
+ # # # Загрузка названий классов
23
+ # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
24
+ # # id2label = json.load(f)
25
+ # # id2label = {int(k): v for k, v in id2label.items()}
26
+
27
+ # # NUM_CLASSES = len(id2label)
28
+
29
+
30
+ # # # ======================
31
+ # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
32
+ # # # ======================
33
+ # # @gr.cache_resource
34
+ # # def load_models():
35
+ # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # # print(f"Using device: {DEVICE}")
37
+
38
+ # # # Визуальный энкодер
39
+ # # visual = models.resnet50(weights=None)
40
+ # # visual.fc = nn.Identity()
41
+ # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
42
+ # # visual.to(DEVICE)
43
+ # # visual.eval()
44
+ # # for p in visual.parameters():
45
+ # # p.requires_grad = False
46
+
47
+ # # # Текстовые энкодеры
48
+ # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
49
+ # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
50
+ # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
51
+
52
+ # # for p in ocr_encoder.parameters():
53
+ # # p.requires_grad = False
54
+ # # for p in caption_encoder.parameters():
55
+ # # p.requires_grad = False
56
+
57
+ # # # Классификатор
58
+ # # class ConcatFusionModel(nn.Module):
59
+ # # def __init__(self, num_classes, dropout=0.3):
60
+ # # super().__init__()
61
+ # # self.classifier = nn.Sequential(
62
+ # # nn.Linear(2048 + 312 + 312, 512),
63
+ # # nn.BatchNorm1d(512),
64
+ # # nn.ReLU(),
65
+ # # nn.Dropout(dropout),
66
+ # # nn.Linear(512, num_classes)
67
+ # # )
68
+
69
+ # # def forward(self, v, ocr, cap):
70
+ # # x = torch.cat([v, ocr, cap], dim=1)
71
+ # # return self.classifier(x)
72
+
73
+ # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
74
+ # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
75
+ # # model.to(DEVICE)
76
+ # # model.eval()
77
+
78
+ # # # EasyOCR
79
+ # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
80
+
81
+ # # # Трансформы
82
+ # # val_transform = transforms.Compose([
83
+ # # transforms.Resize(256),
84
+ # # transforms.CenterCrop(224),
85
+ # # transforms.ToTensor(),
86
+ # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
+ # # ])
88
+
89
+ # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
90
+
91
+
92
+ # # # Загружаем всё при старте
93
+ # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
94
+
95
+
96
+ # # # ======================
97
+ # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
98
+ # # # ======================
99
+ # # def predict(image, caption_text=""):
100
+ # # image = image.convert("RGB")
101
+
102
+ # # # OCR
103
+ # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
104
+ # # ocr_text = " ".join(ocr_result) if ocr_result else ""
105
+
106
+ # # # Image
107
+ # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
108
+ # # with torch.no_grad():
109
+ # # v = visual(image_tensor)
110
+ # # v = torch.flatten(v, 1)
111
+
112
+ # # # OCR encode
113
+ # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
114
+ # # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
115
+ # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
116
+ # # with torch.no_grad():
117
+ # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
118
+ # # ocr = ocr_out.last_hidden_state[:, 0]
119
+
120
+ # # # Caption encode
121
+ # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
122
+ # # cap_ids = cap_enc["input_ids"].to(DEVICE)
123
+ # # cap_mask = cap_enc["attention_mask"].to(DEVICE)
124
+ # # with torch.no_grad():
125
+ # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
126
+ # # cap = cap_out.last_hidden_state[:, 0]
127
+
128
+ # # # Предсказание
129
+ # # with torch.no_grad():
130
+ # # logits = model(v, ocr, cap)
131
+ # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
132
+
133
+ # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
134
+ # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
135
+
136
+
137
+ # # # ======================
138
+ # # # GRADIO ИНТЕРФЕЙС
139
+ # # # ======================
140
+ # # demo = gr.Interface(
141
+ # # fn=predict,
142
+ # # inputs=[
143
+ # # gr.Image(type="pil", label="Загрузите изображение"),
144
+ # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
145
+ # # ],
146
+ # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
147
+ # # title="Мультимодальный классификатор контента",
148
+ # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
149
+ # # )
150
+
151
+ # # if __name__ == "__main__":
152
+ # # demo.launch()
153
+
154
+
155
  # import gradio as gr
156
  # import torch
157
  # import torch.nn as nn
 
161
  # import easyocr
162
  # import json
163
  # import os
164
+ # import numpy as np
165
+
166
+ # import spaces
167
 
168
+ # # ======================
169
+ # # УСТАНОВКА УСТРОЙСТВА
170
+ # # ======================
171
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
+ # print(f"Using device: {DEVICE}")
173
 
 
 
 
174
  # # ======================
175
+ # # ПУТИ
176
  # # ======================
 
177
  # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
178
 
179
  # # Загрузка названий классов
 
185
 
186
 
187
  # # ======================
188
+ # # ОПРЕДЕЛЕНИЕ МОДЕЛИ
189
  # # ======================
190
+ # class ConcatFusionModel(nn.Module):
191
+ # def __init__(self, num_classes, dropout=0.3):
192
+ # super().__init__()
193
+ # self.classifier = nn.Sequential(
194
+ # nn.Linear(2048 + 312 + 312, 512),
195
+ # nn.BatchNorm1d(512),
196
+ # nn.ReLU(),
197
+ # nn.Dropout(dropout),
198
+ # nn.Linear(512, 256),
199
+ # nn.BatchNorm1d(256),
200
+ # nn.ReLU(),
201
+ # nn.Dropout(0.3),
202
+ # nn.Linear(256, num_classes)
203
+ # )
204
+
205
+ # def forward(self, v, ocr, cap):
206
+ # x = torch.cat([v, ocr, cap], dim=1)
207
+ # return self.classifier(x)
208
 
209
+
210
+ # # ======================
211
+ # # ЗАГРУЗКА МОДЕЛЕЙ
212
+ # # ======================
213
+ # @gr.cache
214
+ # def load_models():
215
+ # # Визуальный энкодер (загружаем предобученный из torchvision)
216
+ # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
217
+ # visual.fc = nn.Identity() # убираем классификатор
218
  # visual.to(DEVICE)
219
  # visual.eval()
220
  # for p in visual.parameters():
221
  # p.requires_grad = False
222
 
223
+ # # Текстовые энкодеры (загружаем предобученные из Hugging Face)
224
  # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
225
+ # ocr_encoder = AutoModel.from_pretrained(
226
+ # "cointegrated/rubert-tiny2").to(DEVICE).eval()
227
+ # caption_encoder = AutoModel.from_pretrained(
228
+ # "cointegrated/rubert-tiny2").to(DEVICE).eval()
229
 
230
  # for p in ocr_encoder.parameters():
231
  # p.requires_grad = False
232
  # for p in caption_encoder.parameters():
233
  # p.requires_grad = False
234
 
235
+ # # Классификационная голова (обученная)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
237
+ # model.load_state_dict(torch.load(os.path.join(
238
+ # BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
  # model.to(DEVICE)
240
  # model.eval()
241
 
242
  # # EasyOCR
243
  # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
244
 
245
+ # # Трансформы для изображений
246
  # val_transform = transforms.Compose([
247
  # transforms.Resize(256),
248
  # transforms.CenterCrop(224),
249
  # transforms.ToTensor(),
250
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
251
+ # 0.229, 0.224, 0.225]),
252
  # ])
253
 
254
+ # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
255
 
256
 
257
+ # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
 
258
 
259
 
260
  # # ======================
261
  # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
262
  # # ======================
263
+ # @spaces.GPU(duration=60)
264
  # def predict(image, caption_text=""):
265
  # image = image.convert("RGB")
266
 
 
275
  # v = torch.flatten(v, 1)
276
 
277
  # # OCR encode
278
+ # ocr_enc = tokenizer(ocr_text, truncation=True,
279
+ # padding="max_length", max_length=64, return_tensors="pt")
 
280
  # with torch.no_grad():
281
+ # ocr_out = ocr_encoder(
282
+ # input_ids=ocr_enc["input_ids"].to(DEVICE),
283
+ # attention_mask=ocr_enc["attention_mask"].to(DEVICE)
284
+ # )
285
  # ocr = ocr_out.last_hidden_state[:, 0]
286
 
287
  # # Caption encode
288
+ # cap_enc = tokenizer(caption_text, truncation=True,
289
+ # padding="max_length", max_length=128, return_tensors="pt")
 
290
  # with torch.no_grad():
291
+ # cap_out = caption_encoder(
292
+ # input_ids=cap_enc["input_ids"].to(DEVICE),
293
+ # attention_mask=cap_enc["attention_mask"].to(DEVICE)
294
+ # )
295
  # cap = cap_out.last_hidden_state[:, 0]
296
 
297
  # # Предсказание
 
309
  # demo = gr.Interface(
310
  # fn=predict,
311
  # inputs=[
312
+ # gr.Image(type="pil", label="📸 Загрузите изображение"),
313
+ # gr.Textbox(label="📝 Подпись (необязательно)",
314
+ # placeholder="Введите текст подписи...")
315
  # ],
316
+ # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
317
  # title="Мультимодальный классификатор контента",
318
  # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
319
  # )
 
321
  # if __name__ == "__main__":
322
  # demo.launch()
323
 
 
324
  import gradio as gr
325
  import torch
326
  import torch.nn as nn
 
354
 
355
 
356
  # ======================
357
+ # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!)
358
  # ======================
359
  class ConcatFusionModel(nn.Module):
360
  def __init__(self, num_classes, dropout=0.3):
 
377
 
378
 
379
  # ======================
380
+ # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально)
381
  # ======================
382
+ # Визуальный энкодер
383
+ visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
384
+ visual.fc = nn.Identity()
385
+ visual.to(DEVICE)
386
+ visual.eval()
387
+ for p in visual.parameters():
388
+ p.requires_grad = False
389
+
390
+ # Текстовые энкодеры
391
+ tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
392
+ ocr_encoder = AutoModel.from_pretrained(
393
+ "cointegrated/rubert-tiny2").to(DEVICE).eval()
394
+ caption_encoder = AutoModel.from_pretrained(
395
+ "cointegrated/rubert-tiny2").to(DEVICE).eval()
396
+
397
+ for p in ocr_encoder.parameters():
398
+ p.requires_grad = False
399
+ for p in caption_encoder.parameters():
400
+ p.requires_grad = False
401
+
402
+ # Классификационная голова
403
+ model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
404
+ model.load_state_dict(torch.load(os.path.join(
405
+ BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
406
+ model.to(DEVICE)
407
+ model.eval()
408
+
409
+ # EasyOCR
410
+ reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
411
+
412
+ # Трансформы
413
+ val_transform = transforms.Compose([
414
+ transforms.Resize(256),
415
+ transforms.CenterCrop(224),
416
+ transforms.ToTensor(),
417
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
418
+ std=[0.229, 0.224, 0.225]),
419
+ ])
 
 
 
 
 
 
 
420
 
421
 
422
  # ======================