| |
| import os |
| import warnings |
|
|
| import torch |
| from mmengine.dist import master_only |
| from mmengine.hooks import Hook |
| from mmengine.model import is_model_wrapper |
| from mmengine.utils import mkdir_or_exist |
| from mmengine.utils.misc import get_object_from_string |
| from transformers import GenerationConfig, StoppingCriteriaList |
|
|
| from xtuner.dataset.utils import expand2square, load_image |
| from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
| from xtuner.registry import BUILDER |
| from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
| StopWordStoppingCriteria) |
|
|
|
|
| class EvaluateChatHook_ST(Hook): |
| priority = 'LOW' |
|
|
| def __init__(self, |
| tokenizer, |
| evaluation_inputs, |
| evaluation_images=None, |
| system='', |
| every_n_iters=None, |
| max_new_tokens=600, |
| stop_word=None, |
| stop_words=[], |
| generation_kwargs={}): |
| self.evaluation_inputs = evaluation_inputs |
| if isinstance(self.evaluation_inputs, str): |
| self.evaluation_inputs = [self.evaluation_inputs] |
| self.evaluation_images = evaluation_images |
| if isinstance(self.evaluation_images, str): |
| self.evaluation_images = [self.evaluation_images] |
| if self.evaluation_images is not None: |
| assert len( |
| self.evaluation_images) in [1, len(self.evaluation_inputs)] |
| if len(self.evaluation_images) == 1: |
| self.evaluation_images = [self.evaluation_images[0]] * len( |
| self.evaluation_inputs) |
| self.evaluation_images = [ |
| load_image(img) for img in self.evaluation_images |
| ] |
| instruction = '{input}' |
| if stop_word is not None: |
| |
| warnings.warn( |
| ('The `stop_word` argument is deprecated and will be removed ' |
| 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning) |
| stop_words.append(stop_word) |
| self.instruction = instruction |
| self.system = system |
| self.every_n_iters = every_n_iters |
| self.max_new_tokens = max_new_tokens |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.stop_criteria = StoppingCriteriaList() |
|
|
| |
| default_generation_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=0.1, |
| top_p=0.75, |
| top_k=40, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id |
| if self.tokenizer.pad_token_id is not None else |
| self.tokenizer.eos_token_id) |
| default_generation_kwargs.update(generation_kwargs) |
| self.gen_config = GenerationConfig(**default_generation_kwargs) |
|
|
| self.stop_criteria = StoppingCriteriaList() |
| for word in stop_words: |
| self.stop_criteria.append( |
| StopWordStoppingCriteria(self.tokenizer, word)) |
|
|
| self.is_first_run = True |
|
|
| @master_only |
| def _save_eval_output(self, runner, eval_outputs): |
| save_path = os.path.join(runner.log_dir, 'vis_data', |
| f'eval_outputs_iter_{runner.iter}.txt') |
| mkdir_or_exist(os.path.dirname(save_path)) |
| with open(save_path, 'w', encoding='utf-8') as f: |
| for i, output in enumerate(eval_outputs): |
| f.write(f'Eval output {i + 1}:\n{output}\n\n') |
|
|
| def _eval_images(self, |
| runner, |
| model, |
| device, |
| max_new_tokens=None, |
| save_eval_output=False): |
| if save_eval_output: |
| eval_outputs = [] |
| model.preparing_for_generation(metainfo={}) |
| for sample_image, sample_input in zip(self.evaluation_images, |
| self.evaluation_inputs): |
| image = sample_image |
| sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input |
| inputs = sample_input |
| generation_output = model.predict_forward(image=image, text=inputs) |
| inputs = generation_output['input_text'] |
| generation_output = generation_output['prediction'] |
| runner.logger.info(f'Sample output:\n' |
| f'{inputs + generation_output}\n') |
| if save_eval_output: |
| eval_outputs.append(f'{inputs + generation_output}\n') |
|
|
| if save_eval_output: |
| self._save_eval_output(runner, eval_outputs) |
|
|
| def _eval_language(self, |
| runner, |
| model, |
| device, |
| max_new_tokens=None, |
| save_eval_output=False): |
| if save_eval_output: |
| eval_outputs = [] |
|
|
| for sample_input in self.evaluation_inputs: |
| inputs = (self.system + self.instruction).format( |
| input=sample_input, round=1, **runner.cfg) |
| input_ids = self.tokenizer.encode(inputs, return_tensors='pt') |
| input_ids = input_ids.to(device) |
| generation_output = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=max_new_tokens, |
| generation_config=self.gen_config, |
| stopping_criteria=self.stop_criteria) |
| generation_output = self.tokenizer.decode(generation_output[0]) |
| runner.logger.info(f'Sample output:\n{generation_output}\n') |
| if save_eval_output: |
| eval_outputs.append(f'{generation_output}\n') |
|
|
| if save_eval_output: |
| self._save_eval_output(runner, eval_outputs) |
|
|
| def _generate_samples(self, |
| runner, |
| max_new_tokens=None, |
| save_eval_output=False): |
| if max_new_tokens is None: |
| max_new_tokens = self.max_new_tokens |
| model = runner.model |
| if is_model_wrapper(model): |
| model = model.module |
|
|
| device = next(iter(model.parameters())).device |
|
|
| if self.is_first_run: |
| |
| |
| model.to(device) |
| self.is_first_run = False |
|
|
| is_checkpointing = model.llm.is_gradient_checkpointing |
| use_cache = model.llm.config.use_cache |
|
|
| |
| model.activation_checkpointing_disable() |
| model.llm.config.use_cache = True |
| model.eval() |
| if self.evaluation_images is not None: |
| self._eval_images(runner, model, device, max_new_tokens, |
| save_eval_output) |
| else: |
| self._eval_language(runner, model, device, max_new_tokens, |
| save_eval_output) |
|
|
| |
| if is_checkpointing: |
| model.activation_checkpointing_enable() |
| model.llm.config.use_cache = use_cache |
| model.train() |
|
|
| def before_train(self, runner): |
| runner.logger.info('before_train in EvaluateChatHook.') |
| self._generate_samples(runner, max_new_tokens=50) |
|
|
| def _is_save_checkpoint(self, runner): |
| hooks = runner.hooks |
| checkpoint_hook = None |
| for hook in hooks: |
| if type(hook).__name__ == 'CheckpointHook': |
| checkpoint_hook = hook |
| break |
| if checkpoint_hook is None or checkpoint_hook.by_epoch: |
| return False |
|
|
| if checkpoint_hook.every_n_train_iters( |
| runner, checkpoint_hook.interval, checkpoint_hook.save_begin) or \ |
| (checkpoint_hook.save_last and |
| checkpoint_hook.is_last_train_iter(runner)): |
| return True |
|
|
| return False |
|
|
| def after_train_iter(self, |
| runner, |
| batch_idx: int, |
| data_batch=None, |
| outputs=None) -> None: |
| if self.every_n_iters is None: |
| return |
|
|
| save_eval_output = self._is_save_checkpoint(runner) |
|
|
| do_chat = ( |
| save_eval_output |
| or self.every_n_train_iters(runner, self.every_n_iters)) |
| if not do_chat: |
| return |
|
|
| runner.logger.info('after_train_iter in EvaluateChatHook.') |
| self._generate_samples(runner, save_eval_output=save_eval_output) |
|
|
| def after_train(self, runner): |
| runner.logger.info('after_train in EvaluateChatHook.') |
| self._generate_samples(runner) |
|
|
| def after_val(self, runner) -> None: |
| if self.every_n_iters is not None: |
| return |
| runner.logger.info('after_val in EvaluateChatHook.') |
| self._generate_samples(runner) |