| import zipfile |
| import os |
| import plotly.express as px |
| import plotly.graph_objects as go |
| from torch.utils.data.dataloader import DataLoader as dl |
| import yaml |
| from io import StringIO |
| import torch as t |
| import numpy as np |
| import pandas as pd |
| from torch.utils.data import Dataset as torch_dset |
| from PIL import Image |
| import torchvision.transforms.functional as tvfunc |
| import json |
| from matplotlib import pyplot as plt |
| import matplotlib.patches as patches |
| from matplotlib.font_manager import FontProperties |
| import pathlib as pl |
| import matplotlib as mpl |
| import streamlit as st |
| from streamlit.runtime.uploaded_file_manager import UploadedFile |
| import einops as eo |
| import copy |
|
|
| |
| from tqdm.auto import tqdm |
| import time |
| import requests |
|
|
| from matplotlib.patches import Rectangle |
| from matplotlib import font_manager |
| from models import LitModel, EnsembleModel |
| from loss_functions import corn_label_from_logits |
| import classic_correction_algos as calgo |
| import analysis_funcs as anf |
|
|
| TEMP_FOLDER = pl.Path("results") |
| AVAILABLE_FONTS = [x.name for x in font_manager.fontManager.ttflist] |
| PLOTS_FOLDER = pl.Path("plots") |
| TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER / "temp_matplotlib_plot_stimulus.png" |
| all_fonts = [x.name for x in font_manager.fontManager.ttflist] |
| mpl.use("agg") |
|
|
| DIST_MODELS_FOLDER = pl.Path("models") |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
| gradio_plots = pl.Path("plots") |
|
|
| event_strs = [ |
| "EFIX", |
| "EFIX R", |
| "EFIX L", |
| "SSACC", |
| "ESACC", |
| "SFIX", |
| "MSG", |
| "SBLINK", |
| "EBLINK", |
| "BUTTON", |
| "INPUT", |
| "END", |
| "START", |
| "DISPLAY ON", |
| ] |
| names_dict = { |
| "SSACC": {"Descr": "Start of Saccade", "Pattern": "SSACC <eye > <stime>"}, |
| "ESACC": { |
| "Descr": "End of Saccade", |
| "Pattern": "ESACC <eye > <stime> <etime > <dur> <sxp > <syp> <exp > <eyp> <ampl > <pv >", |
| }, |
| "SFIX": {"Descr": "Start of Fixation", "Pattern": "SFIX <eye > <stime>"}, |
| "EFIX": {"Descr": "End of Fixation", "Pattern": "EFIX <eye > <stime> <etime > <dur> <axp > <ayp> <aps >"}, |
| "SBLINK": {"Descr": "Start of Blink", "Pattern": "SBLINK <eye > <stime>"}, |
| "EBLINK": {"Descr": "End of Blink", "Pattern": "EBLINK <eye > <stime> <etime > <dur>"}, |
| "DISPLAY ON": {"Descr": "Actual start of Trial", "Pattern": "DISPLAY ON"}, |
| } |
| metadata_strs = ["DISPLAY COORDS", "GAZE_COORDS", "FRAMERATE"] |
|
|
| ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [ |
| "warp", |
| "regress", |
| "compare", |
| "attach", |
| "segment", |
| "split", |
| "stretch", |
| "chain", |
| "slice", |
| "cluster", |
| "merge", |
| "Wisdom_of_Crowds", |
| "DIST", |
| "DIST-Ensemble", |
| "Wisdom_of_Crowds_with_DIST", |
| "Wisdom_of_Crowds_with_DIST_Ensemble", |
| ] |
| COLORS = px.colors.qualitative.Alphabet |
|
|
|
|
| class NumpyEncoder(json.JSONEncoder): |
| "From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable" |
|
|
| def default(self, obj): |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| elif isinstance(obj, pl.Path) or isinstance(obj, UploadedFile): |
| return str(obj) |
| return json.JSONEncoder.default(self, obj) |
|
|
|
|
| class DSet(torch_dset): |
| def __init__( |
| self, |
| in_sequence: t.Tensor, |
| chars_center_coords_padded: t.Tensor, |
| out_categories: t.Tensor, |
| trialslist: list, |
| padding_list: list = None, |
| padding_at_end: bool = False, |
| return_images_for_conv: bool = False, |
| im_partial_string: str = "fixations_chars_channel_sep", |
| input_im_shape=[224, 224], |
| ) -> None: |
| super().__init__() |
|
|
| self.in_sequence = in_sequence |
| self.chars_center_coords_padded = chars_center_coords_padded |
| self.out_categories = out_categories |
| self.padding_list = padding_list |
| self.padding_at_end = padding_at_end |
| self.trialslist = trialslist |
| self.return_images_for_conv = return_images_for_conv |
| self.input_im_shape = input_im_shape |
| if return_images_for_conv: |
| self.im_partial_string = im_partial_string |
| self.plot_files = [ |
| str(x["plot_file"]).replace("fixations_words", im_partial_string) for x in self.trialslist |
| ] |
|
|
| def __getitem__(self, index): |
|
|
| if self.return_images_for_conv: |
| im = Image.open(self.plot_files[index]) |
| if [im.size[1], im.size[0]] != self.input_im_shape: |
| im = tvfunc.resize(im, self.input_im_shape) |
| im = tvfunc.normalize(tvfunc.to_tensor(im), IMAGENET_MEAN, IMAGENET_STD) |
| if self.chars_center_coords_padded is not None: |
| if self.padding_list is not None: |
| attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long) |
| if self.padding_at_end: |
| if self.padding_list[index] > 0: |
| attention_mask[-self.padding_list[index] :] = 0 |
| else: |
| attention_mask[: self.padding_list[index]] = 0 |
| if self.return_images_for_conv: |
| return ( |
| self.in_sequence[index], |
| self.chars_center_coords_padded[index], |
| im, |
| attention_mask, |
| self.out_categories[index], |
| ) |
| return ( |
| self.in_sequence[index], |
| self.chars_center_coords_padded[index], |
| attention_mask, |
| self.out_categories[index], |
| ) |
| else: |
| if self.return_images_for_conv: |
| return ( |
| self.in_sequence[index], |
| self.chars_center_coords_padded[index], |
| im, |
| self.out_categories[index], |
| ) |
| else: |
| return (self.in_sequence[index], self.chars_center_coords_padded[index], self.out_categories[index]) |
|
|
| if self.padding_list is not None: |
| attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long) |
| if self.padding_at_end: |
| if self.padding_list[index] > 0: |
| attention_mask[-self.padding_list[index] :] = 0 |
| else: |
| attention_mask[: self.padding_list[index]] = 0 |
| if self.return_images_for_conv: |
| return (self.in_sequence[index], im, attention_mask, self.out_categories[index]) |
| else: |
| return (self.in_sequence[index], attention_mask, self.out_categories[index]) |
| if self.return_images_for_conv: |
| return (self.in_sequence[index], im, self.out_categories[index]) |
| else: |
| return (self.in_sequence[index], self.out_categories[index]) |
|
|
| def __len__(self): |
| if isinstance(self.in_sequence, t.Tensor): |
| return self.in_sequence.shape[0] |
| else: |
| return len(self.in_sequence) |
|
|
|
|
| def download_url(url, target_filename): |
| r = requests.get(url) |
| open(target_filename, "wb").write(r.content) |
| return 0 |
|
|
|
|
| def asc_to_trial_ids(asc_file, close_gap_between_words=True): |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug("asc_to_trial_ids entered") |
| asc_encoding = ["ISO-8859-15", "UTF-8"][0] |
| trials_dict, lines = file_to_trials_and_lines( |
| asc_file, asc_encoding, close_gap_between_words=close_gap_between_words |
| ) |
|
|
| trials_by_ids = {trials_dict[idx]["trial_id"]: trials_dict[idx] for idx in trials_dict["paragraph_trials"]} |
| if hasattr(asc_file, "name"): |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"Found {len(trials_by_ids)} trials in {asc_file.name}.") |
| return trials_by_ids, lines |
|
|
|
|
| def get_trials_list(asc_file=None, close_gap_between_words=True): |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug("get_trials_list entered") |
|
|
| if asc_file == None: |
| if "single_asc_file" in st.session_state.keys() and st.session_state["single_asc_file"] is not None: |
| asc_file = st.session_state["single_asc_file"] |
| else: |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning("Asc file is None") |
| return None |
|
|
| if hasattr(asc_file, "name"): |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"get_trials_list entered with asc_file {asc_file.name}") |
|
|
| trials_by_ids, lines = asc_to_trial_ids(asc_file, close_gap_between_words=close_gap_between_words) |
| trial_keys = list(trials_by_ids.keys()) |
|
|
| return trial_keys, trials_by_ids, lines, asc_file |
|
|
|
|
| def save_trial_to_json(trial, savename): |
| if "dffix" in trial: |
| trial.pop("dffix") |
| with open(savename, "w", encoding="utf-8") as f: |
| json.dump(trial, f, ensure_ascii=False, indent=4, cls=NumpyEncoder) |
|
|
|
|
| def export_csv(dffix, trial): |
| if isinstance(dffix, dict): |
| dffix = dffix["value"] |
| trial_id = trial["trial_id"] |
| savename = TEMP_FOLDER.joinpath(pl.Path(trial["fname"]).stem) |
| trial_name = f"{savename}_{trial_id}_trial_info.json" |
| csv_name = f"{savename}_{trial_id}.csv" |
| dffix.to_csv(csv_name) |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"Saved processed data as {csv_name}") |
| save_trial_to_json(trial, trial_name) |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"Saved processed trial data as {trial_name}") |
|
|
| return csv_name, trial_name |
|
|
|
|
| def get_all_classic_preds(dffix, trial, classic_algos_cfg): |
| corrections = [] |
| for algo, classic_params in copy.deepcopy(classic_algos_cfg).items(): |
| dffix = calgo.apply_classic_algo(dffix, trial, algo, classic_params) |
| corrections.append(np.asarray(dffix.loc[:, f"y_{algo}"])) |
| return dffix, corrections |
|
|
|
|
| def apply_woc(dffix, trial, corrections, algo_choice): |
|
|
| corrected_Y = calgo.wisdom_of_the_crowd(corrections) |
| dffix.loc[:, f"y_{algo_choice}"] = corrected_Y |
| dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
| corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_Y] |
| dffix.loc[:, f"line_num_y_{algo_choice}"] = corrected_line_nums |
| return dffix |
|
|
|
|
| def calc_xdiff_ydiff(line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False): |
| x_diffs = np.unique(np.diff(line_xcoords_no_pad)) |
| if len(x_diffs) == 1: |
| x_diff = x_diffs[0] |
| elif not allow_multiple_values: |
| x_diff = np.min(x_diffs) |
| else: |
| x_diff = x_diffs |
|
|
| if np.unique(line_ycoords_no_pad).shape[0] == 1: |
| return x_diff, line_heights[0] |
| y_diffs = np.unique(np.diff(line_ycoords_no_pad)) |
| if len(y_diffs) == 1: |
| y_diff = y_diffs[0] |
| elif len(y_diffs) == 0: |
| y_diff = 0 |
| elif not allow_multiple_values: |
| y_diff = np.min(y_diffs) |
| else: |
| y_diff = y_diffs |
| return x_diff, y_diff |
|
|
|
|
| def add_words(trial, close_gap_between_words=True): |
| chars_list_reconstructed = [] |
| words_list = [] |
| word_start_idx = 0 |
| chars_df = pd.DataFrame(trial["chars_list"]) |
| chars_df["char_width"] = chars_df.char_xmax - chars_df.char_xmin |
| space_width = chars_df.loc[chars_df["char"] == " ", "char_width"].mean() |
|
|
| for idx, char_dict in enumerate(trial["chars_list"]): |
| on_line_num = char_dict["assigned_line"] |
| chars_list_reconstructed.append(char_dict) |
| if ( |
| char_dict["char"] in [" ", ",", ";", ".", ":"] |
| or ( |
| len(chars_list_reconstructed) > 2 |
| and (chars_list_reconstructed[-1]["char_xmin"] < chars_list_reconstructed[-2]["char_xmin"]) |
| ) |
| or len(chars_list_reconstructed) == len(trial["chars_list"]) |
| ): |
| triggered = True |
| word_xmin = chars_list_reconstructed[word_start_idx]["char_xmin"] |
| word_xmax = chars_list_reconstructed[-2]["char_xmax"] |
| word_ymin = chars_list_reconstructed[word_start_idx]["char_ymin"] |
| word_ymax = chars_list_reconstructed[word_start_idx]["char_ymax"] |
| word_x_center = (word_xmax - word_xmin) / 2 + word_xmin |
| word_y_center = (word_ymax - word_ymin) / 2 + word_ymin |
| word = "".join( |
| [ |
| chars_list_reconstructed[idx]["char"] |
| for idx in range(word_start_idx, len(chars_list_reconstructed) - 1) |
| ] |
| ) |
| assigned_line = chars_list_reconstructed[word_start_idx]["assigned_line"] |
|
|
| word_dict = dict( |
| word=word, |
| word_xmin=word_xmin, |
| word_xmax=word_xmax, |
| word_ymin=word_ymin, |
| word_ymax=word_ymax, |
| word_x_center=word_x_center, |
| word_y_center=word_y_center, |
| assigned_line=assigned_line, |
| ) |
| if char_dict["char"] != " ": |
| word_start_idx = idx |
| else: |
| word_start_idx = idx + 1 |
| words_list.append(word_dict) |
| else: |
| triggered = False |
| last_letter_in_word = word_dict["word"][-1] |
| last_letter_in_chars_list_reconstructed = char_dict["char"] |
| if last_letter_in_word != last_letter_in_chars_list_reconstructed: |
| word_dict = dict( |
| word=char_dict["char"], |
| word_xmin=char_dict["char_xmin"], |
| word_xmax=char_dict["char_xmax"], |
| word_ymin=char_dict["char_ymin"], |
| word_ymax=char_dict["char_ymax"], |
| word_x_center=char_dict["char_x_center"], |
| word_y_center=char_dict["char_y_center"], |
| assigned_line=assigned_line, |
| ) |
| words_list.append(word_dict) |
|
|
| if close_gap_between_words: |
| for widx in range(1, len(words_list)): |
| if words_list[widx]["assigned_line"] == words_list[widx - 1]["assigned_line"]: |
| word_sep_half_width = (words_list[widx]["word_xmin"] - words_list[widx - 1]["word_xmax"]) / 2 |
| words_list[widx - 1]["word_xmax"] = words_list[widx - 1]["word_xmax"] + word_sep_half_width |
| words_list[widx]["word_xmin"] = words_list[widx]["word_xmin"] - word_sep_half_width |
|
|
| return words_list |
|
|
|
|
| def asc_lines_to_trials_by_trail_id( |
| lines: list, paragraph_trials_only=False, fname: str = "", close_gap_between_words=True |
| ) -> dict: |
| if hasattr(fname, "name"): |
| fname = fname.name |
| fps = -999 |
| display_coords = -999 |
| trials_dict = dict(paragraph_trials=[], paragraph_trial_IDs=[]) |
| trial_idx = -1 |
| removed_trial_ids = [] |
| for idx, l in enumerate(lines): |
| parts = l.strip().split(" ") |
| if "TRIALID" in l: |
| trial_id = parts[-1] |
| trial_idx += 1 |
| if trial_id[0] == "F": |
| trial_is = "question" |
| elif trial_id[0] == "P": |
| trial_is = "practice" |
| else: |
| trial_is = "paragraph" |
| trials_dict["paragraph_trials"].append(trial_idx) |
| trials_dict["paragraph_trial_IDs"].append(trial_id) |
| trials_dict[trial_idx] = dict(trial_id=trial_id, trial_id_idx=idx, trial_is=trial_is, filename=fname) |
| last_trial_skipped = False |
|
|
| elif "TRIAL_RESULT" in l or "stop_trial" in l: |
| trials_dict[trial_idx]["trial_result_idx"] = idx |
| trials_dict[trial_idx]["trial_result_timestamp"] = int(parts[0].split("\t")[1]) |
| if len(parts) > 2: |
| trials_dict[trial_idx]["trial_result_number"] = int(parts[2]) |
| elif "DISPLAY COORDS" in l and isinstance(display_coords, int): |
| display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1])) |
| elif "GAZE_COORDS" in l and isinstance(display_coords, int): |
| display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1])) |
| elif "FRAMERATE" in l: |
| l_idx = parts.index(metadata_strs[2]) |
| fps = float(parts[l_idx + 1]) |
| elif "TRIAL ABORTED" in l or "TRIAL REPEATED" in l: |
| if not last_trial_skipped: |
| if trial_is == "paragraph": |
| trials_dict["paragraph_trials"].remove(trial_idx) |
| trial_idx -= 1 |
| removed_trial_ids.append(trial_id) |
| last_trial_skipped = True |
|
|
| if paragraph_trials_only: |
| trials_dict_temp = trials_dict.copy() |
| for k in trials_dict_temp.keys(): |
| if k not in ["paragraph_trials"] + trials_dict_temp["paragraph_trials"]: |
| trials_dict.pop(k) |
| if len(trials_dict_temp["paragraph_trials"]): |
| trial_idx = trials_dict_temp["paragraph_trials"][-1] |
| else: |
| return trials_dict |
| trials_dict["display_coords"] = display_coords |
| trials_dict["fps"] = fps |
| trials_dict["max_trial_idx"] = trial_idx |
| enum = trials_dict["paragraph_trials"] if "paragraph_trials" in trials_dict.keys() else range(len(trials_dict)) |
| for trial_idx in enum: |
| if trial_idx not in trials_dict.keys(): |
| continue |
| chars_list = [] |
| if "display_coords" not in trials_dict[trial_idx].keys(): |
| trials_dict[trial_idx]["display_coords"] = trials_dict["display_coords"] |
| trial_start_idx = trials_dict[trial_idx]["trial_id_idx"] |
| trial_end_idx = trials_dict[trial_idx]["trial_result_idx"] |
| trial_lines = lines[trial_start_idx:trial_end_idx] |
| for idx, l in enumerate(trial_lines): |
| parts = l.strip().split(" ") |
| if "START" in l and " MSG" not in l: |
| trials_dict[trial_idx]["start_idx"] = trial_start_idx + idx + 7 |
| trials_dict[trial_idx]["start_time"] = int(parts[0].split("\t")[1]) |
| elif "END" in l and "ENDBUTTON" not in l and " MSG" not in l: |
| trials_dict[trial_idx]["end_idx"] = trial_start_idx + idx - 2 |
| trials_dict[trial_idx]["end_time"] = int(parts[0].split("\t")[1]) |
| elif "SYNCTIME" in l: |
| trials_dict[trial_idx]["synctime"] = trial_start_idx + idx |
| trials_dict[trial_idx]["synctime_time"] = int(parts[0].split("\t")[1]) |
| elif "GAZE TARGET OFF" in l: |
| trials_dict[trial_idx]["gaze_targ_off_time"] = int(parts[0].split("\t")[1]) |
| elif "GAZE TARGET ON" in l: |
| trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1]) |
| elif "DISPLAY_SENTENCE" in l: |
| trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1]) |
| elif "REGION CHAR" in l: |
| rg_idx = parts.index("CHAR") |
| if len(parts[rg_idx:]) > 8: |
| char = " " |
| idx_correction = 1 |
| elif len(parts[rg_idx:]) == 3: |
| char = " " |
| if "REGION CHAR" not in trial_lines[idx + 1]: |
| parts = trial_lines[idx + 1].strip().split(" ") |
| idx_correction = -rg_idx - 4 |
| else: |
| char = parts[rg_idx + 3] |
| idx_correction = 0 |
| try: |
| char_dict = { |
| "char": char, |
| "char_xmin": float(parts[rg_idx + 4 + idx_correction]), |
| "char_ymin": float(parts[rg_idx + 5 + idx_correction]), |
| "char_xmax": float(parts[rg_idx + 6 + idx_correction]), |
| "char_ymax": float(parts[rg_idx + 7 + idx_correction]), |
| } |
| char_dict["char_y_center"] = (char_dict["char_ymax"] - char_dict["char_ymin"]) / 2 + char_dict[ |
| "char_ymin" |
| ] |
| char_dict["char_x_center"] = (char_dict["char_xmax"] - char_dict["char_xmin"]) / 2 + char_dict[ |
| "char_xmin" |
| ] |
| chars_list.append(char_dict) |
| except Exception as e: |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(f"char_dict creation failed for parts {parts}") |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(e) |
|
|
| if "gaze_targ_on_time" in trials_dict[trial_idx]: |
| trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["gaze_targ_on_time"] |
| else: |
| trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["start_time"] |
|
|
| if len(chars_list) > 0: |
| line_ycoords = [] |
| for idx in range(len(chars_list)): |
| chars_list[idx]["char_line_y"] = ( |
| chars_list[idx]["char_ymax"] - chars_list[idx]["char_ymin"] |
| ) / 2 + chars_list[idx]["char_ymin"] |
| if chars_list[idx]["char_line_y"] not in line_ycoords: |
| line_ycoords.append(chars_list[idx]["char_line_y"]) |
| for idx in range(len(chars_list)): |
| chars_list[idx]["assigned_line"] = line_ycoords.index(chars_list[idx]["char_line_y"]) |
|
|
| line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list] |
| line_xcoords_all = [x["char_x_center"] for x in chars_list] |
| line_xcoords_no_pad = np.unique(line_xcoords_all) |
|
|
| line_ycoords_all = [x["char_y_center"] for x in chars_list] |
| line_ycoords_no_pad = np.unique(line_ycoords_all) |
|
|
| trials_dict[trial_idx]["x_char_unique"] = list(line_xcoords_no_pad) |
| trials_dict[trial_idx]["y_char_unique"] = list(line_ycoords_no_pad) |
| x_diff, y_diff = calc_xdiff_ydiff( |
| line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False |
| ) |
| trials_dict[trial_idx]["x_diff"] = float(x_diff) |
| trials_dict[trial_idx]["y_diff"] = float(y_diff) |
| trials_dict[trial_idx]["num_char_lines"] = len(line_ycoords_no_pad) |
| trials_dict[trial_idx]["line_heights"] = line_heights |
| trials_dict[trial_idx]["chars_list"] = chars_list |
|
|
| words_list = add_words(trials_dict[trial_idx], close_gap_between_words=close_gap_between_words) |
| trials_dict[trial_idx]["words_list"] = words_list |
|
|
| return trials_dict |
|
|
|
|
| def file_to_trials_and_lines(uploaded_file, asc_encoding: str = "ISO-8859-15", close_gap_between_words=True): |
| if isinstance(uploaded_file, str) or isinstance(uploaded_file, pl.Path): |
| with open(uploaded_file, "r", encoding=asc_encoding) as f: |
| lines = f.readlines() |
| else: |
| stringio = StringIO(uploaded_file.getvalue().decode(asc_encoding)) |
| loaded_str = stringio.read() |
| lines = loaded_str.split("\n") |
| trials_dict = asc_lines_to_trials_by_trail_id( |
| lines, True, uploaded_file, close_gap_between_words=close_gap_between_words |
| ) |
|
|
| if "paragraph_trials" not in trials_dict.keys() and "trial_is" in trials_dict[0].keys(): |
| paragraph_trials = [] |
| for k in range(trials_dict["max_trial_idx"]): |
| if trials_dict[k]["trial_is"] == "paragraph": |
| paragraph_trials.append(k) |
| trials_dict["paragraph_trials"] = paragraph_trials |
|
|
| enum = ( |
| trials_dict["paragraph_trials"] |
| if "paragraph_trials" in trials_dict.keys() |
| else range(trials_dict["max_trial_idx"]) |
| ) |
| for k in enum: |
| if "chars_list" in trials_dict[k].keys(): |
| max_line = trials_dict[k]["chars_list"][-1]["assigned_line"] |
| words_on_lines = {x: [] for x in range(max_line + 1)} |
| [words_on_lines[x["assigned_line"]].append(x["char"]) for x in trials_dict[k]["chars_list"]] |
| sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()] |
| text = sentence_list[0] + "\n".join([x for x in sentence_list[1:]]) |
| trials_dict[k]["sentence_list"] = sentence_list |
| trials_dict[k]["text"] = text |
| trials_dict[k]["max_line"] = max_line |
|
|
| return trials_dict, lines |
|
|
|
|
| def get_plot_props(trial, available_fonts): |
| if "font" in trial.keys(): |
| font = trial["font"] |
| font_size = trial["font_size"] |
| if font not in available_fonts: |
| font = "DejaVu Sans Mono" |
| else: |
| font = "DejaVu Sans Mono" |
| font_size = 21 |
| dpi = 100 |
| if "display_coords" in trial.keys(): |
| screen_res = (trial["display_coords"][2], trial["display_coords"][3]) |
| else: |
| screen_res = (1920, 1080) |
| return font, font_size, dpi, screen_res |
|
|
|
|
| def trial_to_dfs( |
| trial: dict, lines: list, use_synctime: bool = False, save_lines_to_txt=False, cut_out_outer_fixations=False |
| ): |
| """trial should be dict of line numbers of trials. |
| lines should be list of lines from .asc file.""" |
|
|
| if use_synctime and "synctime" in trial: |
| idx0, idxend = trial["synctime"] + 1, trial["trial_result_idx"] |
| else: |
| idx0, idxend = trial["start_idx"], trial["end_idx"] |
|
|
| line_dicts = [] |
| fixations_dicts = [] |
| blink_started = False |
|
|
| fixation_started = False |
| efix_count = 0 |
| sfix_count = 0 |
| sblink_count = 0 |
|
|
| if save_lines_to_txt: |
| with open("Lines_plus500.txt", "w") as f: |
| f.writelines(lines[idx0 - 500 : idxend + 500]) |
| eye_to_use = "R" |
| for l in lines[idx0 : idxend + 1]: |
| if "EFIX R" in l: |
| eye_to_use = "R" |
| break |
| elif "EFIX L" in l: |
| eye_to_use = "L" |
| break |
| for l in lines[idx0 : idxend + 1]: |
| parts = [x.strip() for x in l.split("\t")] |
| if f"EFIX {eye_to_use}" in l: |
| efix_count += 1 |
| if fixation_started: |
| if parts[1] == "." and parts[2] == ".": |
| continue |
| fixations_dicts.append( |
| { |
| "start_time": float(parts[0].split()[-1].strip()), |
| "end_time": float(parts[1].strip()), |
| "duration": float(parts[2].strip()), |
| "x": float(parts[3].strip()), |
| "y": float(parts[4].strip()), |
| "pupil_size": float(parts[5].strip()), |
| } |
| ) |
| if len(fixations_dicts) >= 2: |
| assert ( |
| fixations_dicts[-1]["start_time"] > fixations_dicts[-2]["start_time"] |
| ), "start times not in order" |
| fixation_started = False |
|
|
| elif f"SFIX {eye_to_use}" in l: |
| sfix_count += 1 |
| fixation_started = True |
| elif f"SBLINK {eye_to_use}" in l: |
| sblink_count += 1 |
| blink_started = True |
| if not blink_started and not any([True for x in event_strs if x in l]): |
| if len(parts) < 3 or (parts[1] == "." and parts[2] == "."): |
| continue |
| line_dicts.append( |
| { |
| "idx": float(parts[0].strip()), |
| "x": float(parts[1].strip()), |
| "y": float(parts[2].strip()), |
| "p": float(parts[3].strip()), |
| } |
| ) |
|
|
| elif f"EBLINK {eye_to_use}" in l: |
| blink_started = False |
|
|
| df = pd.DataFrame(line_dicts) |
| dffix = pd.DataFrame(fixations_dicts) |
| if len(fixations_dicts) > 0: |
| dffix["corrected_start_time"] = dffix.start_time - trial["trial_start_time"] |
| dffix["corrected_end_time"] = dffix.end_time - trial["trial_start_time"] |
| dffix["fix_duration"] = dffix.corrected_end_time.values - dffix.corrected_start_time.values |
| assert all(np.diff(dffix["corrected_start_time"]) > 0), "start times not in order" |
| else: |
| df, pd.DataFrame(), trial |
|
|
| if cut_out_outer_fixations: |
| dffix = dffix[(dffix.x > -10) & (dffix.y > -10) & (dffix.x < 1050) & (dffix.y < 800)] |
| trial["efix_count"] = efix_count |
| trial["eye_to_use"] = eye_to_use |
| trial["sfix_count"] = sfix_count |
| trial["sblink_count"] = sblink_count |
| return df, dffix, trial |
|
|
|
|
| def get_save_path(fpath, fname_ending): |
| save_path = gradio_plots.joinpath(f"{fpath.stem}_{fname_ending}.png") |
| return save_path |
|
|
|
|
| def save_im_load_convert(fpath, fig, fname_ending, mode): |
| save_path = get_save_path(fpath, fname_ending) |
| fig.savefig(save_path) |
| im = Image.open(save_path).convert(mode) |
| im.save(save_path) |
| return im |
|
|
|
|
| def get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, dffix=None, prefix="word"): |
| fig = plt.figure(figsize=(screen_res[0] / dpi, screen_res[1] / dpi), dpi=dpi) |
| ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) |
| ax.set_axis_off() |
| if dffix is not None: |
| ax.set_ylim((dffix.y.min(), dffix.y.max())) |
| ax.set_xlim((dffix.x.min(), dffix.x.max())) |
| else: |
| ax.set_ylim((words_df[f"{prefix}_y_center"].min() - y_margin, words_df[f"{prefix}_y_center"].max() + y_margin)) |
| ax.set_xlim((words_df[f"{prefix}_x_center"].min() - x_margin, words_df[f"{prefix}_x_center"].max() + x_margin)) |
| ax.invert_yaxis() |
| fig.add_axes(ax) |
| return fig, ax |
|
|
|
|
| def plot_text_boxes_fixations( |
| fpath, |
| dpi, |
| screen_res, |
| data_dir_sub, |
| set_font_size: bool, |
| font_size: int, |
| use_words: bool, |
| save_channel_repeats: bool, |
| save_combo_grey_and_rgb: bool, |
| dffix=None, |
| trial=None, |
| ): |
| if isinstance(fpath, str): |
| fpath = pl.Path(fpath) |
| if use_words: |
| prefix = "word" |
| else: |
| prefix = "char" |
| if dffix is None: |
| dffix = pd.read_csv(fpath) |
| if trial is None: |
| json_fpath = str(fpath).replace("_fixations.csv", "_trial.json") |
| with open(json_fpath, "r") as f: |
| trial = json.load(f) |
| words_df = pd.DataFrame(trial[f"{prefix}s_list"]) |
| x_right = words_df[f"{prefix}_xmin"] |
| x_left = words_df[f"{prefix}_xmax"] |
| y_top = words_df[f"{prefix}_ymax"] |
| y_bottom = words_df[f"{prefix}_ymin"] |
|
|
| if f"{prefix}_x_center" not in words_df.columns: |
| words_df[f"{prefix}_x_center"] = (words_df[f"{prefix}_xmax"] - words_df[f"{prefix}_xmin"]) / 2 + words_df[ |
| f"{prefix}_xmin" |
| ] |
| words_df[f"{prefix}_y_center"] = (words_df[f"{prefix}_ymax"] - words_df[f"{prefix}_ymin"]) / 2 + words_df[ |
| f"{prefix}_ymin" |
| ] |
|
|
| x_margin = words_df[f"{prefix}_x_center"].mean() / 8 |
| y_margin = words_df[f"{prefix}_y_center"].mean() / 4 |
| times = dffix.corrected_start_time - dffix.corrected_start_time.min() |
| times = times / times.max() |
| times = np.linspace(0.25, 1, len(times)) |
|
|
| if set_font_size: |
| font = "monospace" |
| else: |
| font_size = trial["font_size"] * 27 // dpi |
|
|
| font_props = FontProperties(family=font, style="normal", size=font_size) |
| if save_combo_grey_and_rgb: |
| fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
| ax.scatter(dffix.x, dffix.y, alpha=times, facecolor="b") |
| for idx in range(len(x_left)): |
| xdiff = x_right[idx] - x_left[idx] |
| ydiff = y_top[idx] - y_bottom[idx] |
| rect = patches.Rectangle( |
| (x_left[idx] - 1, y_bottom[idx] - 1), |
| xdiff, |
| ydiff, |
| alpha=0.9, |
| linewidth=0.8, |
| edgecolor="r", |
| facecolor="none", |
| ) |
| ax.text( |
| words_df[f"{prefix}_x_center"][idx], |
| words_df[f"{prefix}_y_center"][idx], |
| words_df[prefix][idx], |
| horizontalalignment="center", |
| verticalalignment="center", |
| fontproperties=font_props, |
| color="g", |
| ) |
| ax.add_patch(rect) |
| fname_ending = f"{prefix}s_combo_rgb" |
| words_combo_rgb_im = save_im_load_convert(fpath, fig, fname_ending, "RGB") |
| plt.close("all") |
|
|
| fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
| ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times) |
| for idx in range(len(x_left)): |
| xdiff = x_right[idx] - x_left[idx] |
| ydiff = y_top[idx] - y_bottom[idx] |
| rect = patches.Rectangle( |
| (x_left[idx] - 1, y_bottom[idx] - 1), |
| xdiff, |
| ydiff, |
| alpha=0.9, |
| linewidth=0.8, |
| edgecolor="k", |
| facecolor="none", |
| ) |
| ax.text( |
| words_df[f"{prefix}_x_center"][idx], |
| words_df[f"{prefix}_y_center"][idx], |
| words_df[prefix][idx], |
| horizontalalignment="center", |
| verticalalignment="center", |
| fontproperties=font_props, |
| ) |
| ax.add_patch(rect) |
| fname_ending = f"{prefix}s_combo_grey" |
| words_combo_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
| plt.close("all") |
|
|
| fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
| ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.01) |
| for idx in range(len(x_left)): |
| ax.text( |
| words_df[f"{prefix}_x_center"][idx], |
| words_df[f"{prefix}_y_center"][idx], |
| words_df[prefix][idx], |
| horizontalalignment="center", |
| verticalalignment="center", |
| fontproperties=font_props, |
| ) |
| fname_ending = f"{prefix}s_grey" |
| words_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
| plt.close("all") |
| fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
| ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.1) |
| for idx in range(len(x_left)): |
| xdiff = x_right[idx] - x_left[idx] |
| ydiff = y_top[idx] - y_bottom[idx] |
| rect = patches.Rectangle( |
| (x_left[idx] - 1, y_bottom[idx] - 1), xdiff, ydiff, alpha=0.9, linewidth=1, edgecolor="k", facecolor="grey" |
| ) |
| ax.add_patch(rect) |
| fname_ending = f"{prefix}_boxes_grey" |
| word_boxes_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
| plt.close("all") |
|
|
| fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix) |
|
|
| ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times) |
| fname_ending = "fix_scatter_grey" |
| fix_scatter_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L") |
|
|
| plt.close("all") |
|
|
| arr_combo = np.stack( |
| [ |
| np.asarray(words_grey_im), |
| np.asarray(word_boxes_grey_im), |
| np.asarray(fix_scatter_grey_im), |
| ], |
| axis=2, |
| ) |
|
|
| im_combo = Image.fromarray(arr_combo) |
| fname_ending = f"{prefix}s_channel_sep" |
|
|
| save_path = get_save_path(fpath, fname_ending) |
| print(f"save_path for im combo is {save_path}") |
| im_combo.save(fpath) |
|
|
| if save_channel_repeats: |
| arr_combo = np.stack([np.asarray(words_grey_im)] * 3, axis=2) |
| im_combo = Image.fromarray(arr_combo) |
| fname_ending = f"{prefix}s_channel_repeat" |
|
|
| save_path = get_save_path(fpath, fname_ending) |
| im_combo.save(save_path) |
|
|
| arr_combo = np.stack([np.asarray(word_boxes_grey_im)] * 3, axis=2) |
|
|
| im_combo = Image.fromarray(arr_combo) |
| fname_ending = f"{prefix}boxes_channel_repeat" |
|
|
| save_path = get_save_path(fpath, fname_ending) |
| im_combo.save(save_path) |
|
|
| arr_combo = np.stack([np.asarray(fix_scatter_grey_im)] * 3, axis=2) |
|
|
| im_combo = Image.fromarray(arr_combo) |
| fname_ending = "fix_channel_repeat" |
|
|
| save_path = get_save_path(fpath, fname_ending) |
| im_combo.save(save_path) |
|
|
|
|
| def add_line_overlaps_to_sample(trial, sample): |
| char_df = pd.DataFrame(trial["chars_list"]) |
| line_overlaps = [] |
| for arr in sample: |
| y_val = arr[1] |
| line_overlap = t.tensor(-1, dtype=t.float32) |
| for idx, (x1, x2) in enumerate(zip(char_df.char_ymin.unique(), char_df.char_ymax.unique())): |
| if x1 <= y_val <= x2: |
| line_overlap = t.tensor(idx, dtype=t.float32) |
| break |
| line_overlaps.append(line_overlap) |
| line_olaps_tensor = t.stack(line_overlaps, dim=0) |
| sample = t.cat([sample, line_olaps_tensor.unsqueeze(1)], dim=1) |
| return sample |
|
|
|
|
| def norm_coords_by_letter_min_x_y( |
| sample_idx: int, |
| trialslist: list, |
| samplelist: list, |
| chars_center_coords_list: list = None, |
| ): |
| chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"]) |
| trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique() |
|
|
| min_x_chars = chars_df.char_xmin.min() |
| min_y_chars = chars_df.char_ymin.min() |
|
|
| norm_vector_substract = t.zeros( |
| (1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device |
| ) |
| norm_vector_substract[0, 0] = norm_vector_substract[0, 0] + 1 * min_x_chars |
| norm_vector_substract[0, 1] = norm_vector_substract[0, 1] + 1 * min_y_chars |
|
|
| samplelist[sample_idx] = samplelist[sample_idx] - norm_vector_substract |
|
|
| if chars_center_coords_list is not None: |
| norm_vector_substract = norm_vector_substract.squeeze(0)[:2] |
| if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_substract.shape[-1] * 2: |
| chars_center_coords_list[sample_idx][:, :2] -= norm_vector_substract |
| chars_center_coords_list[sample_idx][:, 2:] -= norm_vector_substract |
| else: |
| chars_center_coords_list[sample_idx] -= norm_vector_substract |
| return trialslist, samplelist, chars_center_coords_list |
|
|
|
|
| def norm_coords_by_letter_positions( |
| sample_idx: int, |
| trialslist: list, |
| samplelist: list, |
| meanlist: list = None, |
| stdlist: list = None, |
| return_mean_std_lists=False, |
| norm_by_char_averages=False, |
| chars_center_coords_list: list = None, |
| add_normalised_values_as_features=False, |
| ): |
| chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"]) |
| trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique() |
|
|
| min_x_chars = chars_df.char_xmin.min() |
| max_x_chars = chars_df.char_xmax.max() |
|
|
| norm_vector_multi = t.ones( |
| (1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device |
| ) |
| if norm_by_char_averages: |
| chars_list = trialslist[sample_idx]["chars_list"] |
| char_widths = np.asarray([x["char_xmax"] - x["char_xmin"] for x in chars_list]) |
| char_heights = np.asarray([x["char_ymax"] - x["char_ymin"] for x in chars_list]) |
| char_widths_average = np.mean(char_widths[char_widths > 0]) |
| char_heights_average = np.mean(char_heights[char_heights > 0]) |
|
|
| norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * char_widths_average |
| norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * char_heights_average |
|
|
| else: |
| line_height = min(np.unique(trialslist[sample_idx]["line_heights"])) |
| line_width = max_x_chars - min_x_chars |
| norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * line_width |
| norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * line_height |
| assert ~t.any(t.isnan(norm_vector_multi)), "Nan found in char norming vector" |
|
|
| norm_vector_multi = norm_vector_multi.squeeze(0) |
| if add_normalised_values_as_features: |
| norm_vector_multi = norm_vector_multi[norm_vector_multi != 1] |
| normed_features = samplelist[sample_idx][:, : norm_vector_multi.shape[0]] / norm_vector_multi |
| samplelist[sample_idx] = t.cat([samplelist[sample_idx], normed_features], dim=1) |
| else: |
| samplelist[sample_idx] = samplelist[sample_idx] / norm_vector_multi |
| if chars_center_coords_list is not None: |
| norm_vector_multi = norm_vector_multi[:2] |
| if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_multi.shape[-1] * 2: |
| chars_center_coords_list[sample_idx][:, :2] /= norm_vector_multi |
| chars_center_coords_list[sample_idx][:, 2:] /= norm_vector_multi |
| else: |
| chars_center_coords_list[sample_idx] /= norm_vector_multi |
| if return_mean_std_lists: |
| mean_val = samplelist[sample_idx].mean(axis=0).cpu().numpy() |
| meanlist.append(mean_val) |
| std_val = samplelist[sample_idx].std(axis=0).cpu().numpy() |
| stdlist.append(std_val) |
| assert ~any(np.isnan(mean_val)), "Nan found in mean_val" |
| assert ~any(np.isnan(mean_val)), "Nan found in std_val" |
|
|
| return trialslist, samplelist, meanlist, stdlist, chars_center_coords_list |
| return trialslist, samplelist, chars_center_coords_list |
|
|
|
|
| def remove_compile_from_model(model): |
| if hasattr(model.project, "_orig_mod"): |
| model.project = model.project._orig_mod |
| model.chars_conv = model.chars_conv._orig_mod |
| model.chars_classifier = model.chars_classifier._orig_mod |
| model.layer_norm_in = model.layer_norm_in._orig_mod |
| model.bert_model = model.bert_model._orig_mod |
| model.linear = model.linear._orig_mod |
| else: |
| print(f"remove_compile_from_model not done since model.project {model.project} has no orig_mod") |
| return model |
|
|
|
|
| def remove_compile_from_dict(state_dict): |
| for key in list(state_dict.keys()): |
| newkey = key.replace("._orig_mod.", ".") |
| state_dict[newkey] = state_dict.pop(key) |
| return state_dict |
|
|
|
|
| def add_text_to_ax( |
| chars_list, |
| ax, |
| font_to_use="DejaVu Sans Mono", |
| fontsize=21, |
| prefix="char", |
| plot_boxes=True, |
| plot_text=True, |
| box_annotations=None, |
| ): |
| font_props = FontProperties(family=font_to_use, style="normal", size=fontsize) |
| if not plot_boxes and not plot_text: |
| return None |
| if box_annotations is None: |
| enum = chars_list |
| else: |
| enum = zip(chars_list, box_annotations) |
| for v in enum: |
| if box_annotations is not None: |
| v, annot_text = v |
| x0, y0 = v[f"{prefix}_xmin"], v[f"{prefix}_ymin"] |
| xdiff, ydiff = v[f"{prefix}_xmax"] - v[f"{prefix}_xmin"], v[f"{prefix}_ymax"] - v[f"{prefix}_ymin"] |
| if plot_text: |
| ax.text( |
| v[f"{prefix}_x_center"], |
| v[f"{prefix}_y_center"], |
| v[prefix], |
| horizontalalignment="center", |
| verticalalignment="center", |
| fontproperties=font_props, |
| ) |
| if plot_boxes: |
| ax.add_patch(Rectangle((x0, y0), xdiff, ydiff, edgecolor="grey", facecolor="none", lw=0.8, alpha=0.4)) |
| if box_annotations is not None: |
| ax.annotate( |
| str(annot_text), |
| (x0 + xdiff / 2, y0), |
| horizontalalignment="center", |
| verticalalignment="center", |
| fontproperties=FontProperties(family=font_to_use, style="normal", size=fontsize / 1.5), |
| ) |
|
|
|
|
| def plot_fixations_and_text( |
| dffix: pd.DataFrame, |
| trial: dict, |
| plot_prefix="chars_", |
| show=False, |
| returnfig=False, |
| save=False, |
| savelocation="plot.png", |
| font_to_use="DejaVu Sans Mono", |
| fontsize=20, |
| plot_classic=True, |
| plot_boxes=True, |
| plot_text=True, |
| fig_size=(14, 8), |
| dpi=300, |
| turn_axis_on=True, |
| algo_choice="slice", |
| ): |
| fig, ax = plt.subplots(1, 1, figsize=fig_size, tight_layout=True, dpi=dpi) |
| if f"{plot_prefix}list" in trial.keys(): |
| add_text_to_ax( |
| trial[f"{plot_prefix}list"], |
| ax, |
| font_to_use, |
| fontsize=fontsize, |
| prefix=plot_prefix[:-2], |
| plot_boxes=plot_boxes, |
| plot_text=plot_text, |
| ) |
| ax.plot(dffix.x, dffix.y, "kX", label="Raw Fixations", alpha=0.9) |
|
|
| if plot_classic and f"line_num_{algo_choice}" in dffix.columns: |
| ax.scatter( |
| dffix.x, |
| dffix[f"y_{algo_choice}"], |
| marker="*", |
| color="tab:green", |
| label=f"{algo_choice} Prediction", |
| alpha=0.9, |
| ) |
| for x_before, y_before, x_after, y_after in zip( |
| dffix.x.values, dffix[f"y_{algo_choice}"].values, dffix.x, dffix.y |
| ): |
| arr_delta_x = x_after - x_before |
| arr_delta_y = y_after - y_before |
| ax.arrow(x_before, y_before, arr_delta_x, arr_delta_y, color="tab:green", alpha=0.6) |
| ax.set_ylabel("y (pixel)") |
| ax.set_xlabel("x (pixel)") |
|
|
| ax.invert_yaxis() |
| ax.legend(bbox_to_anchor=(1, 1), loc="upper left") |
| if not turn_axis_on: |
| ax.axis("off") |
| if save: |
| plt.savefig(savelocation, dpi=dpi) |
| if show: |
| plt.show() |
| if returnfig: |
| return fig |
| else: |
| plt.close() |
| return None |
|
|
|
|
| def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots): |
| gradio_temp_folder.mkdir(exist_ok=True) |
| gradio_temp_unzipped_folder.mkdir(exist_ok=True) |
| gradio_plots.mkdir(exist_ok=True) |
| return 0 |
|
|
|
|
| def get_classic_cfg(fname): |
| with open(fname, "r") as f: |
| jsonsstring = f.read() |
| classic_algos_cfg = json.loads(jsonsstring) |
| classic_algos_cfg["slice"] = classic_algos_cfg["slice"] |
| classic_algos_cfg = classic_algos_cfg |
| return classic_algos_cfg |
|
|
|
|
| def find_and_load_model(model_date="20240104-223349"): |
| model_cfg_file = list(DIST_MODELS_FOLDER.glob(f"*{model_date}*.yaml")) |
| if len(model_cfg_file) == 0: |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(f"No model cfg yaml found for {model_date}") |
| return None, None |
| model_cfg_file = model_cfg_file[0] |
| with open(model_cfg_file) as f: |
| model_cfg = yaml.safe_load(f) |
|
|
| model_cfg["system_type"] = "linux" |
| model_file = list(pl.Path("models").glob(f"*{model_date}*.ckpt"))[0] |
| model = load_model(model_file, model_cfg) |
|
|
| return model, model_cfg |
|
|
|
|
| def load_model(model_file, cfg): |
| try: |
| model_loaded = t.load(model_file, map_location="cpu") |
| if "hyper_parameters" in model_loaded.keys(): |
| model_cfg_temp = model_loaded["hyper_parameters"]["cfg"] |
| else: |
| model_cfg_temp = cfg |
| model_state_dict = model_loaded["state_dict"] |
| except Exception as e: |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(e) |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(f"Failed to load {model_file}") |
| return None |
| model = LitModel( |
| [1, 500, 3], |
| model_cfg_temp["hidden_dim_bert"], |
| model_cfg_temp["num_attention_heads"], |
| model_cfg_temp["n_layers_BERT"], |
| model_cfg_temp["loss_function"], |
| 1e-4, |
| model_cfg_temp["weight_decay"], |
| model_cfg_temp, |
| model_cfg_temp["use_lr_warmup"], |
| model_cfg_temp["use_reduce_on_plateau"], |
| track_gradient_histogram=model_cfg_temp["track_gradient_histogram"], |
| register_forw_hook=model_cfg_temp["track_activations_via_hook"], |
| char_dims=model_cfg_temp["char_dims"], |
| ) |
| model = remove_compile_from_model(model) |
| model_state_dict = remove_compile_from_dict(model_state_dict) |
| with t.no_grad(): |
| model.load_state_dict(model_state_dict, strict=False) |
| model.eval() |
| model.freeze() |
| return model |
|
|
|
|
| def set_up_models(dist_models_folder): |
| out_dict = {} |
| if "logger" in st.session_state: |
| st.session_state["logger"].info("Loading Ensemble") |
| dist_models_with_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_True*.ckpt")) |
| dist_models_without_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_False*.ckpt")) |
| DIST_MODEL_DATE_WITH_NORM = dist_models_with_norm[0].stem.split("_")[1] |
|
|
| models_without_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm] |
| models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm] |
|
|
| model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0] |
| model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0] |
|
|
| models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None] |
| models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None] |
|
|
| ensemble_model_avg = EnsembleModel( |
| models_without_norm_df, models_with_norm_df, learning_rate=0.0058, use_simple_average=True |
| ) |
| out_dict["ensemble_model_avg"] = ensemble_model_avg |
|
|
| out_dict["model_cfg_without_norm_df"] = model_cfg_without_norm_df |
| out_dict["model_cfg_with_norm_df"] = model_cfg_with_norm_df |
|
|
| single_DIST_model, single_DIST_model_cfg = find_and_load_model(model_date=DIST_MODEL_DATE_WITH_NORM) |
| out_dict["DIST_MODEL_DATE_WITH_NORM"] = DIST_MODEL_DATE_WITH_NORM |
| out_dict["single_DIST_model"] = single_DIST_model |
| out_dict["single_DIST_model_cfg"] = single_DIST_model_cfg |
| return out_dict |
|
|
|
|
| def prep_data_for_dist(model_cfg, dffix, trial=None): |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug("prep_data_for_dist entered") |
| if trial is None: |
| trial = st.session_state["trial"] |
| if isinstance(dffix, dict): |
| dffix = dffix["value"] |
| sample_tensor = t.tensor(dffix.loc[:, model_cfg["sample_cols"]].to_numpy(), dtype=t.float32) |
|
|
| if model_cfg["add_line_overlap_feature"]: |
| sample_tensor = add_line_overlaps_to_sample(trial, sample_tensor) |
|
|
| has_nans = t.any(t.isnan(sample_tensor)) |
| assert not has_nans, "NaNs found in sample tensor" |
| samplelist_eval = [sample_tensor] |
| trialslist_eval = [trial] |
| chars_center_coords_list_eval = None |
| if model_cfg["norm_coords_by_letter_min_x_y"]: |
| for sample_idx, _ in enumerate(samplelist_eval): |
| trialslist_eval, samplelist_eval, chars_center_coords_list_eval = norm_coords_by_letter_min_x_y( |
| sample_idx, |
| trialslist_eval, |
| samplelist_eval, |
| chars_center_coords_list=chars_center_coords_list_eval, |
| ) |
|
|
| if model_cfg["normalize_by_line_height_and_width"]: |
| meanlist_eval, stdlist_eval = [], [] |
| for sample_idx, _ in enumerate(samplelist_eval): |
| ( |
| trialslist_eval, |
| samplelist_eval, |
| meanlist_eval, |
| stdlist_eval, |
| chars_center_coords_list_eval, |
| ) = norm_coords_by_letter_positions( |
| sample_idx, |
| trialslist_eval, |
| samplelist_eval, |
| meanlist_eval, |
| stdlist_eval, |
| return_mean_std_lists=True, |
| norm_by_char_averages=model_cfg["norm_by_char_averages"], |
| chars_center_coords_list=chars_center_coords_list_eval, |
| add_normalised_values_as_features=model_cfg["add_normalised_values_as_features"], |
| ) |
| sample_tensor = samplelist_eval[0] |
| sample_means = t.tensor(model_cfg["sample_means"], dtype=t.float32) |
| sample_std = t.tensor(model_cfg["sample_std"], dtype=t.float32) |
| sample_tensor = (sample_tensor - sample_means) / sample_std |
| sample_tensor = sample_tensor.unsqueeze(0) |
|
|
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"Using path {trial['plot_file']} for plotting") |
| plot_text_boxes_fixations( |
| fpath=trial["plot_file"], |
| dpi=250, |
| screen_res=(1024, 768), |
| data_dir_sub=None, |
| set_font_size=True, |
| font_size=4, |
| use_words=False, |
| save_channel_repeats=False, |
| save_combo_grey_and_rgb=False, |
| dffix=dffix, |
| trial=trial, |
| ) |
|
|
| val_set = DSet( |
| sample_tensor, |
| None, |
| t.zeros((1, sample_tensor.shape[1])), |
| trialslist_eval, |
| padding_list=[0], |
| padding_at_end=model_cfg["padding_at_end"], |
| return_images_for_conv=True, |
| im_partial_string=model_cfg["im_partial_string"], |
| input_im_shape=model_cfg["char_plot_shape"], |
| ) |
| val_loader = dl(val_set, batch_size=1, shuffle=False, num_workers=0) |
| return val_loader, val_set |
|
|
|
|
| def fold_in_seq_dim(out, y=None): |
| batch_size, seq_len, num_classes = out.shape |
|
|
| out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len) |
| if y is None: |
| return out, None |
| if len(y.shape) > 2: |
| y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len) |
| else: |
| y = eo.rearrange(y, "b s -> (b s)", s=seq_len) |
| return out, y |
|
|
|
|
| def logits_to_pred(out, y=None): |
| seq_len = out.shape[1] |
| out, y = fold_in_seq_dim(out, y) |
| preds = corn_label_from_logits(out) |
| preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len) |
| if y is not None: |
| y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len) |
| y = y |
| return preds, y |
|
|
|
|
| def get_DIST_preds(dffix, trial, models_dict=None): |
| algo_choice = "DIST" |
|
|
| if models_dict is None: |
| if st.session_state["single_DIST_model"] is None or st.session_state["single_DIST_model_cfg"] is None: |
| st.session_state["single_DIST_model"], st.session_state["single_DIST_model_cfg"] = find_and_load_model( |
| model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"] |
| ) |
|
|
| if "logger" in st.session_state: |
| st.session_state["logger"].info("Model is None, reiniting model") |
| else: |
| model = st.session_state["single_DIST_model"] |
| loader, dset = prep_data_for_dist(st.session_state["single_DIST_model_cfg"], dffix, trial) |
| else: |
| model = models_dict["single_DIST_model"] |
| loader, dset = prep_data_for_dist(models_dict["single_DIST_model_cfg"], dffix, trial) |
| batch = next(iter(loader)) |
|
|
| if "cpu" not in str(model.device): |
| batch = [x.cuda() for x in batch] |
| try: |
| out = model(batch) |
| preds, y = logits_to_pred(out, y=None) |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug( |
| f"y_char_unique are {trial['y_char_unique']} for trial {trial['trial_id']}" |
| ) |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug(f"trial keys are {trial.keys()} for trial {trial['trial_id']}") |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug( |
| f"chars_list has len {len(trial['chars_list'])} for trial {trial['trial_id']}" |
| ) |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug(f"y_char_unique {trial['y_char_unique']} for trial {trial['trial_id']}") |
| if len(trial["y_char_unique"]) < 1: |
| y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique() |
| else: |
| y_char_unique = trial["y_char_unique"] |
| num_lines = trial["num_char_lines"] - 1 |
| preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy() |
| y_pred_DIST = [y_char_unique[idx] for idx in preds] |
|
|
| dffix[f"line_num_{algo_choice}"] = preds |
| dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1) |
| dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
| except Exception as e: |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(f"Exception on model(batch) for DIST \n{e}") |
| return dffix |
|
|
|
|
| def get_DIST_ensemble_preds( |
| dffix, |
| trial, |
| model_cfg_without_norm_df, |
| model_cfg_with_norm_df, |
| ensemble_model_avg, |
| ): |
| algo_choice = "DIST-Ensemble" |
| loader_without_norm, dset_without_norm = prep_data_for_dist(model_cfg_without_norm_df, dffix, trial) |
| loader_with_norm, dset_with_norm = prep_data_for_dist(model_cfg_with_norm_df, dffix, trial) |
| batch_without_norm = next(iter(loader_without_norm)) |
| batch_with_norm = next(iter(loader_with_norm)) |
| out = ensemble_model_avg((batch_without_norm, batch_with_norm)) |
| preds, y = logits_to_pred(out[0]["out_avg"], y=None) |
| if len(trial["y_char_unique"]) < 1: |
| y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique() |
| else: |
| y_char_unique = trial["y_char_unique"] |
| num_lines = trial["num_char_lines"] - 1 |
| preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy() |
| if "logger" in st.session_state: |
| st.session_state["logger"].debug(f"preds are {preds} for trial {trial['trial_id']}") |
| y_pred_DIST = [y_char_unique[idx] for idx in preds] |
|
|
| dffix[f"line_num_{algo_choice}"] = preds |
| dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1) |
| dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
| return dffix |
|
|
|
|
| def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None, models_dict=None): |
|
|
| if models_dict is None: |
| if ensemble_model_avg is None and "ensemble_model_avg" not in st.session_state: |
| if "logger" in st.session_state: |
| st.session_state["logger"].info("Ensemble Model is None, reiniting model") |
| dist_models_with_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_True*.ckpt") |
| dist_models_without_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_False*.ckpt") |
|
|
| models_without_norm_df = [ |
| find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm |
| ] |
| models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm] |
|
|
| model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0] |
| model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0] |
|
|
| models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None] |
| models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None] |
|
|
| ensemble_model_avg = EnsembleModel( |
| models_without_norm_df, models_with_norm_df, learning_rate=0.0, use_simple_average=True |
| ) |
| st.session_state["ensemble_model_avg"] = ensemble_model_avg |
| st.session_state["model_cfg_without_norm_df"] = model_cfg_without_norm_df |
| st.session_state["model_cfg_with_norm_df"] = model_cfg_with_norm_df |
| else: |
| model_cfg_without_norm_df = st.session_state["model_cfg_without_norm_df"] |
| model_cfg_with_norm_df = st.session_state["model_cfg_with_norm_df"] |
| ensemble_model_avg = st.session_state["ensemble_model_avg"] |
| dffix = get_DIST_ensemble_preds( |
| dffix, |
| trial, |
| st.session_state["model_cfg_without_norm_df"], |
| st.session_state["model_cfg_with_norm_df"], |
| st.session_state["ensemble_model_avg"], |
| ) |
| else: |
| dffix = get_DIST_ensemble_preds( |
| dffix, |
| trial, |
| models_dict["model_cfg_without_norm_df"], |
| models_dict["model_cfg_with_norm_df"], |
| models_dict["ensemble_model_avg"], |
| ) |
| return dffix |
|
|
|
|
| def correct_df( |
| dffix, |
| algo_choice, |
| trial=None, |
| for_multi=False, |
| ensemble_model_avg=None, |
| is_outside_of_streamlit=False, |
| classic_algos_cfg=None, |
| models_dict=None, |
| ): |
| if is_outside_of_streamlit: |
| stqdm = tqdm |
| else: |
| from stqdm import stqdm |
| if classic_algos_cfg is None: |
| classic_algos_cfg = st.session_state["classic_algos_cfg"] |
| if trial is None and not for_multi: |
| trial = st.session_state["trial"] |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"Applying {algo_choice} to fixations for trial {trial['trial_id']}") |
|
|
| if isinstance(dffix, dict): |
| dffix = dffix["value"] |
| if "x" not in dffix.keys() or "x" not in dffix.keys(): |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(f"x or y not in dffix") |
| if "logger" in st.session_state: |
| st.session_state["logger"].warning(dffix.columns) |
| return dffix |
| if isinstance(algo_choice, list): |
| algo_choices = algo_choice |
| repeats = range(len(algo_choice)) |
| else: |
| algo_choices = [algo_choice] |
| repeats = range(1) |
| for algoIdx in stqdm(repeats, desc="Applying correction algorithms"): |
| algo_choice = algo_choices[algoIdx] |
| st_proc = time.process_time() |
| st_wall = time.time() |
|
|
| if algo_choice == "DIST": |
| dffix = get_DIST_preds(dffix, trial, models_dict=models_dict) |
|
|
| elif algo_choice == "DIST-Ensemble": |
| dffix = get_EDIST_preds_with_model_check(dffix, trial, models_dict=models_dict) |
| elif algo_choice == "Wisdom_of_Crowds_with_DIST": |
| dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
| dffix = get_DIST_preds(dffix, trial, models_dict=models_dict) |
| for _ in range(3): |
| corrections.append(np.asarray(dffix.loc[:, "y_DIST"])) |
| dffix = apply_woc(dffix, trial, corrections, algo_choice) |
| elif algo_choice == "Wisdom_of_Crowds_with_DIST_Ensemble": |
| dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
| dffix = get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg, models_dict=models_dict) |
| for _ in range(3): |
| corrections.append(np.asarray(dffix.loc[:, "y_DIST-Ensemble"])) |
| dffix = apply_woc(dffix, trial, corrections, algo_choice) |
| elif algo_choice == "Wisdom_of_Crowds": |
| dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg) |
| dffix = apply_woc(dffix, trial, corrections, algo_choice) |
|
|
| else: |
| algo_cfg = classic_algos_cfg[algo_choice] |
| dffix = calgo.apply_classic_algo(dffix, trial, algo_choice, algo_cfg) |
| dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1) |
|
|
| et_proc = time.process_time() |
| time_proc = et_proc - st_proc |
| et_wall = time.time() |
| time_wall = et_wall - st_wall |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"time_proc {algo_choice} {time_proc}") |
| if "logger" in st.session_state: |
| st.session_state["logger"].info(f"time_wall {algo_choice} {time_wall}") |
| if for_multi: |
| return dffix |
| else: |
| if "start_time" in dffix.columns: |
| dffix = dffix.drop(axis=1, labels=["start_time", "end_time"]) |
| return dffix, export_csv(dffix, trial) |
|
|
| def set_font_from_chars_list(trial): |
| |
| if "chars_list" in trial: |
| chars_df = pd.DataFrame(trial["chars_list"]) |
| line_diffs = np.diff(chars_df.char_y_center.unique()) |
| y_diffs = np.unique(line_diffs) |
| if len(y_diffs) == 1: |
| y_diff = y_diffs[0] |
| else: |
| y_diff = np.min(y_diffs) |
| y_diff = round(y_diff * 2) / 2 |
|
|
| else: |
| y_diff = 1 / 0.333 * 18 |
| font_size = y_diff * 0.333 |
| return round((font_size)*4,ndigits=0)/4 |
|
|
| def get_font_and_font_size_from_trial(trial): |
| font_face, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS) |
|
|
| if font_size is None and "font_size" in trial: |
| font_size = trial["font_size"] |
| elif font_size is None: |
| font_size = set_font_from_chars_list(trial) |
| return font_face, font_size |
|
|
|
|
| def sigmoid(x): |
| return 1 / (1 + np.exp(-1 * x)) |
|
|
|
|
| def matplotlib_plot_df( |
| dffix, |
| trial, |
| algo_choice, |
| stimulus_prefix="word", |
| desired_dpi=300, |
| fix_to_plot=[], |
| stim_info_to_plot=["Words", "Word boxes"], |
| box_annotations=None, |
| ): |
| chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None |
|
|
| if chars_df is not None: |
| font_face, font_size = get_font_and_font_size_from_trial(trial) |
| font_size = font_size * 0.65 |
| else: |
| st.warning("No character or word information available to plot") |
|
|
| if "display_coords" in trial: |
| desired_width_in_pixels = trial["display_coords"][2] + 1 |
| desired_height_in_pixels = trial["display_coords"][3] + 1 |
| else: |
| desired_width_in_pixels = 1920 |
| desired_height_in_pixels = 1080 |
|
|
| figure_width = desired_width_in_pixels / desired_dpi |
| figure_height = desired_height_in_pixels / desired_dpi |
|
|
| fig = plt.figure(figsize=(figure_width, figure_height), dpi=desired_dpi) |
| ax = fig.add_subplot(1, 1, 1) |
| fig.subplots_adjust(bottom=0) |
| fig.subplots_adjust(top=1) |
| fig.subplots_adjust(right=1) |
| fig.subplots_adjust(left=0) |
| if "font" in trial and trial["font"] in AVAILABLE_FONTS: |
| font_to_use = trial["font"] |
| else: |
| font_to_use = "DejaVu Sans Mono" |
| if "font_size" in trial: |
| font_size = trial["font_size"] |
| else: |
| font_size = 20 |
|
|
| if f"{stimulus_prefix}s_list" in trial: |
| add_text_to_ax( |
| trial[f"{stimulus_prefix}s_list"], |
| ax, |
| font_to_use, |
| prefix=stimulus_prefix, |
| fontsize=font_size / 3.89, |
| plot_text=False, |
| plot_boxes=True if "Word boxes" in stim_info_to_plot else False, |
| box_annotations=box_annotations, |
| ) |
|
|
| if "chars_list" in trial: |
| add_text_to_ax( |
| trial["chars_list"], |
| ax, |
| font_to_use, |
| prefix="char", |
| fontsize=font_size / 3.89, |
| plot_text=True if "Words" in stim_info_to_plot else False, |
| plot_boxes=False, |
| box_annotations=None, |
| ) |
|
|
| if "Uncorrected Fixations" in fix_to_plot: |
| ax.plot(dffix.x, dffix.y, label="Raw fixations", color="blue", alpha=0.6, linewidth=0.6) |
|
|
| x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values |
| x1 = dffix.x.iloc[range(1, len(dffix.x))].values |
| y0 = dffix.y.iloc[range(len(dffix.y) - 1)].values |
| y1 = dffix.y.iloc[range(1, len(dffix.y))].values |
| xpos = x0 |
| ypos = y0 |
| xdir = x1 - x0 |
| ydir = y1 - y0 |
| for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir): |
| ax.annotate( |
| "", |
| xytext=(X, Y), |
| xy=(X + 0.001 * dX, Y + 0.001 * dY), |
| arrowprops=dict(arrowstyle="fancy", color="blue"), |
| size=8, |
| alpha=0.3, |
| ) |
| if "Corrected Fixations" in fix_to_plot: |
| if isinstance(algo_choice, list): |
| algo_choices = algo_choice |
| repeats = range(len(algo_choice)) |
| else: |
| algo_choices = [algo_choice] |
| repeats = range(1) |
| for algoIdx in repeats: |
| algo_choice = algo_choices[algoIdx] |
| if f"y_{algo_choice}" in dffix.columns: |
| ax.plot( |
| dffix.x, |
| dffix.loc[:, f"y_{algo_choice}"], |
| label="Raw fixations", |
| color=COLORS[algoIdx], |
| alpha=0.6, |
| linewidth=0.6, |
| ) |
|
|
| x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values |
| x1 = dffix.x.iloc[range(1, len(dffix.x))].values |
| y0 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(len(dffix.loc[:, f"y_{algo_choice}"]) - 1)].values |
| y1 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(1, len(dffix.loc[:, f"y_{algo_choice}"]))].values |
| xpos = x0 |
| ypos = y0 |
| xdir = x1 - x0 |
| ydir = y1 - y0 |
| for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir): |
| ax.annotate( |
| "", |
| xytext=(X, Y), |
| xy=(X + 0.001 * dX, Y + 0.001 * dY), |
| arrowprops=dict(arrowstyle="fancy", color=COLORS[algoIdx]), |
| size=8, |
| alpha=0.3, |
| ) |
|
|
| ax.set_xlim((0, desired_width_in_pixels)) |
| ax.set_ylim((0, desired_height_in_pixels)) |
| ax.invert_yaxis() |
|
|
| return fig, desired_width_in_pixels, desired_height_in_pixels |
|
|
|
|
| def plotly_plot_with_image( |
| dffix, |
| trial, |
| algo_choice, |
| to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"], |
| scale_factor=0.5, |
| ): |
| fig, img_width, img_height = matplotlib_plot_df( |
| dffix, trial, algo_choice, desired_dpi=300, fix_to_plot=[], stim_info_to_plot=to_plot_list |
| ) |
| fig.savefig(TEMP_FIGURE_STIMULUS_PATH) |
| fig = go.Figure() |
| fig.add_trace( |
| go.Scatter( |
| x=[0, img_width * scale_factor], |
| y=[img_height * scale_factor, 0], |
| mode="markers", |
| marker_opacity=0, |
| name="scale_helper", |
| ) |
| ) |
|
|
| fig.update_xaxes(visible=False, range=[0, img_width * scale_factor]) |
|
|
| fig.update_yaxes( |
| visible=False, |
| range=[img_height * scale_factor, 0], |
| scaleanchor="x", |
| ) |
| if "Words" in to_plot_list or "Word boxes" in to_plot_list: |
| imsource = Image.open(str(TEMP_FIGURE_STIMULUS_PATH)) |
| fig.add_layout_image( |
| dict( |
| x=0, |
| sizex=img_width * scale_factor, |
| y=0, |
| sizey=img_height * scale_factor, |
| xref="x", |
| yref="y", |
| opacity=1.0, |
| layer="below", |
| sizing="stretch", |
| source=imsource, |
| ) |
| ) |
|
|
| if "Uncorrected Fixations" in to_plot_list: |
| duration_scaled = dffix.duration - dffix.duration.min() |
| duration_scaled = ((duration_scaled / duration_scaled.max()) - 0.5) * 3 |
| duration = sigmoid(duration_scaled) * 50 * scale_factor |
| fig.add_trace( |
| go.Scatter( |
| x=dffix.x * scale_factor, |
| y=dffix.y * scale_factor, |
| mode="markers+lines+text", |
| name="Raw fixations", |
| marker=dict( |
| color=COLORS[-1], |
| symbol="arrow", |
| size=duration.values, |
| angleref="previous", |
| line=dict(color="black", width=duration.values / 10), |
| ), |
| line_width=2 * scale_factor, |
| text=np.arange(len(dffix.x)), |
| textposition="middle right", |
| textfont=dict( |
| family="sans serif", |
| size=18 * scale_factor, |
| ), |
| hoverinfo="text+x+y", |
| opacity=0.9, |
| ) |
| ) |
|
|
| if "Corrected Fixations" in to_plot_list: |
| if isinstance(algo_choice, list): |
| algo_choices = algo_choice |
| repeats = range(len(algo_choice)) |
| else: |
| algo_choices = [algo_choice] |
| repeats = range(1) |
| for algoIdx in repeats: |
| algo_choice = algo_choices[algoIdx] |
| if f"y_{algo_choice}" in dffix.columns: |
| fig.add_trace( |
| go.Scatter( |
| x=dffix.x * scale_factor, |
| y=dffix.loc[:, f"y_{algo_choice}"] * scale_factor, |
| mode="markers", |
| name=f"{algo_choice} corrected", |
| marker_color=COLORS[algoIdx], |
| marker_size=10 * scale_factor, |
| hoverinfo="text+x+y", |
| opacity=0.75, |
| ) |
| ) |
|
|
| fig.update_layout( |
| plot_bgcolor=None, |
| width=img_width * scale_factor, |
| height=img_height * scale_factor, |
| margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
| legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8), |
| ) |
|
|
| for trace in fig["data"]: |
| if trace["name"] == "scale_helper": |
| trace["showlegend"] = False |
| return fig |
|
|
|
|
| def plot_y_corr(dffix, algo_choice, margin=dict(t=40, l=10, r=10, b=1)): |
| num_datapoints = len(dffix.x) |
|
|
| layout = dict( |
| plot_bgcolor="white", |
| autosize=True, |
| margin=margin, |
| xaxis=dict( |
| title="Fixation Index", |
| linecolor="black", |
| range=[-1, num_datapoints + 1], |
| showgrid=False, |
| mirror="all", |
| showline=True, |
| ), |
| yaxis=dict( |
| title="y correction", |
| side="left", |
| linecolor="black", |
| showgrid=False, |
| mirror="all", |
| showline=True, |
| ), |
| legend=dict(orientation="v", yanchor="middle", y=0.95, xanchor="left", x=1.05), |
| ) |
| if isinstance(dffix, dict): |
| dffix = dffix["value"] |
| algo_string = algo_choice[0] if isinstance(algo_choice, list) else algo_choice |
| if f"y_{algo_string}_correction" not in dffix.columns: |
| st.session_state["logger"].warning("No correction column found in dataframe") |
| return go.Figure(layout=layout) |
| if isinstance(dffix, dict): |
| dffix = dffix["value"] |
|
|
| fig = go.Figure(layout=layout) |
|
|
| if isinstance(algo_choice, list): |
| algo_choices = algo_choice |
| repeats = range(len(algo_choice)) |
| else: |
| algo_choices = [algo_choice] |
| repeats = range(1) |
| for algoIdx in repeats: |
| algo_choice = algo_choices[algoIdx] |
| fig.add_trace( |
| go.Scatter( |
| x=np.arange(num_datapoints), |
| y=dffix.loc[:, f"y_{algo_choice}_correction"], |
| mode="markers", |
| name=f"{algo_choice} y correction", |
| marker_color=COLORS[algoIdx], |
| marker_size=3, |
| showlegend=True, |
| ) |
| ) |
| fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor="black") |
|
|
| return fig |
|
|
|
|
| def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH): |
| if not os.path.isdir(EXAMPLES_FOLDER): |
| os.mkdir(EXAMPLES_FOLDER) |
|
|
| if not os.path.exists(EXAMPLES_ASC_ZIP_FILENAME): |
| download_url(OSF_DOWNLAOD_LINK, EXAMPLES_ASC_ZIP_FILENAME) |
| |
|
|
| if os.path.exists(EXAMPLES_ASC_ZIP_FILENAME): |
| if EXAMPLES_FOLDER_PATH.exists(): |
| EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")] |
| if len(EXAMPLE_ASC_FILES) != 4: |
| try: |
| with zipfile.ZipFile(EXAMPLES_ASC_ZIP_FILENAME, "r") as zip_ref: |
| zip_ref.extractall(EXAMPLES_FOLDER) |
| except Exception as e: |
| st.session_state["logger"].warning(e) |
| st.session_state["logger"].warning(f"Extracting {EXAMPLES_ASC_ZIP_FILENAME} failed") |
|
|
| EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")] |
| return EXAMPLE_ASC_FILES |
|
|
|
|
| def process_trial_choice_single_csv(trial, algo_choice, file=None): |
| trial_id = trial["trial_id"] |
| if "dffix" in trial: |
| dffix = trial["dffix"] |
| else: |
| if file is None: |
| file = st.session_state["single_csv_file"] |
| trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{file.name}_{trial_id}_2ndInput_chars_channel_sep.png")) |
| trial["fname"] = str(file.name) |
| dffix = trial["dffix"] = st.session_state["trials_by_ids_single_csv"][trial_id]["dffix"] |
|
|
| font, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS) |
| chars_df = pd.DataFrame(trial["chars_list"]) |
| trial["chars_df"] = chars_df.to_dict() |
| trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique()) |
| if algo_choice is not None: |
| dffix, _ = correct_df(dffix, algo_choice, trial) |
| return dffix, trial, dpi, screen_res, font, font_size |
|
|
|
|
| def add_default_font_and_character_props_to_state(trial): |
| chars_list = trial["chars_list"] |
| chars_df = pd.DataFrame(trial["chars_list"]) |
| line_diffs = np.diff(chars_df.char_y_center.unique()) |
| y_diffs = np.unique(line_diffs) |
| if len(y_diffs) == 1: |
| y_diff = y_diffs[0] |
| else: |
| y_diff = np.min(y_diffs) |
| y_diff = round(y_diff * 2) / 2 |
| x_txt_start = chars_list[0]["char_xmin"] |
| y_txt_start = chars_list[0]["char_y_center"] |
|
|
| font_face, font_size = get_font_and_font_size_from_trial(trial) |
|
|
| line_height = y_diff |
| return y_diff, x_txt_start, y_txt_start, font_face, font_size, line_height |
|
|
| def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"): |
| if use_corrected_fixations: |
| dffix_copy = copy.deepcopy(dffix) |
| dffix_copy["y"] = dffix_copy[f"y_{correction_algo}"] |
| else: |
| dffix_copy = dffix |
| initial_landing_position_own_vals = anf.initial_landing_position_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| second_pass_duration_own_vals = anf.second_pass_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
| number_of_fixations_own_vals = anf.number_of_fixations_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
| initial_fixation_duration_own_vals = anf.initial_fixation_duration_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| first_of_many_duration_own_vals = anf.first_of_many_duration_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| total_fixation_duration_own_vals = anf.total_fixation_duration_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| gaze_duration_own_vals = anf.gaze_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
| go_past_duration_own_vals = anf.go_past_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
| initial_landing_distance_own_vals = anf.initial_landing_distance_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| landing_distances_own_vals = anf.landing_distances_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index") |
| number_of_regressions_in_own_vals = anf.number_of_regressions_in_own(trial, dffix_copy, prefix).set_index( |
| f"{prefix}_index" |
| ) |
| own_measure_df = pd.concat( |
| [ |
| df.drop(prefix, axis=1) |
| for df in [ |
| number_of_fixations_own_vals, |
| initial_fixation_duration_own_vals, |
| first_of_many_duration_own_vals, |
| total_fixation_duration_own_vals, |
| gaze_duration_own_vals, |
| go_past_duration_own_vals, |
| second_pass_duration_own_vals, |
| initial_landing_position_own_vals, |
| initial_landing_distance_own_vals, |
| landing_distances_own_vals, |
| number_of_regressions_in_own_vals, |
| ] |
| ], |
| axis=1, |
| ) |
| own_measure_df[prefix] = number_of_fixations_own_vals[prefix] |
| first_column = own_measure_df.pop(prefix) |
| own_measure_df.insert(0, prefix, first_column) |
| own_measure_df.insert(0, f"{prefix}_num", np.arange((own_measure_df.shape[0]))) |
| return own_measure_df |