| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import itertools |
| | import json |
| | import linecache |
| | import math |
| | import os |
| | import pickle |
| | import socket |
| | from collections.abc import Iterable |
| | from logging import getLogger |
| | from pathlib import Path |
| | from typing import Callable, Union |
| |
|
| | import git |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from rouge_score import rouge_scorer, scoring |
| | from sacrebleu import corpus_bleu |
| | from sentence_splitter import add_newline_to_end_of_each_sentence |
| | from torch import nn |
| | from torch.utils.data import Dataset, Sampler |
| |
|
| | from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer |
| | from transformers.models.bart.modeling_bart import shift_tokens_right |
| | from transformers.utils import cached_property |
| |
|
| |
|
| | try: |
| | from fairseq.data.data_utils import batch_by_size |
| |
|
| | FAIRSEQ_AVAILABLE = True |
| | except (ImportError, ModuleNotFoundError): |
| | FAIRSEQ_AVAILABLE = False |
| |
|
| |
|
| | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): |
| | """From fairseq""" |
| | if target.dim() == lprobs.dim() - 1: |
| | target = target.unsqueeze(-1) |
| | nll_loss = -lprobs.gather(dim=-1, index=target) |
| | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) |
| | if ignore_index is not None: |
| | pad_mask = target.eq(ignore_index) |
| | nll_loss.masked_fill_(pad_mask, 0.0) |
| | smooth_loss.masked_fill_(pad_mask, 0.0) |
| | else: |
| | nll_loss = nll_loss.squeeze(-1) |
| | smooth_loss = smooth_loss.squeeze(-1) |
| |
|
| | nll_loss = nll_loss.sum() |
| | smooth_loss = smooth_loss.sum() |
| | eps_i = epsilon / lprobs.size(-1) |
| | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss |
| | return loss, nll_loss |
| |
|
| |
|
| | def lmap(f: Callable, x: Iterable) -> list: |
| | """list(map(f, x))""" |
| | return list(map(f, x)) |
| |
|
| |
|
| | def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: |
| | """Uses sacrebleu's corpus_bleu implementation.""" |
| | return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} |
| |
|
| |
|
| | def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], dict]: |
| | def non_pad_len(tokens: np.ndarray) -> int: |
| | return np.count_nonzero(tokens != tokenizer.pad_token_id) |
| |
|
| | def decode_pred(pred: EvalPrediction) -> tuple[list[str], list[str]]: |
| | pred_ids = pred.predictions |
| | label_ids = pred.label_ids |
| | pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
| | label_ids[label_ids == -100] = tokenizer.pad_token_id |
| | label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
| | pred_str = lmap(str.strip, pred_str) |
| | label_str = lmap(str.strip, label_str) |
| | return pred_str, label_str |
| |
|
| | def summarization_metrics(pred: EvalPrediction) -> dict: |
| | pred_str, label_str = decode_pred(pred) |
| | rouge: dict = calculate_rouge(pred_str, label_str) |
| | summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) |
| | rouge.update({"gen_len": summ_len}) |
| | return rouge |
| |
|
| | def translation_metrics(pred: EvalPrediction) -> dict: |
| | pred_str, label_str = decode_pred(pred) |
| | bleu: dict = calculate_bleu(pred_str, label_str) |
| | gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) |
| | bleu.update({"gen_len": gen_len}) |
| | return bleu |
| |
|
| | compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics |
| | return compute_metrics_fn |
| |
|
| |
|
| | def trim_batch( |
| | input_ids, |
| | pad_token_id, |
| | attention_mask=None, |
| | ): |
| | """Remove columns that are populated exclusively by pad_token_id""" |
| | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) |
| | if attention_mask is None: |
| | return input_ids[:, keep_column_mask] |
| | else: |
| | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) |
| |
|
| |
|
| | class AbstractSeq2SeqDataset(Dataset): |
| | def __init__( |
| | self, |
| | tokenizer, |
| | data_dir, |
| | max_source_length, |
| | max_target_length, |
| | type_path="train", |
| | n_obs=None, |
| | prefix="", |
| | **dataset_kwargs, |
| | ): |
| | super().__init__() |
| | self.src_file = Path(data_dir).joinpath(type_path + ".source") |
| | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") |
| | self.len_file = Path(data_dir).joinpath(type_path + ".len") |
| | if os.path.exists(self.len_file): |
| | self.src_lens = pickle_load(self.len_file) |
| | self.used_char_len = False |
| | else: |
| | self.src_lens = self.get_char_lens(self.src_file) |
| | self.used_char_len = True |
| | self.max_source_length = max_source_length |
| | self.max_target_length = max_target_length |
| | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" |
| | self.tokenizer = tokenizer |
| | self.prefix = prefix if prefix is not None else "" |
| |
|
| | if n_obs is not None: |
| | self.src_lens = self.src_lens[:n_obs] |
| | self.pad_token_id = self.tokenizer.pad_token_id |
| | self.dataset_kwargs = dataset_kwargs |
| | dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) |
| |
|
| | def __len__(self): |
| | return len(self.src_lens) |
| |
|
| | @staticmethod |
| | def get_char_lens(data_file): |
| | return [len(x) for x in Path(data_file).open().readlines()] |
| |
|
| | @cached_property |
| | def tgt_lens(self): |
| | """Length in characters of target documents""" |
| | return self.get_char_lens(self.tgt_file) |
| |
|
| | def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): |
| | if distributed: |
| | return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) |
| | else: |
| | return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) |
| |
|
| | def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): |
| | assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" |
| | assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler" |
| | sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) |
| |
|
| | def num_tokens_in_example(i): |
| | return min(self.src_lens[i], self.max_target_length) |
| |
|
| | |
| | batch_sampler: list[list[int]] = batch_by_size( |
| | sorted_indices, |
| | num_tokens_fn=num_tokens_in_example, |
| | max_tokens=max_tokens_per_batch, |
| | required_batch_size_multiple=64, |
| | ) |
| | shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))] |
| | |
| | approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches] |
| | largest_batch_idx = np.argmax(approximate_toks_per_batch) |
| | shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( |
| | shuffled_batches[largest_batch_idx], |
| | shuffled_batches[0], |
| | ) |
| | return shuffled_batches |
| |
|
| | def __getitem__(self, item): |
| | raise NotImplementedError("You must implement this") |
| |
|
| | def collate_fn(self, batch): |
| | raise NotImplementedError("You must implement this") |
| |
|
| |
|
| | class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
| | def __getitem__(self, index) -> dict[str, torch.Tensor]: |
| | """Call tokenizer on src and tgt_lines""" |
| | index = index + 1 |
| | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") |
| | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
| | assert source_line, f"empty source line for index {index}" |
| | assert tgt_line, f"empty tgt line for index {index}" |
| | source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) |
| | target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) |
| |
|
| | source_ids = source_inputs["input_ids"].squeeze() |
| | target_ids = target_inputs["input_ids"].squeeze() |
| | src_mask = source_inputs["attention_mask"].squeeze() |
| | return { |
| | "input_ids": source_ids, |
| | "attention_mask": src_mask, |
| | "labels": target_ids, |
| | } |
| |
|
| | def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): |
| | """Only used by LegacyDataset""" |
| | return tokenizer( |
| | [line], |
| | max_length=max_length, |
| | padding="max_length" if pad_to_max_length else None, |
| | truncation=True, |
| | return_tensors=return_tensors, |
| | **self.dataset_kwargs, |
| | ) |
| |
|
| | def collate_fn(self, batch) -> dict[str, torch.Tensor]: |
| | input_ids = torch.stack([x["input_ids"] for x in batch]) |
| | masks = torch.stack([x["attention_mask"] for x in batch]) |
| | target_ids = torch.stack([x["labels"] for x in batch]) |
| | pad_token_id = self.pad_token_id |
| | y = trim_batch(target_ids, pad_token_id) |
| | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) |
| | batch = { |
| | "input_ids": source_ids, |
| | "attention_mask": source_mask, |
| | "labels": y, |
| | } |
| | return batch |
| |
|
| |
|
| | class Seq2SeqDataset(AbstractSeq2SeqDataset): |
| | """A dataset that calls prepare_seq2seq_batch.""" |
| |
|
| | def __getitem__(self, index) -> dict[str, str]: |
| | index = index + 1 |
| | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") |
| | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
| | assert source_line, f"empty source line for index {index}" |
| | assert tgt_line, f"empty tgt line for index {index}" |
| | return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} |
| |
|
| | def collate_fn(self, batch) -> dict[str, torch.Tensor]: |
| | """Call prepare_seq2seq_batch.""" |
| | batch_encoding: dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( |
| | [x["src_texts"] for x in batch], |
| | tgt_texts=[x["tgt_texts"] for x in batch], |
| | max_length=self.max_source_length, |
| | max_target_length=self.max_target_length, |
| | return_tensors="pt", |
| | **self.dataset_kwargs, |
| | ).data |
| | batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) |
| | return batch_encoding |
| |
|
| |
|
| | class Seq2SeqDataCollator: |
| | def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None): |
| | self.tokenizer = tokenizer |
| | self.pad_token_id = tokenizer.pad_token_id |
| | self.decoder_start_token_id = decoder_start_token_id |
| | assert self.pad_token_id is not None, ( |
| | f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." |
| | ) |
| | self.data_args = data_args |
| | self.tpu_num_cores = tpu_num_cores |
| | self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} |
| | if data_args.src_lang is not None: |
| | self.dataset_kwargs["src_lang"] = data_args.src_lang |
| | if data_args.tgt_lang is not None: |
| | self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang |
| |
|
| | def __call__(self, batch) -> dict[str, torch.Tensor]: |
| | if hasattr(self.tokenizer, "prepare_seq2seq_batch"): |
| | batch = self._encode(batch) |
| | input_ids, attention_mask, labels = ( |
| | batch["input_ids"], |
| | batch["attention_mask"], |
| | batch["labels"], |
| | ) |
| | else: |
| | input_ids = torch.stack([x["input_ids"] for x in batch]) |
| | attention_mask = torch.stack([x["attention_mask"] for x in batch]) |
| | labels = torch.stack([x["labels"] for x in batch]) |
| |
|
| | labels = trim_batch(labels, self.pad_token_id) |
| | input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) |
| |
|
| | if isinstance(self.tokenizer, T5Tokenizer): |
| | decoder_input_ids = self._shift_right_t5(labels) |
| | else: |
| | decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id) |
| |
|
| | batch = { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "decoder_input_ids": decoder_input_ids, |
| | "labels": labels, |
| | } |
| | return batch |
| |
|
| | def _shift_right_t5(self, input_ids): |
| | |
| | shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
| | shifted_input_ids[..., 0] = self.pad_token_id |
| | return shifted_input_ids |
| |
|
| | def _encode(self, batch) -> dict[str, torch.Tensor]: |
| | batch_encoding = self.tokenizer.prepare_seq2seq_batch( |
| | [x["src_texts"] for x in batch], |
| | tgt_texts=[x["tgt_texts"] for x in batch], |
| | max_length=self.data_args.max_source_length, |
| | max_target_length=self.data_args.max_target_length, |
| | padding="max_length" if self.tpu_num_cores is not None else "longest", |
| | return_tensors="pt", |
| | **self.dataset_kwargs, |
| | ) |
| | return batch_encoding.data |
| |
|
| |
|
| | class SortishSampler(Sampler): |
| | "Go through the text data by order of src length with a bit of randomness. From fastai repo." |
| |
|
| | def __init__(self, data, batch_size, shuffle=True): |
| | self.data, self.bs, self.shuffle = data, batch_size, shuffle |
| |
|
| | def __len__(self) -> int: |
| | return len(self.data) |
| |
|
| | def __iter__(self): |
| | return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) |
| |
|
| |
|
| | def sortish_sampler_indices(data: list, bs: int, shuffle=True) -> np.array: |
| | "Go through the text data by order of src length with a bit of randomness. From fastai repo." |
| | if not shuffle: |
| | return np.argsort(np.array(data) * -1) |
| |
|
| | def key_fn(i): |
| | return data[i] |
| |
|
| | idxs = np.random.permutation(len(data)) |
| | sz = bs * 50 |
| | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] |
| | sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) |
| | sz = bs |
| | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] |
| | max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) |
| | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] |
| | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=int) |
| | sort_idx = np.concatenate((ck_idx[0], sort_idx)) |
| | return sort_idx |
| |
|
| |
|
| | class DistributedSortishSampler(Sampler): |
| | """Copied from torch DistributedSampler""" |
| |
|
| | def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): |
| | if num_replicas is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = dist.get_world_size() |
| | if rank is None: |
| | if not dist.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = dist.get_rank() |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.epoch = 0 |
| | if add_extra_examples: |
| | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| | self.total_size = self.num_samples * self.num_replicas |
| | else: |
| | self.total_size = len(dataset) |
| | self.num_samples = len(self.available_indices) |
| | self.batch_size = batch_size |
| | self.add_extra_examples = add_extra_examples |
| | self.shuffle = shuffle |
| |
|
| | def __iter__(self) -> Iterable: |
| | g = torch.Generator() |
| | g.manual_seed(self.epoch) |
| |
|
| | sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] |
| | sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) |
| | indices = [self.available_indices[i] for i in sortish_indices] |
| | assert len(indices) == self.num_samples |
| | return iter(indices) |
| |
|
| | @cached_property |
| | def available_indices(self) -> np.array: |
| | indices = list(range(len(self.dataset))) |
| | |
| | indices += indices[: (self.total_size - len(indices))] |
| | assert len(indices) == self.total_size |
| | |
| | available_indices = indices[self.rank : self.total_size : self.num_replicas] |
| | return available_indices |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| | def set_epoch(self, epoch): |
| | self.epoch = epoch |
| |
|
| |
|
| | logger = getLogger(__name__) |
| |
|
| |
|
| | def use_task_specific_params(model, task): |
| | """Update config with summarization specific params.""" |
| | task_specific_params = model.config.task_specific_params |
| |
|
| | if task_specific_params is not None: |
| | pars = task_specific_params.get(task, {}) |
| | logger.info(f"setting model.config to task specific params for {task}:\n {pars}") |
| | logger.info("note: command line args may override some of these") |
| | model.config.update(pars) |
| |
|
| |
|
| | def pickle_load(path): |
| | """pickle.load(path)""" |
| | with open(path, "rb") as f: |
| | return pickle.load(f) |
| |
|
| |
|
| | def pickle_save(obj, path): |
| | """pickle.dump(obj, path)""" |
| | with open(path, "wb") as f: |
| | return pickle.dump(obj, f) |
| |
|
| |
|
| | def flatten_list(summary_ids: list[list]): |
| | return list(itertools.chain.from_iterable(summary_ids)) |
| |
|
| |
|
| | def save_git_info(folder_path: str) -> None: |
| | """Save git information to output_dir/git_log.json""" |
| | repo_infos = get_git_info() |
| | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) |
| |
|
| |
|
| | def save_json(content, path, indent=4, **json_dump_kwargs): |
| | with open(path, "w") as f: |
| | json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) |
| |
|
| |
|
| | def load_json(path): |
| | with open(path) as f: |
| | return json.load(f) |
| |
|
| |
|
| | def get_git_info(): |
| | try: |
| | repo = git.Repo(search_parent_directories=True) |
| | repo_infos = { |
| | "repo_id": str(repo), |
| | "repo_sha": str(repo.head.object.hexsha), |
| | "repo_branch": str(repo.active_branch), |
| | "hostname": str(socket.gethostname()), |
| | } |
| | return repo_infos |
| | except TypeError: |
| | return { |
| | "repo_id": None, |
| | "repo_sha": None, |
| | "repo_branch": None, |
| | "hostname": None, |
| | } |
| |
|
| |
|
| | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] |
| |
|
| |
|
| | def extract_rouge_mid_statistics(dct): |
| | new_dict = {} |
| | for k1, v1 in dct.items(): |
| | mid = v1.mid |
| | new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} |
| | return new_dict |
| |
|
| |
|
| | def calculate_rouge( |
| | pred_lns: list[str], |
| | tgt_lns: list[str], |
| | use_stemmer=True, |
| | rouge_keys=ROUGE_KEYS, |
| | return_precision_and_recall=False, |
| | bootstrap_aggregation=True, |
| | newline_sep=True, |
| | ) -> dict: |
| | """Calculate rouge using rouge_scorer package. |
| | |
| | Args: |
| | pred_lns: list of summaries generated by model |
| | tgt_lns: list of groundtruth summaries (e.g. contents of val.target) |
| | use_stemmer: Bool indicating whether Porter stemmer should be used to |
| | strip word suffixes to improve matching. |
| | rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum |
| | return_precision_and_recall: (False) whether to also return precision and recall. |
| | bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False |
| | this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` |
| | newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL |
| | on multi sentence summaries (CNN/DM dataset). |
| | |
| | Returns: |
| | Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys |
| | |
| | """ |
| | scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) |
| | aggregator = scoring.BootstrapAggregator() |
| | for pred, tgt in zip(tgt_lns, pred_lns): |
| | |
| | if newline_sep: |
| | pred = add_newline_to_end_of_each_sentence(pred) |
| | tgt = add_newline_to_end_of_each_sentence(tgt) |
| | scores = scorer.score(pred, tgt) |
| | aggregator.add_scores(scores) |
| |
|
| | if bootstrap_aggregation: |
| | result = aggregator.aggregate() |
| | if return_precision_and_recall: |
| | return extract_rouge_mid_statistics(result) |
| | else: |
| | return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} |
| |
|
| | else: |
| | return aggregator._scores |
| |
|
| |
|
| | |
| |
|
| |
|
| | def freeze_params(model: nn.Module): |
| | """Set requires_grad=False for each of model.parameters()""" |
| | for par in model.parameters(): |
| | par.requires_grad = False |
| |
|
| |
|
| | def freeze_embeds(model): |
| | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" |
| | model_type = model.config.model_type |
| |
|
| | if model_type in ["t5", "mt5"]: |
| | freeze_params(model.shared) |
| | for d in [model.encoder, model.decoder]: |
| | freeze_params(d.embed_tokens) |
| | elif model_type == "fsmt": |
| | for d in [model.model.encoder, model.model.decoder]: |
| | freeze_params(d.embed_positions) |
| | freeze_params(d.embed_tokens) |
| | else: |
| | freeze_params(model.model.shared) |
| | for d in [model.model.encoder, model.model.decoder]: |
| | freeze_params(d.embed_positions) |
| | freeze_params(d.embed_tokens) |
| |
|
| |
|
| | def grad_status(model: nn.Module) -> Iterable: |
| | return (par.requires_grad for par in model.parameters()) |
| |
|
| |
|
| | def any_requires_grad(model: nn.Module) -> bool: |
| | return any(grad_status(model)) |
| |
|
| |
|
| | def assert_all_frozen(model): |
| | model_grads: list[bool] = list(grad_status(model)) |
| | n_require_grad = sum(lmap(int, model_grads)) |
| | npars = len(model_grads) |
| | assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad" |
| |
|
| |
|
| | def assert_not_all_frozen(model): |
| | model_grads: list[bool] = list(grad_status(model)) |
| | npars = len(model_grads) |
| | assert any(model_grads), f"none of {npars} weights require grad" |
| |
|
| |
|
| | def parse_numeric_n_bool_cl_kwargs(unparsed_args: list[str]) -> dict[str, Union[int, float, bool]]: |
| | """ |
| | Parse an argv list of unspecified command line args to a dict. |
| | Assumes all values are either numeric or boolean in the form of true/false. |
| | """ |
| | result = {} |
| | assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" |
| | num_pairs = len(unparsed_args) // 2 |
| | for pair_num in range(num_pairs): |
| | i = 2 * pair_num |
| | assert unparsed_args[i].startswith("--") |
| | if unparsed_args[i + 1].lower() == "true": |
| | value = True |
| | elif unparsed_args[i + 1].lower() == "false": |
| | value = False |
| | else: |
| | try: |
| | value = int(unparsed_args[i + 1]) |
| | except ValueError: |
| | value = float(unparsed_args[i + 1]) |
| |
|
| | result[unparsed_args[i][2:]] = value |
| | return result |
| |
|
| |
|
| | def write_txt_file(ordered_tgt, path): |
| | f = Path(path).open("w") |
| | for ln in ordered_tgt: |
| | f.write(ln + "\n") |
| | f.flush() |
| |
|
| |
|
| | def chunks(lst, n): |
| | """Yield successive n-sized chunks from lst.""" |
| | for i in range(0, len(lst), n): |
| | yield lst[i : i + n] |
| |
|
| |
|
| | def check_output_dir(args, expected_items=0): |
| | """ |
| | Checks whether to bail out if output_dir already exists and has more than expected_items in it |
| | |
| | `args`: needs to have the following attributes of `args`: |
| | - output_dir |
| | - do_train |
| | - overwrite_output_dir |
| | |
| | `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) |
| | """ |
| | if ( |
| | os.path.exists(args.output_dir) |
| | and len(os.listdir(args.output_dir)) > expected_items |
| | and args.do_train |
| | and not args.overwrite_output_dir |
| | ): |
| | raise ValueError( |
| | f"Output directory ({args.output_dir}) already exists and " |
| | f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " |
| | "Use --overwrite_output_dir to overcome." |
| | ) |
| |
|