File size: 5,043 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from hydra.core.hydra_config import HydraConfig
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.metrics import confusion_matrix
from PIL import Image
from typing import Dict


def vis_training_curves(history=None, output_dir: str = None) -> None:
    """
    Visualizes the training curves of the model.

    Args:
        history: The history object returned by the model.fit() method.
        output_dir (Optional[str]): The output directory to save the training curves plot.

    Returns:
        None
    """
    # Extract the accuracy and loss values for training and validation data
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    # Create dataframes for the training and validation data
    df_val = pd.DataFrame(
        {'run': 'validation', 'step': epochs_range, 'epoch_accuracy': val_acc, 'epoch_loss': val_loss})
    df_train = pd.DataFrame({'run': 'train', 'step': epochs_range, 'epoch_accuracy': acc, 'epoch_loss': loss})

    # Concatenate the dataframes
    frames = [df_val, df_train]
    df = pd.concat(frames)
    df = df.reset_index()

    # Plot the training curves
    plt.figure(figsize=(16, 6))
    plt.subplot(1, 2, 1)
    sns.lineplot(data=df, x="step", y="epoch_accuracy", hue="run").set_title("accuracy")
    plt.grid()
    plt.subplot(1, 2, 2)
    sns.lineplot(data=df, x="step", y="epoch_loss", hue="run").set_title("loss")
    plt.grid()
    plt.savefig(os.path.join(output_dir, 'Training_curves.png'))
    

def plot_confusion_matrix(cm: np.ndarray = None,
                          class_names: list = None,
                          title: str = "f",
                          model_name: str = "f",
                          output_dir: str = None) -> None:
    """
    Plots a confusion matrix using seaborn and saves it as an image.

    Args:
        cm (numpy.ndarray): The confusion matrix to plot.
        class_names (list): A list of class names.title : str,  Pre-pended to model test accuracy in the figure title
        model_name (str): The name of the model.
        output_dir (str): The directory where to save the image.

    Returns:
        None
    """
    plt.figure(figsize=(14, 14))
    confusion_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    axis_labels = list(class_names)
    if len(class_names) > 20:
        sns.set(font_scale=0.5)
        plt.xticks(rotation=45, ha='right')
    sns.heatmap(confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels, cmap='Blues',
                annot=True, fmt='.2f', square=True)
    plt.title(title, fontsize=10)
    plt.tight_layout(pad=3)
    plt.ylabel("True Label", fontsize=10)
    plt.xlabel("Predicted Label", fontsize=10)
    plt.savefig(os.path.join(output_dir, f"{model_name}.png"))

def display_figures(cfg: Dict = None):
    """
    Displays all the figures created during the execution of current run stored in output_dir.
    """
    if cfg.general.display_figures:
        # Get a list of all the PNG files in it and display it
        png_files = [f for f in os.listdir(cfg.output_dir) if f.endswith('.png')]
        for png_file in png_files:
            img = Image.open(os.path.join(cfg.output_dir, png_file))
            img.show()

def compute_confusion_matrix2(y_true, y_pred):
    '''Computes a confusion matrix for multiclass monolabel classification
       Takes one-hot encoded labels and predictions as input.
       Inputs
       ------
       y_true : np.ndarray, (n_samples, n_classes) True labels. Must be one-hot encoded labels
       y_pred : np.ndarray, (n_samples, n_classes) Predicted labels. Must be one-hot encoded labels.
       
       Outputs
       -------
       matrix : ndarray, (n_classes, n_classes) : Confusion matrix'''
    # Convert one-hot vectors to integer
    y_pred = np.argmax(y_pred, axis=1)
    y_true = np.argmax(y_true, axis=1)

    matrix = confusion_matrix(y_true, y_pred)

    return matrix

def compute_multilabel_confusion_matrices():
    """
    Compute confusion matrices for multilabel inference.
    Outputs 1 2x2 matrix per class.
    """
    print("this feature is not available yet!")
    pass

def plot_multilabel_confusion_matrices():
    """Plots confusion matrices for multilabel inference.
    Plots 1 2x2 matrix per class."""
    print("this feature is not available yet!")
    pass