| import gradio as gr |
| import lime |
| from lime.lime_text import LimeTextExplainer |
| import numpy as np |
| from datasets import load_dataset |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.model_selection import train_test_split |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.pipeline import make_pipeline |
| import shap |
| import matplotlib.pyplot as plt |
| import io |
| from PIL import Image |
| import pandas as pd |
|
|
|
|
| |
| dataset = load_dataset('imdb') |
|
|
| |
| text_train = [review['text'] for review in dataset['train']] |
| y_train = [review['label'] for review in dataset['train']] |
| text_test = [review['text'] for review in dataset['test']] |
| y_test = [review['label'] for review in dataset['test']] |
|
|
| |
| vectorizer = TfidfVectorizer(stop_words='english', max_features=5000) |
| X_train = vectorizer.fit_transform(text_train) |
| X_test = vectorizer.transform(text_test) |
|
|
| |
| X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42) |
|
|
| |
| model = LogisticRegression(max_iter=1000) |
| model.fit(X_train, y_train) |
|
|
| |
| lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive']) |
|
|
| |
| shap_explainer = shap.LinearExplainer(model, X_train) |
|
|
| def explain_text(input_text): |
| |
| input_vector = vectorizer.transform([input_text]) |
| predicted_label = model.predict(input_vector)[0] |
| label_name = 'Positive' if predicted_label == 1 else 'Negative' |
|
|
| |
| def predict_proba_for_lime(texts): |
| return model.predict_proba(vectorizer.transform(texts)) |
|
|
| lime_exp = lime_explainer.explain_instance(input_text, predict_proba_for_lime, num_features=10) |
| lime_fig = lime_exp.as_pyplot_figure() |
| lime_img = fig_to_nparray(lime_fig) |
|
|
| |
| lime_html = lime_exp.as_html() |
|
|
| |
| shap_values = shap_explainer.shap_values(input_vector)[0] |
| feature_names = vectorizer.get_feature_names_out() |
|
|
| |
| shap_explanation = shap.Explanation( |
| values=shap_values, |
| base_values=shap_explainer.expected_value, |
| feature_names=feature_names, |
| data=input_vector.toarray()[0] |
| ) |
|
|
| |
| def highlight_text_shap(text, word_importances, feature_names, max_num_features): |
| words = text.split() |
| word_to_importance = {} |
| for idx, word in enumerate(feature_names): |
| if word in text.lower(): |
| word_to_importance[word] = word_importances[idx] |
|
|
| sorted_word_importance = sorted(word_to_importance.items(), key=lambda x: abs(x[1]), reverse=True)[:max_num_features] |
| top_words = {word: importance for word, importance in sorted_word_importance} |
|
|
| highlighted_text = [] |
| for word in words: |
| cleaned_word = ''.join(filter(str.isalnum, word)).lower() |
| if cleaned_word in top_words: |
| importance = top_words[cleaned_word] |
| color = 'red' if importance > 0 else 'blue' |
| highlighted_text.append(f'<span style="color:{color}">{word}</span>') |
| else: |
| highlighted_text.append(word) |
|
|
| return ' '.join(highlighted_text) |
|
|
| |
| max_num_features = 10 |
|
|
| |
| shap_df = pd.DataFrame({ |
| 'Feature': shap_explanation.feature_names, |
| 'SHAP Value': shap_explanation.values |
| }).sort_values(by='SHAP Value', ascending=False).head(max_num_features) |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| plt.barh(shap_df['Feature'], shap_df['SHAP Value'], color=['red' if val > 0 else 'blue' for val in shap_df['SHAP Value']]) |
| plt.xlabel('SHAP Value') |
| plt.title('Top 10 Feature Importance') |
| plt.tight_layout() |
| shap_fig = fig_to_nparray(plt.gcf()) |
|
|
| |
| shap_highlighted_text = highlight_text_shap(input_text, shap_values, feature_names, max_num_features) |
|
|
| return label_name, lime_img, shap_fig, lime_html, shap_highlighted_text |
|
|
| def fig_to_nparray(fig): |
| """Convert a matplotlib figure to a NumPy array.""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png') |
| buf.seek(0) |
| img = Image.open(buf) |
| return np.array(img) |
|
|
|
|
| |
| iface = gr.Interface( |
| fn=explain_text, |
| inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), |
| outputs=[ |
| gr.Label(label="Predicted Label"), |
| gr.Image(type="numpy", label="LIME Explanation"), |
| gr.Image(type="numpy", label="SHAP Explanation"), |
| gr.HTML(label="LIME Highlighted Text Explanation"), |
| gr.HTML(label="SHAP Highlighted Text Explanation"), |
| ], |
| title="LIME and SHAP Explanations", |
| description="Enter a text sample to see its prediction and explanations using LIME and SHAP." |
| ) |
|
|
| |
| iface.launch() |