Bachstelze commited on
Commit
de7924c
·
1 Parent(s): 76e50c6

save plot figure

Browse files
Files changed (1) hide show
  1. test/test_model.py +26 -2
test/test_model.py CHANGED
@@ -3,10 +3,30 @@ import pandas as pd
3
  from sklearn.linear_model import LinearRegression
4
  from sklearn.metrics import mean_absolute_error, r2_score
5
  import re # For using regular expressions
 
 
6
 
7
  train_path = "../Datasets_all/A2_dataset_80.csv"
8
  test_path = "../Datasets_all/A2_dataset_20.csv"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def extract_missing_feature(error_message):
11
  # Use regex to find feature names in the ValueError message
12
  match = re.search(r"Feature names unseen at fit time:\s*-\s*(.+)", error_message)
@@ -39,8 +59,8 @@ def load_and_evaluate_model(model_path):
39
  X_test = test_df[features_cols]
40
  y_test = test_df[target_col]
41
 
42
- #define y_pred
43
- y_pred = 0
44
 
45
  # Continue to predict until no ValueErrors occur
46
  while True:
@@ -81,6 +101,10 @@ def load_and_evaluate_model(model_path):
81
  # Save predictions to CSV
82
  test_df["Predicted_AimoScore"] = y_pred
83
  test_df.to_csv("predicted_test.csv", index=False)
 
 
 
 
84
 
85
 
86
  if __name__ == "__main__":
 
3
  from sklearn.linear_model import LinearRegression
4
  from sklearn.metrics import mean_absolute_error, r2_score
5
  import re # For using regular expressions
6
+ import matplotlib.pyplot as plt
7
+ import datetime
8
 
9
  train_path = "../Datasets_all/A2_dataset_80.csv"
10
  test_path = "../Datasets_all/A2_dataset_20.csv"
11
 
12
+ def save_prediction_plot(y_test, y_test_pred_baseline, baseline_test_r2):
13
+ # Visualize baseline predictions
14
+ fig, axes = plt.subplots(figsize=(5, 5))
15
+
16
+ # Actual vs Predicted
17
+ axes.scatter(y_test, y_test_pred_baseline, alpha=0.5)
18
+ axes.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
19
+ axes.set_xlabel('Actual AimoScore')
20
+ axes.set_ylabel('Predicted AimoScore')
21
+ axes.set_title(f'Baseline: Actual vs Predicted (R²={baseline_test_r2:.4f})')
22
+ axes.grid(True, alpha=0.3)
23
+
24
+ # Save the figure
25
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # e.g., 20260130_143210
26
+ fig_path = f"baseline_actual_vs_predicted_{timestamp}.png"
27
+ plt.savefig(fig_path, dpi=300, bbox_inches='tight')
28
+ print(f"Figure saved to {fig_path}")
29
+
30
  def extract_missing_feature(error_message):
31
  # Use regex to find feature names in the ValueError message
32
  match = re.search(r"Feature names unseen at fit time:\s*-\s*(.+)", error_message)
 
59
  X_test = test_df[features_cols]
60
  y_test = test_df[target_col]
61
 
62
+ #define y_pred and r2
63
+ y_pred, r2 = 0, 0
64
 
65
  # Continue to predict until no ValueErrors occur
66
  while True:
 
101
  # Save predictions to CSV
102
  test_df["Predicted_AimoScore"] = y_pred
103
  test_df.to_csv("predicted_test.csv", index=False)
104
+ else:
105
+ print("no predictions!!!")
106
+
107
+ save_prediction_plot(y_test, y_pred, r2)
108
 
109
 
110
  if __name__ == "__main__":