File size: 4,923 Bytes
c8e426d
 
6a0b93e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eadc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff83735
 
2eadc64
ff83735
2eadc64
dcb04c4
e749740
 
 
 
 
 
 
 
 
 
 
05e5639
 
 
 
 
e749740
 
 
dcb04c4
 
 
 
 
 
 
 
 
 
 
2eadc64
c8e426d
 
 
e5d95e7
c8e426d
 
 
 
e5d95e7
c8e426d
 
e5d95e7
c8e426d
66bf19a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch, torchvision, os
from PIL import Image

# %% image loading
def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor:
    """
    Convert an input image (PIL.Image or numpy array) from Hugging Face to a torch tensor
    of shape (3, height, width) and type float32.

    Args:
        image: Input image (PIL.Image or numpy array).
        width (int): Target width.
        height (int): Target height.

    Returns:
        torch.Tensor: Image tensor of shape (3, height, width).
    """
    image = image if isinstance(image, torch.Tensor) else torchvision.transforms.functional.to_tensor(image)
    return torchvision.transforms.functional.resize(image, [height, width])

# %% preprocessing
def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
    """
    Standardize the image tensor and add batch dimension.

    Args:
        image_tensor (torch.Tensor): Image tensor of shape (3, H, W).

    Returns:
        torch.Tensor: Preprocessed tensor of shape (1, 3, H, W).
    """
    return torchvision.transforms.functional.normalize(
        image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    ).unsqueeze(0)
    
# %% print mask on a sem seg style
def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
    """
        Visualizes the segmentation mask by mapping each class to a specific color.
    
    Args:
        mask (torch.Tensor): The segmentation mask to visualize.
        numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
    """
    colors = [
        (128, 64, 128),  # 0: road
        (244, 35, 232),  # 1: sidewalk
        (70, 70, 70),    # 2: building
        (102, 102, 156), # 3: wall
        (190, 153, 153), # 4: fence
        (153, 153, 153), # 5: pole
        (250, 170, 30),  # 6: traffic light
        (220, 220, 0),   # 7: traffic sign
        (107, 142, 35),  # 8: vegetation
        (152, 251, 152), # 9: terrain
        (70, 130, 180),  # 10: sky
        (220, 20, 60),   # 11: person
        (255, 0, 0),     # 12: rider
        (0, 0, 142),     # 13: car
        (0, 0, 70),      # 14: truck
        (0, 60, 100),    # 15: bus
        (0, 80, 100),    # 16: train
        (0, 0, 230),     # 17: motorcycle
        (119, 11, 32)    # 18: bicycle
    ]

    new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8)
    new_mask[mask == 255] = torch.tensor([0, 0, 0], dtype=torch.uint8)
    for i in range (numClasses):
        new_mask[mask == i] = torch.tensor(colors[i][:3], dtype=torch.uint8)
    return new_mask.permute(2,0,1)  


def legendHandling()->list[int, str, str]:
    """
    Returns a sorted list of tuples containing class IDs, names, and colors for semantic segmentation.
    
    Each tuple contains:
        - Class ID (int)
        - Class name (str)
        - Color description (str)
    The list is sorted by class ID.
    """
    return sorted([[0, "road", "dark purple", (128, 64, 128)], [1, "sidewalk", "light purple / pink", (244, 35, 232)], [2, "building", "dark gray", (70, 70, 70)], 
                    [3, "wall", "blue + grey", (102, 102, 156)], [4, "fence", "beige", (190, 153, 153)], [5, "pole", "grey", (153, 153, 153)], [6, "traffic light", "orange", (250, 170, 30)], 
                    [7, "traffic sign", "yellow", (220, 220, 0)], [8, "vegetation", "dark green", (107, 142, 35)], [9, "terrain", "light green", (152, 251, 152)], [10, "sky", "blue", (70, 130, 180)], 
                    [11, "person", "dark red", (220, 20, 60)], [12, "rider", "light red", (255, 0, 0)], [13, "car", "blue", (0, 0, 142)], [14, "truck", "dark blue", (0, 0, 70)],
                    [15, "bus", "dark blue", (0, 60, 100)], [16, "train", "blue + green", (0, 80, 100)], [17, "motorcycle", "light blue", (0, 0, 230)], [18, "bicycle", "velvet", (119, 11, 32)]
    ], key=lambda x: x[0])


# %% postprocessing
def postprocessing(pred: torch.Tensor) -> torch.Tensor:
    """
    Convert the model's output tensor to a format suitable for visualization.

    Args:
        pred (torch.Tensor): Model output tensor of shape (1, H, W).

    Returns:
        torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
    """ 
    return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))


# %% preloaded images
def loadPreloadedImages(*args:str) -> list[tuple[Image.Image, str]]:
    """
    Load preloaded images from a directory.

    Args:
        args (str): Path to the directory containing images.

    Returns:
        images (list[tuple[Image.Image, str]]): List of loaded images with their original paths.
    """
    return list(map(lambda x:x[0], sorted([[Image.open(os.path.join(imageDir, image)).convert("RGB"), os.path.join(imageDir, image)]
                    for imageDir in args for image in os.listdir(imageDir) if image.endswith((".png", ".jpg", "jpeg"))], key=lambda x: x[1])))