Instructions to use yuneun92/koCSN_SAPR with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use yuneun92/koCSN_SAPR with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="yuneun92/koCSN_SAPR")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("yuneun92/koCSN_SAPR", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Author: | |
| """ | |
| import copy | |
| from typing import Any | |
| from ckonlpy.tag import Twitter | |
| from tqdm import tqdm | |
| import re | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.model_selection import train_test_split | |
| # ์ฌ์ฉ์๊ฐ ์ฌ์ ์ ๋จ์ด ์ถ๊ฐ๊ฐ ๊ฐ๋ฅํ ํํ์ ๋ถ์๊ธฐ๋ฅผ ์ด์ฉ(์ถํ์ name_list์ ๋ฑ์ฌ๋ ์ด๋ฆ์ ๋ฑ๋กํ์ฌ ์ธ์ ๋ฐ ๋ถ๋ฆฌํ๊ธฐ ์ํจ) | |
| twitter = Twitter() | |
| def load_data(filename) -> Any: | |
| """ | |
| ์ง์ ๋ ํ์ผ์์ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| """ | |
| return torch.load(filename) | |
| def NML(seg_sents, mention_positions, ws): | |
| """ | |
| Nearest Mention Location (ํน์ ํ๋ณด ๋ฐํ์๊ฐ ์ธ๊ธ๋ ์์น์ค, ์ธ์ฉ๋ฌธ์ผ๋ก๋ถํฐ ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น๋ฅผ ์ฐพ๋ ํจ์) | |
| Parameters: | |
| - seg_sents: ๋ฌธ์ฅ์ ๋ถํ ํ ๋ฆฌ์คํธ | |
| - mention_positions: ํน์ ํ๋ณด ๋ฐํ์๊ฐ ์ธ๊ธ๋ ์์น๋ฅผ ๋ชจ๋ ๋ด์ ๋ฆฌ์คํธ [(sentence_index, word_index), ...] | |
| - ws: ์ธ์ฉ๋ฌธ ์/๋ค๋ก ๊ณ ๋ คํ ๋ฌธ์ฅ์ ์ | |
| Returns: | |
| - ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น์ (sentence_index, word_index) | |
| """ | |
| def word_dist(pos): | |
| """ | |
| ๋ฐํ ํ๋ณด์ ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น์ ์ธ์ฉ๋ฌธ ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๋จ์ด ์์ค(word level)์์ ๋ฐํํฉ๋๋ค. | |
| Parameters: | |
| - pos: ๋ฐํ ํ๋ณด์๊ฐ ์ธ๊ธ๋ ์์น (sentence_index, word_index) | |
| Returns: | |
| - ๋ฐํ ํ๋ณด์์ ์ธ๊ธ๋ ์์น ์ฌ์ด์ ๊ฑฐ๋ฆฌ (๋จ์ด ์์ค) | |
| """ | |
| if pos[0] == ws: | |
| w_d = ws * 2 | |
| elif pos[0] < ws: | |
| w_d = sum(len( | |
| sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:]) | |
| else: | |
| w_d = sum( | |
| len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]]) | |
| return w_d | |
| # ์ธ๊ธ๋ ์์น๋ค๊ณผ ์ธ์ฉ๋ฌธ ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ๊น์ด ์์ผ๋ก ์ ๋ ฌ | |
| sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x)) | |
| # ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ๊ธ ์์น(Nearest Mention Location) ๋ฐํ | |
| return sorted_positions[0] | |
| def max_len_cut(seg_sents, mention_pos, max_len): | |
| """ | |
| ์ฃผ์ด์ง ๋ฌธ์ฅ์ ๋ชจ๋ธ์ ์ ๋ ฅ ๊ฐ๋ฅํ ์ต๋ ๊ธธ์ด(max_len)๋ก ์๋ฅด๋ ํจ์ | |
| Parameters: | |
| - seg_sents: ๋ฌธ์ฅ์ ๋ถํ ํ ๋ฆฌ์คํธ | |
| - mention_pos: ๋ฐํ ํ๋ณด์๊ฐ ์ธ๊ธ๋ ์์น (sentence_index, word_index) | |
| - max_len: ์ ๋ ฅ ๊ฐ๋ฅํ ์ต๋ ๊ธธ์ด | |
| Returns: | |
| - seg_sents : ์๋ฅด๊ณ ๋จ์ ๋ฌธ์ฅ ๋ฆฌ์คํธ | |
| - mention_pos : ์กฐ์ ๋ ์ธ๊ธ๋ ์์น | |
| """ | |
| # ๊ฐ ๋ฌธ์ฅ์ ๊ธธ์ด๋ฅผ ๋ฌธ์ ๋จ์๋ก ๊ณ์ฐํ ๋ฆฌ์คํธ ์์ฑ | |
| sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents] | |
| # ์ ์ฒด ๋ฌธ์์ ๊ธธ์ด ํฉ | |
| sum_char_len = sum(sent_char_lens) | |
| # ๊ฐ ๋ฌธ์ฅ์์, cut์ ์คํํ ๋ฌธ์์ ์์น(๋งจ ๋ง์ง๋ง ๋ฌธ์) | |
| running_cut_idx = [len(sent) - 1 for sent in seg_sents] | |
| while sum_char_len > max_len: | |
| max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0] | |
| if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]: | |
| running_cut_idx[max_len_sent_idx] -= 1 | |
| if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]: | |
| mention_pos[1] -= 1 | |
| reduced_char_len = len( | |
| seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]) | |
| sent_char_lens[max_len_sent_idx] -= reduced_char_len | |
| sum_char_len -= reduced_char_len | |
| # ์๋ฅผ ์์น ์ญ์ | |
| del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]] | |
| # ์๋ฅผ ์์น ์ ๋ฐ์ดํธ | |
| running_cut_idx[max_len_sent_idx] -= 1 | |
| return seg_sents, mention_pos | |
| def seg_and_mention_location(raw_sents_in_list, alias2id): | |
| """ | |
| ์ฃผ์ด์ง ๋ฌธ์ฅ์ ๋ถํ ํ๊ณ ๋ฐํ์ ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ์ฐพ๋ ํจ์ | |
| Parameters: | |
| - raw_sents_in_list: ๋ถํ ํ ์๋ณธ ๋ฌธ์ฅ ๋ฆฌ์คํธ | |
| - alias2id: ์บ๋ฆญํฐ ๋ณ ์ด๋ฆ(๋ฐ ๋ณ์นญ)๊ณผ ID๋ฅผ ๋งคํํ ๋์ ๋๋ฆฌ | |
| Returns: | |
| - seg_sents: ๋ฌธ์ฅ์ ๋จ์ด๋ก ๋ถํ ํ ๋ฆฌ์คํธ | |
| - character_mention_poses: ์บ๋ฆญํฐ๋ณ๋ก, ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ๋ชจ๋ ์ ์ฅํ ๋์ ๋๋ฆฌ {character1_id: [[sent_idx, word_idx], ...]} | |
| - name_list_index: ์ธ๊ธ๋ ์บ๋ฆญํฐ ์ด๋ฆ ๋ฆฌ์คํธ | |
| """ | |
| character_mention_poses = {} | |
| seg_sents = [] | |
| id_pattern = ['&C{:02d}&'.format(i) for i in range(51)] | |
| for sent_idx, sent in enumerate(raw_sents_in_list): | |
| raw_sent_with_split = sent.split() | |
| for word_idx, word in enumerate(raw_sent_with_split): | |
| match = re.search(r'&C\d{1,2}&', word) | |
| # &C00& ํ์์ผ๋ก ๋ ์ด๋ฆ์ด ์์ ๊ฒฝ์ฐ, result ๋ณ์๋ก ์ง์ | |
| if match: | |
| result = match.group(0) | |
| if alias2id[result] in character_mention_poses: | |
| character_mention_poses[alias2id[result]].append([sent_idx, word_idx]) | |
| else: | |
| character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]] | |
| seg_sents.append(raw_sent_with_split) | |
| name_list_index = list(character_mention_poses.keys()) | |
| return seg_sents, character_mention_poses, name_list_index | |
| def create_CSS(seg_sents, candidate_mention_poses, args): | |
| """ | |
| ๊ฐ ์ธ์คํด์ค ๋ด ๊ฐ ๋ฐํ์ ํ๋ณด(candidate)์ ๋ํ์ฌ candidate-specific segments(CSS)๋ฅผ ๋ง๋ญ๋๋ค. | |
| parameters: | |
| seg_sents: 2ws + 1 ๊ฐ์ ๋ฌธ์ฅ(๊ฐ ๋ฌธ์ฅ์ ๋ถํ ๋จ)๋ค์ ๋ด์ ๋ฆฌ์คํธ | |
| candidate_mention_poses: ๋ฐํ์๋ณ๋ก ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น๋ฅผ ๋ด๊ณ ์๋ ๋์ ๋๋ฆฌ์ด๋ฉฐ, ํํ๋ ๋ค์๊ณผ ๊ฐ์. | |
| {character index: [[sentence index, word index in sentence] of mention 1,...]...}. | |
| args : ์คํ ์ธ์๋ฅผ ๋ด์ ๊ฐ์ฒด | |
| return: | |
| Returned contents are in lists, in which each element corresponds to a candidate. | |
| The order of candidate is consistent with that in list(candidate_mention_poses.keys()). | |
| many_css: ๊ฐ ๋ฐํ์ ํ๋ณด์ ๋ํ candidate-specific segments(CSS). | |
| many_sent_char_len: ๊ฐ CSS์ ๋ฌธ์ ๊ธธ์ด ์ ๋ณด | |
| [[character-level length of sentence 1,...] of the CSS of candidate 1,...]. | |
| many_mention_pos: CSS ๋ด์์, ์ธ์ฉ๋ฌธ๊ณผ ๊ฐ์ฅ ๊ฐ๊น์ด ์ด๋ฆ์ด ์ธ๊ธ๋ ์์น ์ ๋ณด | |
| [(sentence-level index of nearest mention in CSS, | |
| character-level index of the leftmost character of nearest mention in CSS, | |
| character-level index of the rightmost character + 1) of candidate 1,...]. | |
| many_quote_idx: CSS ๋ด์ ์ธ์ฉ๋ฌธ์ ๋ฌธ์ฅ ์ธ๋ฑ์ค | |
| many_cut_css : ์ต๋ ๊ธธ์ด ์ ํ์ด ์ ์ฉ๋ CSS | |
| """ | |
| ws = args.ws | |
| max_len = args.length_limit | |
| model_name = args.model_name | |
| # assert len(seg_sents) == ws * 2 + 1 | |
| many_css = [] | |
| many_sent_char_lens = [] | |
| many_mention_poses = [] | |
| many_quote_idxes = [] | |
| many_cut_css = [] | |
| for candidate_idx in candidate_mention_poses.keys(): | |
| nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws) | |
| if nearest_pos[0] <= ws: | |
| CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1]) | |
| mention_pos = [0, nearest_pos[1]] | |
| quote_idx = ws - nearest_pos[0] | |
| else: | |
| CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1]) | |
| mention_pos = [nearest_pos[0] - ws, nearest_pos[1]] | |
| quote_idx = 0 | |
| cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, max_len) | |
| sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS] | |
| mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum( | |
| len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]]) | |
| mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]]) | |
| if model_name == 'CSN': | |
| mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right) | |
| cat_CSS = ''.join([''.join(sent) for sent in cut_CSS]) | |
| elif model_name == 'KCSN': | |
| mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1]) | |
| cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS]) | |
| many_css.append(cat_CSS) | |
| many_sent_char_lens.append(sent_char_lens) | |
| many_mention_poses.append(mention_pos) | |
| many_quote_idxes.append(quote_idx) | |
| many_cut_css.append(cut_CSS) | |
| return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css | |
| class ISDataset(Dataset): | |
| """ | |
| ๋ฐํ์ ์๋ณ์ ์ํ ๋ฐ์ดํฐ์ ์๋ธํด๋์ค | |
| """ | |
| def __init__(self, data_list): | |
| super(ISDataset, self).__init__() | |
| self.data = data_list | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def build_data_loader(data_file, alias2id, args, save_name=None) -> DataLoader: | |
| """ | |
| ํ์ต์ ์ํ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค. | |
| """ | |
| # ์ฌ์ ์ ์ด๋ฆ์ ์ถ๊ฐ | |
| for alias in alias2id: | |
| twitter.add_dictionary(alias, 'Noun') | |
| # ํ์ผ์ ์ค๋ณ๋ก ๋ถ๋ฌ๋ค์ | |
| with open(data_file, 'r', encoding='utf-8') as fin: | |
| data_lines = fin.readlines() | |
| # ์ ์ฒ๋ฆฌ | |
| data_list = [] | |
| for i, line in enumerate(tqdm(data_lines)): | |
| offset = i % 31 | |
| if offset == 0: | |
| instance_index = line.strip().split()[-1] | |
| raw_sents_in_list = [] | |
| continue | |
| if offset < 22: | |
| raw_sents_in_list.append(line.strip()) | |
| if offset == 22: | |
| speaker_name = line.strip().split()[-1] | |
| # ๋น ๋ฆฌ์คํธ๋ ์ ๊ฑฐ | |
| filtered_list = [li for li in raw_sents_in_list if li] | |
| # ๋ฌธ์ฅ ๋ถํ ๋ฐ ๋ฑ์ฅ์ธ๋ฌผ ์ธ๊ธ ์์น ์ถ์ถ | |
| seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location( | |
| filtered_list, alias2id) | |
| # CSS ์์ฑ | |
| css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS( | |
| seg_sents, candidate_mention_poses, args) | |
| # ํ๋ณด์ ๋ฆฌ์คํธ | |
| candidates_list = list(candidate_mention_poses.keys()) | |
| # ์ํซ ๋ ์ด๋ธ ์์ฑ | |
| one_hot_label = [0 if character_idx != alias2id[speaker_name] | |
| else 1 for character_idx in candidate_mention_poses.keys()] | |
| true_index = one_hot_label.index(1) if 1 in one_hot_label else 0 | |
| if offset == 24: | |
| category = line.strip().split()[-1] | |
| if offset == 25: | |
| name = ' '.join(line.strip().split()[1:]) | |
| if offset == 26: | |
| scene = line.strip().split()[-1] | |
| if offset == 27: | |
| place = line.strip().split()[-1] | |
| if offset == 28: | |
| time = line.strip().split()[-1] | |
| if offset == 29: | |
| cut_position = line.strip().split()[-1] | |
| data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes, | |
| cut_css, one_hot_label, true_index, category, name_list_index, | |
| name, scene, place, time, cut_position, candidates_list, | |
| instance_index)) | |
| # ๋ฐ์ดํฐ๋ก๋ ์์ฑ | |
| data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0]) | |
| # ์ ์ฅํ ์ด๋ฆ์ด ์ฃผ์ด์ง ๊ฒฝ์ฐ ๋ฐ์ดํฐ ๋ฆฌ์คํธ ์ ์ฅ | |
| if save_name is not None: | |
| torch.save(data_list, save_name) | |
| return data_loader | |
| def load_data_loader(saved_filename: str) -> DataLoader: | |
| """ | |
| ์ ์ฅ๋ ํ์ผ์์ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ DataLoader ๊ฐ์ฒด๋ก ๋ณํํฉ๋๋ค. | |
| """ | |
| # ์ ์ฅ๋ ๋ฐ์ดํฐ ๋ฆฌ์คํธ ๋ก๋ | |
| data_list = load_data(saved_filename) | |
| return DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0]) | |
| def split_train_val_test(data_file, alias2id, args, save_name=None, test_size=0.2, val_size=0.1, random_state=13): | |
| """ | |
| ๊ธฐ์กด ๊ฒ์ฆ ๋ฐฉ์์ ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋น๋ํฉ๋๋ค. | |
| ์ฃผ์ด์ง ๋ฐ์ดํฐ ํ์ผ์ ํ๋ จ, ๊ฒ์ฆ, ํ ์คํธ ์ธํธ๋ก ๋ถํ ํ๊ณ ๊ฐ๊ฐ์ DataLoader๋ฅผ ์์ฑํฉ๋๋ค. | |
| Parameters: | |
| - data_file: ๋ถํ ํ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก | |
| - alias2id: ๋ฑ์ฅ์ธ๋ฌผ ์ด๋ฆ๊ณผ ID๋ฅผ ๋งคํํ ๋์ ๋๋ฆฌ | |
| - args: ์คํ ์ธ์๋ฅผ ๋ด์ ๊ฐ์ฒด | |
| - save_name: ๋ถํ ๋ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ ํ์ผ ์ด๋ฆ | |
| - test_size: ํ ์คํธ ์ธํธ์ ๋น์จ (๊ธฐ๋ณธ๊ฐ: 0.2) | |
| - val_size: ๊ฒ์ฆ ์ธํธ์ ๋น์จ (๊ธฐ๋ณธ๊ฐ: 0.1) | |
| - random_state: ๋๋ค ์๋ (๊ธฐ๋ณธ๊ฐ: 13) | |
| Returns: | |
| - train_loader: ํ๋ จ ๋ฐ์ดํฐ๋ก๋ | |
| - val_loader: ๊ฒ์ฆ ๋ฐ์ดํฐ๋ก๋ | |
| - test_loader: ํ ์คํธ ๋ฐ์ดํฐ๋ก๋ | |
| """ | |
| # ์ฌ์ ์ ์ด๋ฆ ์ถ๊ฐ | |
| for alias in alias2id: | |
| twitter.add_dictionary(alias, 'Noun') | |
| # ํ์ผ์์ ์ธ์คํด์ค ๋ก๋ | |
| with open(data_file, 'r', encoding='utf-8') as fin: | |
| data_lines = fin.readlines() | |
| # ์ ์ฒ๋ฆฌ | |
| data_list = [] | |
| for i, line in enumerate(tqdm(data_lines)): | |
| offset = i % 31 | |
| if offset == 0: | |
| instance_index = line.strip().split()[-1] | |
| raw_sents_in_list = [] | |
| continue | |
| if offset < 22: | |
| raw_sents_in_list.append(line.strip()) | |
| if offset == 22: | |
| speaker_name = line.strip().split()[-1] | |
| # ๋น ๋ฆฌ์คํธ๋ ์ ๊ฑฐํฉ๋๋ค. | |
| filtered_list = [li for li in raw_sents_in_list if li] | |
| # ๋ฌธ์ฅ ๋ถํ ๋ฐ ๋ฑ์ฅ์ธ๋ฌผ ์ธ๊ธ ์์น ์ถ์ถ | |
| seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location( | |
| filtered_list, alias2id) | |
| # CSS ์์ฑ | |
| css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS( | |
| seg_sents, candidate_mention_poses, args) | |
| # ํ๋ณด์ ๋ฆฌ์คํธ | |
| candidates_list = list(candidate_mention_poses.keys()) | |
| # ์ํซ ๋ ์ด๋ธ ์์ฑ | |
| one_hot_label = [0 if character_idx != alias2id[speaker_name] | |
| else 1 for character_idx in candidate_mention_poses.keys()] | |
| true_index = one_hot_label.index(1) if 1 in one_hot_label else 0 | |
| if offset == 24: | |
| category = line.strip().split()[-1] | |
| if offset == 25: | |
| name = ' '.join(line.strip().split()[1:]) | |
| if offset == 26: | |
| scene = line.strip().split()[-1] | |
| if offset == 27: | |
| place = line.strip().split()[-1] | |
| if offset == 28: | |
| time = line.strip().split()[-1] | |
| if offset == 29: | |
| cut_position = line.strip().split()[-1] | |
| data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes, | |
| cut_css, one_hot_label, true_index, category, name_list_index, | |
| name, scene, place, time, cut_position, candidates_list, | |
| instance_index)) | |
| # train-validation-test๋ก ๋ฐ์ดํฐ๋ฅผ ๋๋๊ธฐ | |
| train_data, test_data = train_test_split( | |
| data_list, test_size=test_size, random_state=random_state) | |
| train_data, val_data = train_test_split( | |
| train_data, test_size=val_size, random_state=random_state) | |
| # train DataLoader ์์ฑ | |
| train_loader = DataLoader(ISDataset(train_data), batch_size=1, collate_fn=lambda x: x[0]) | |
| # validation DataLoader ์์ฑ | |
| val_loader = DataLoader(ISDataset(val_data), batch_size=1, collate_fn=lambda x: x[0]) | |
| # test DataLoader ์์ฑ | |
| test_loader = DataLoader(ISDataset(test_data), batch_size=1, collate_fn=lambda x: x[0]) | |
| if save_name is not None: | |
| # ๊ฐ๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅ | |
| torch.save(train_data, save_name.replace(".pt", "_train.pt")) | |
| torch.save(val_data, save_name.replace(".pt", "_val.pt")) | |
| torch.save(test_data, save_name.replace(".pt", "_test.pt")) | |
| return train_loader, val_loader, test_loader | |