| | |
| | import torch |
| | from MultiTaskConvLSTM import ConvLSTMNetwork |
| | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
| | import torch |
| | import torch.nn as nn |
| | from tqdm.auto import tqdm |
| | from utils import ( |
| | mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation, |
| | spearman_correlation, percentage_error, percentage_bias, |
| | kendall_tau, spatial_correlation |
| | ) |
| | import torch.optim as optim |
| |
|
| |
|
| | device = 'cpu' |
| |
|
| | height = 81 |
| | width = 97 |
| |
|
| | set_lookback = 1 |
| | set_forecast_horizon = 1 |
| |
|
| | |
| | batch_size = 16 |
| | time_steps_out = set_forecast_horizon |
| | channels = 8 |
| |
|
| | |
| | |
| | variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'Total column rain water', 'Total precipitation', 'Time-integrated surface latent heat net flux'] |
| |
|
| | |
| | model = ConvLSTMNetwork( |
| | input_dim=8 * set_lookback, |
| | hidden_dims=[8, 32, 64], |
| | kernel_size=(3,3), |
| | num_layers=3, |
| | output_channels=64 * set_forecast_horizon, |
| | batch_first=True |
| | ).to(device) |
| |
|
| | |
| | loss_fn = nn.MSELoss() |
| | bce_loss_fn = nn.BCELoss() |
| |
|
| | optimizer = optim.AdamW(model.parameters(), lr = 0.005) |
| |
|
| | checkpoint = torch.load("MultiTaskConvLSTM_no_veg_variables.pth", map_location = device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| |
|
| | |
| | model.to(device) |
| |
|
| | |
| | model.eval() |
| |
|
| | print("Model loaded successfully") |
| |
|
| |
|
| | threshold = 0.1 |
| | precip_index = 10 |
| |
|
| | def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width): |
| | """ |
| | Evaluate the model on the test set for both regression and classification tasks. |
| | """ |
| | model.eval() |
| |
|
| | |
| | |
| | |
| |
|
| | test_reg_loss = 0.0 |
| | test_class_loss = 0.0 |
| | test_total_loss = 0.0 |
| |
|
| | y_true_reg = [] |
| | y_pred_reg = [] |
| |
|
| | y_pred_reg2 = [] |
| |
|
| | y_true_class = [] |
| | y_pred_class = [] |
| |
|
| | |
| | with torch.no_grad(): |
| | for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"): |
| | |
| | X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device) |
| |
|
| | |
| | batch_size, time_steps_in, channels_in, grid_points = X_test.shape |
| | batch_size, time_steps_out, channels_out, grid_points = y_test.shape |
| | X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width) |
| | y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width) |
| | y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width) |
| |
|
| | |
| | regression_output, classification_output = model(X_test) |
| |
|
| | classification_predictions = (classification_output > 0.7).float() |
| |
|
| | |
| | reg_loss = reg_loss_fn(regression_output, y_test) |
| |
|
| | |
| | class_loss = class_loss_fn(classification_output, y_zero_test) |
| |
|
| | |
| | total_loss = reg_loss + class_loss |
| |
|
| | regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions) |
| |
|
| | |
| | test_reg_loss += reg_loss.item() * X_test.size(0) |
| | test_class_loss += class_loss.item() * X_test.size(0) |
| | test_total_loss += total_loss.item() * X_test.size(0) |
| |
|
| | |
| | y_true_reg.append(y_test.cpu()) |
| | y_pred_reg.append(regression_output.cpu()) |
| | y_pred_reg2.append(regression_output2.cpu()) |
| | y_true_class.append(y_zero_test.cpu()) |
| | y_pred_class.append(classification_output.cpu()) |
| |
|
| | |
| | test_reg_loss /= len(test_loader) |
| | test_class_loss /= len(test_loader) |
| | test_total_loss /= len(test_loader) |
| |
|
| | print(f"Test Regression Loss: {test_reg_loss:.16f}") |
| | print(f"Test Classification Loss: {test_class_loss:.16f}") |
| | print(f"Test Total Loss: {test_total_loss:.16f}") |
| |
|
| | y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() |
| | y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() |
| | y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() |
| | y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() |
| |
|
| | |
| | regression_metrics = { |
| | "MSE": mse(y_true_reg_flat, y_pred_reg_flat), |
| | "MAE": mae(y_true_reg_flat, y_pred_reg_flat), |
| | "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat), |
| | "R2": r2_score(y_true_reg_flat, y_pred_reg_flat), |
| | "Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat), |
| | "Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat), |
| | "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat), |
| | "Percentage Error": percentage_error(y_true_reg_flat, y_pred_reg_flat), |
| | "Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat), |
| | "Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat), |
| | "Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)} |
| |
|
| | print("\nRegression Metrics:") |
| | for metric, value in regression_metrics.items(): |
| | print(f"{metric}: {value:.16f}") |
| |
|
| |
|
| | |
| | classification_metrics = { |
| | "Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
| | "Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
| | "Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
| | "F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)), |
| | "ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat), |
| | } |
| |
|
| | print("\nClassification Metrics:") |
| | for metric, value in classification_metrics.items(): |
| | print(f"{metric}: {value:.16f}") |
| |
|
| | torch.save({ |
| | 'y_true_reg': y_true_reg_flat, |
| | 'y_pred_reg': y_pred_reg_flat, |
| | 'y_true_class': y_true_class_flat, |
| | 'y_pred_class': y_pred_class_flat, |
| | }, 'results') |
| |
|
| | return test_total_loss, regression_metrics, classification_metrics |
| |
|
| |
|
| | """ |
| | EXPECTED DATALOADER BATCH FORMAT (normalized_test_data): |
| | |
| | Each batch must be a tuple: (X_batch, y_batch, y_zero_batch) |
| | |
| | X_batch contains the previous hours variables. y_batch contains the next hour's precipitation. |
| | y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and |
| | 1 for precipitation >0.1mm. |
| | |
| | Shapes BEFORE reshaping inside `evaluate`: |
| | X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857 |
| | y_batch: (B, T_out, C_out, G) |
| | y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets |
| | |
| | If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference. |
| | |
| | DTypes: |
| | X_batch, y_batch: torch.float32 |
| | y_zero_batch: torch.float32 (will be used with BCELoss) |
| | |
| | Reshaping done in 'evaluate': |
| | X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97) |
| | y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97) |
| | y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W) |
| | |
| | Model input: |
| | model expects X_test shaped (B, T_in, input_dim, H, W) |
| | where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9) |
| | |
| | Notes: |
| | • Make sure G == H*W (i.e., 7857 for 81x97). |
| | • C_out for precipitation should be 1 (one target channel), and y_zero_batch |
| | is the 0/1 mask for “zero precipitation” at each pixel & time. |
| | • y_zero_batch should be probabilities/labels in {0,1} for BCELoss. |
| | """ |
| |
|
| | normalized_test_data = torch.load("data/normalized_test_data_no_veg_input.pth") |
| |
|
| |
|
| | test_total_loss, regression_metrics, classification_metrics = evaluate( |
| | model=model, |
| | test_loader=normalized_test_data, |
| | reg_loss_fn=loss_fn, |
| | class_loss_fn=bce_loss_fn, |
| | device=device, |
| | variable_names=variable_names, |
| | height=height, |
| | width=width, |
| | ) |