| import itertools |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| plt.rcParams['font.sans-serif'] = ['SimHei'] |
| plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
| |
|
|
| def plot_confusion_matrix(cm, classes, |
| normalize=False, |
| title='Confusion matrix', |
| cmap=plt.cm.Blues): |
| """ |
| This function prints and plots the confusion matrix. |
| Normalization can be applied by setting `normalize=True`. |
| """ |
| if normalize: |
| cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
| print("Normalized confusion matrix") |
| else: |
| print('Confusion matrix, without normalization') |
|
|
| print(cm) |
|
|
| plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) |
| plt.title(title) |
| plt.colorbar() |
| tick_marks = np.arange(len(classes)) |
| plt.xticks(tick_marks, classes, fontsize=16) |
| plt.yticks(tick_marks, classes, fontsize=16) |
|
|
| fmt = '.2f' if normalize else 'd' |
| thresh = cm.max() / 2. |
| for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
| plt.text(j, i, format(cm[i, j], fmt), |
| horizontalalignment="center", |
| color="white" if cm[i, j] > thresh else "black") |
|
|
| plt.tight_layout() |
| plt.ylabel('True Label',fontsize=12) |
| plt.xlabel('Predicted Label',fontsize=12) |
| plt.show() |
|
|
|
|
|
|
| cnf_matrix = np.array([[ 299 , 6 , 5 , 3 , 1 , 4, 11], |
| [ 9, 51 , 0, 2 , 8, 2 , 2], |
| [ 2 , 1 ,120 , 6 ,13 , 9 , 9], |
| [ 5 , 1 , 7 ,1148 , 2 , 4 , 18], |
| [ 0 , 0 , 9 , 4 ,442 , 1 , 22], |
| [ 2 ,0 , 7 , 3 , 0 ,145 , 5], |
| [ 10 ,0, 6 ,11, 29 , 0, 624]]) |
|
|
| class_names = ["SU", 'FE', 'AN', 'HA', 'SA', 'DI', 'NE'] |
|
|
|
|
| plt.figure(dpi=200) |
| plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, |
| title=None) |
|
|