| | --- |
| | license: apache-2.0 |
| | base_model: microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft |
| | model-index: |
| | - name: THW |
| | results: |
| | - task: |
| | name: Image Classification |
| | type: image-classification |
| | dataset: |
| | name: None |
| | type: None |
| | config: None |
| | split: None |
| | args: None |
| | metrics: |
| | - name: None |
| | type: None |
| | value: None |
| | --- |
| | <!-- This model card has been generated automatically according to the information the Trainer had access to. You |
| | should probably proofread and complete it, then remove this comment. --> |
| |
|
| | # Normal1919/THW |
| |
|
| | This model is a fine-tuned version of [microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft](https://huggingface.co/microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft) on the private dataset. |
| |
|
| | # How to use |
| |
|
| | ```python |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision |
| | import torchvision.transforms as transforms |
| | from transformers import AutoModelForImageClassification |
| | from matplotlib import pyplot as plt |
| | |
| | model_name = "Normal1919/THW" |
| | |
| | model = AutoModelForImageClassification.from_pretrained(model_name) |
| | model.eval() |
| | # model = torch.compile(model) |
| | |
| | image_transform = transforms.Compose([ |
| | transforms.ToPILImage(), |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.697, 0.633, 0.635], std=[0.3135, 0.320, 0.315]) |
| | ]) |
| | |
| | with torch.no_grad(): |
| | image_raw = torchvision.io.read_image("test_img/c9f00dbb7e8fe20538fcc71b1dc0fbb913029959.png") |
| | if image_raw.size()[0] == 1: |
| | image_raw = torch.cat([image_raw]*3, 0) |
| | if image_raw.size()[0] == 4: |
| | image_raw = image_raw[:3] |
| | edit_image_tensor: torch.Tensor = image_transform(image_raw) |
| | edit_image_tensor = edit_image_tensor.unsqueeze(0) |
| | |
| | outputs = model(pixel_values=edit_image_tensor) |
| | logits = F.sigmoid(outputs.logits)[0] |
| | ind = logits.argmax().item() |
| | print(model.config.id2label[ind]) |
| | |
| | cha_names = [model.config.id2label[i] for i in range(146)] |
| | cha_probs = logits.numpy() |
| | names_probs = list(zip(cha_names, cha_probs)) |
| | names_probs = sorted(names_probs, key=lambda x: x[1], reverse=True) |
| | |
| | print(names_probs) |
| | |
| | top_k = 10 |
| | names_show = [] |
| | probs_show = [] |
| | for i in range(top_k): |
| | names_show.append(names_probs[i][0]) |
| | probs_show.append(names_probs[i][1]) |
| | |
| | plt.rcParams['font.sans-serif'] = ['SimHei'] |
| | plt.figure(figsize=(12, 8)) |
| | plt.bar(names_show, probs_show) |
| | plt.show() |
| | ``` |