| | |
| | import argparse |
| | import os |
| | import os.path as osp |
| | import re |
| | import sys |
| |
|
| | import torch |
| | from huggingface_hub import snapshot_download |
| | from peft import PeftModel |
| | from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, |
| | BitsAndBytesConfig, SiglipImageProcessor, |
| | SiglipVisionModel, Dinov2Model, |
| | GenerationConfig) |
| | from transformers.generation.streamers import TextStreamer |
| |
|
| | from xtuner.dataset.utils import expand2square, load_image |
| | from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
| | from xtuner.tools.utils import get_stop_criteria |
| | from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
| | PROMPT_TEMPLATE, SYSTEM_TEMPLATE) |
| |
|
| | TORCH_DTYPE_MAP = dict( |
| | fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
| |
|
| |
|
| | def remove_prefix(state_dict, prefix): |
| | new_state_dict = {} |
| | for key, value in state_dict.items(): |
| | if key.startswith(prefix): |
| | new_key = key[len(prefix):] |
| | new_state_dict[new_key] = value |
| | else: |
| | new_state_dict[key] = value |
| | return new_state_dict |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Chat with a HF model') |
| | parser.add_argument( |
| | 'model_name_or_path', help='Hugging Face model name or path') |
| | adapter_group = parser.add_mutually_exclusive_group() |
| | adapter_group.add_argument( |
| | '--adapter', default=None, help='adapter name or path') |
| | adapter_group.add_argument( |
| | '--llava', default=None, help='llava name or path') |
| | parser.add_argument( |
| | '--siglip', default=None, help='siglip visual encoder name or path') |
| | parser.add_argument( |
| | '--visual-select-layer', default=-2, help='visual select layer') |
| | parser.add_argument( |
| | '--dino', default=None, help='dino visual encoder name or path') |
| | parser.add_argument('--image', default=None, help='image') |
| | parser.add_argument( |
| | '--torch-dtype', |
| | default='fp16', |
| | choices=TORCH_DTYPE_MAP.keys(), |
| | help='Override the default `torch.dtype` and load the model under ' |
| | 'a specific `dtype`.') |
| | parser.add_argument( |
| | '--prompt-template', |
| | choices=PROMPT_TEMPLATE.keys(), |
| | default=None, |
| | help='Specify a prompt template') |
| | system_group = parser.add_mutually_exclusive_group() |
| | system_group.add_argument( |
| | '--system', default=None, help='Specify the system text') |
| | system_group.add_argument( |
| | '--system-template', |
| | choices=SYSTEM_TEMPLATE.keys(), |
| | default=None, |
| | help='Specify a system template') |
| | parser.add_argument( |
| | '--bits', |
| | type=int, |
| | choices=[4, 8, None], |
| | default=None, |
| | help='LLM bits') |
| | parser.add_argument( |
| | '--bot-name', type=str, default='BOT', help='Name for Bot') |
| | parser.add_argument( |
| | '--with-plugins', |
| | nargs='+', |
| | choices=['calculate', 'solve', 'search'], |
| | help='Specify plugins to use') |
| | parser.add_argument( |
| | '--no-streamer', action='store_true', help='Whether to with streamer') |
| | parser.add_argument( |
| | '--lagent', action='store_true', help='Whether to use lagent') |
| | parser.add_argument( |
| | '--stop-words', nargs='+', type=str, default=[], help='Stop words') |
| | parser.add_argument( |
| | '--offload-folder', |
| | default=None, |
| | help='The folder in which to offload the model weights (or where the ' |
| | 'model weights are already offloaded).') |
| | parser.add_argument( |
| | '--max-new-tokens', |
| | type=int, |
| | default=2048, |
| | help='Maximum number of new tokens allowed in generated text') |
| | parser.add_argument( |
| | '--temperature', |
| | type=float, |
| | default=0.1, |
| | help='The value used to modulate the next token probabilities.') |
| | parser.add_argument( |
| | '--top-k', |
| | type=int, |
| | default=40, |
| | help='The number of highest probability vocabulary tokens to ' |
| | 'keep for top-k-filtering.') |
| | parser.add_argument( |
| | '--top-p', |
| | type=float, |
| | default=0.75, |
| | help='If set to float < 1, only the smallest set of most probable ' |
| | 'tokens with probabilities that add up to top_p or higher are ' |
| | 'kept for generation.') |
| | parser.add_argument( |
| | '--repetition-penalty', |
| | type=float, |
| | default=1.0, |
| | help='The parameter for repetition penalty. 1.0 means no penalty.') |
| | parser.add_argument( |
| | '--seed', |
| | type=int, |
| | default=0, |
| | help='Random seed for reproducible text generation') |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def get_input(): |
| | """Helper function for getting input from users.""" |
| | sentinel = '' |
| | result = None |
| | while result is None: |
| | print(('\ndouble enter to end input (EXIT: exit chat, ' |
| | 'RESET: reset history) >>> '), |
| | end='') |
| | try: |
| | result = '\n'.join(iter(input, sentinel)) |
| | except UnicodeDecodeError: |
| | print('Invalid characters detected. Please enter again.') |
| | return result |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | torch.manual_seed(args.seed) |
| |
|
| | |
| | quantization_config = None |
| | load_in_8bit = False |
| | if args.bits == 4: |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | load_in_8bit=False, |
| | llm_int8_threshold=6.0, |
| | llm_int8_has_fp16_weight=False, |
| | bnb_4bit_compute_dtype=torch.float16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type='nf4') |
| | elif args.bits == 8: |
| | load_in_8bit = True |
| | model_kwargs = { |
| | 'quantization_config': quantization_config, |
| | 'load_in_8bit': load_in_8bit, |
| | 'device_map': 'auto', |
| | 'offload_folder': args.offload_folder, |
| | 'trust_remote_code': True, |
| | 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] |
| | } |
| | if args.lagent: |
| | from lagent.actions import ActionExecutor, GoogleSearch |
| | from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN, |
| | ReAct, ReActProtocol) |
| | from lagent.llms import HFTransformerCasualLM |
| |
|
| | try: |
| | SERPER_API_KEY = os.environ['SERPER_API_KEY'] |
| | except Exception: |
| | print('Please obtain the `SERPER_API_KEY` from https://serper.dev ' |
| | 'and set it using `export SERPER_API_KEY=xxx`.') |
| | sys.exit(1) |
| |
|
| | model_kwargs.pop('trust_remote_code') |
| | llm = HFTransformerCasualLM( |
| | args.model_name_or_path, model_kwargs=model_kwargs) |
| | if args.adapter is not None: |
| | print(f'Loading adapter from {args.adapter}...') |
| | llm.model = PeftModel.from_pretrained( |
| | llm.model, |
| | args.adapter, |
| | offload_folder=args.offload_folder, |
| | trust_remote_code=True) |
| | search_tool = GoogleSearch(api_key=SERPER_API_KEY) |
| | chatbot = ReAct( |
| | llm=llm, |
| | action_executor=ActionExecutor(actions=[search_tool]), |
| | protocol=ReActProtocol( |
| | call_protocol=CALL_PROTOCOL_CN, |
| | force_stop=FORCE_STOP_PROMPT_CN)) |
| | while True: |
| | text = get_input() |
| | while text.strip() == 'RESET': |
| | print('Log: History responses have been removed!') |
| | chatbot._session_history = [] |
| | inputs = '' |
| | text = get_input() |
| | if text.strip() == 'EXIT': |
| | print('Log: Exit!') |
| | exit(0) |
| | response = chatbot.chat(text) |
| | print(response.response) |
| | else: |
| | if args.with_plugins is None: |
| | inner_thoughts_open = False |
| | calculate_open = False |
| | solve_open = False |
| | search_open = False |
| | else: |
| | assert args.prompt_template == args.system_template == 'moss_sft' |
| | from plugins import plugins_api |
| | inner_thoughts_open = True |
| | calculate_open = 'calculate' in args.with_plugins |
| | solve_open = 'solve' in args.with_plugins |
| | search_open = 'search' in args.with_plugins |
| | |
| | if calculate_open: |
| | from plugins import calculate |
| | if solve_open: |
| | from plugins import solve |
| | if search_open: |
| | from plugins import search |
| | |
| | llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, |
| | **model_kwargs) |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | args.model_name_or_path, |
| | trust_remote_code=True, |
| | encode_special_tokens=True) |
| | print(f'Load LLM from {args.model_name_or_path}') |
| | if args.adapter is not None: |
| | llm = PeftModel.from_pretrained( |
| | llm, |
| | args.adapter, |
| | offload_folder=args.offload_folder, |
| | trust_remote_code=True) |
| | print(f'Load adapter from {args.adapter}') |
| | if args.llava is not None: |
| | llava_path = snapshot_download( |
| | repo_id=args.llava) if not osp.isdir( |
| | args.llava) else args.llava |
| |
|
| | |
| | if 'visual_encoder' in os.listdir(llava_path): |
| | assert args.visual_encoder is None, ( |
| | "Please don't specify the `--visual-encoder` since passed " |
| | '`--llava` contains a visual encoder!') |
| | visual_encoder_path = osp.join(llava_path, 'visual_encoder') |
| | else: |
| | assert args.siglip is not None, ( |
| | 'Please specify the `--siglip`!') |
| | assert args.dino is not None, ( |
| | 'Please specify the `--dino`!') |
| | siglip = SiglipVisionModel.from_pretrained( |
| | args.siglip, |
| | torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) |
| | image_processor = SiglipImageProcessor.from_pretrained( |
| | args.siglip) |
| | print(f'Load siglip from {args.siglip}') |
| | dino = Dinov2Model.from_pretrained( |
| | args.dino, |
| | torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) |
| | print(f'Load dino from {args.dino}') |
| |
|
| | |
| | if 'llm_adapter' in os.listdir(llava_path): |
| | adapter_path = osp.join(llava_path, 'llm_adapter') |
| | llm = PeftModel.from_pretrained( |
| | llm, |
| | adapter_path, |
| | offload_folder=args.offload_folder, |
| | trust_remote_code=True) |
| | print(f'Load LLM adapter from {args.llava}') |
| | if 'visual_encoder_adapter' in os.listdir(llava_path): |
| | adapter_path = osp.join(llava_path, 'visual_encoder_adapter') |
| | visual_encoder = PeftModel.from_pretrained( |
| | visual_encoder, |
| | adapter_path, |
| | offload_folder=args.offload_folder) |
| | print(f'Load visual_encoder adapter from {args.llava}') |
| |
|
| | |
| | projector_path = osp.join(llava_path, 'projector') |
| | projector = AutoModel.from_pretrained( |
| | projector_path, |
| | torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype], |
| | trust_remote_code=True) |
| | print(f'Load projector from {args.llava}') |
| |
|
| | projector.cuda() |
| | projector.eval() |
| | siglip.cuda() |
| | siglip.eval() |
| | dino.cuda() |
| | dino.eval() |
| |
|
| | llm.eval() |
| |
|
| | if args.image is not None: |
| | image = load_image(args.image) |
| | image = expand2square( |
| | image, tuple(int(x * 255) for x in image_processor.image_mean)) |
| | image = image_processor.preprocess( |
| | image, return_tensors='pt')['pixel_values'][0] |
| | image = image.cuda().unsqueeze(0) |
| |
|
| | siglip_out = siglip( |
| | image, output_hidden_states=True).hidden_states[args.visual_select_layer] |
| | dino_out = dino( |
| | image, output_hidden_states=True).hidden_states[-1][:, 1:] |
| | visual_out = torch.cat((siglip_out, dino_out), dim=-1) |
| | pixel_values = projector(visual_out) |
| |
|
| | stop_words = args.stop_words |
| | sep = '' |
| | if args.prompt_template: |
| | template = PROMPT_TEMPLATE[args.prompt_template] |
| | stop_words += template.get('STOP_WORDS', []) |
| | sep = template.get('SEP', '') |
| | stop_criteria = get_stop_criteria( |
| | tokenizer=tokenizer, stop_words=stop_words) |
| |
|
| | if args.no_streamer: |
| | streamer = None |
| | else: |
| | streamer = TextStreamer(tokenizer, skip_prompt=True) |
| |
|
| | gen_config = GenerationConfig( |
| | max_new_tokens=args.max_new_tokens, |
| | do_sample=args.temperature > 0, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | top_k=args.top_k, |
| | repetition_penalty=args.repetition_penalty, |
| | eos_token_id=tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.pad_token_id |
| | if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
| | ) |
| |
|
| | n_turn = 0 |
| | inputs = '' |
| | while True: |
| | text = get_input() |
| | while text.strip() == 'RESET': |
| | print('Log: History responses have been removed!') |
| | n_turn = 0 |
| | inputs = '' |
| | text = get_input() |
| | if text.strip() == 'EXIT': |
| | print('Log: Exit!') |
| | exit(0) |
| |
|
| | if args.image is not None and n_turn == 0: |
| | text = DEFAULT_IMAGE_TOKEN + '\n' + text |
| |
|
| | if args.prompt_template: |
| | prompt_text = '' |
| | template = PROMPT_TEMPLATE[args.prompt_template] |
| | if 'SYSTEM' in template and n_turn == 0: |
| | system_text = None |
| | if args.system_template is not None: |
| | system_text = SYSTEM_TEMPLATE[ |
| | args.system_template].format( |
| | round=n_turn + 1, bot_name=args.bot_name) |
| | elif args.system is not None: |
| | system_text = args.system |
| | if system_text is not None: |
| | prompt_text += template['SYSTEM'].format( |
| | system=system_text, |
| | round=n_turn + 1, |
| | bot_name=args.bot_name) |
| | prompt_text += template['INSTRUCTION'].format( |
| | input=text, round=n_turn + 1, bot_name=args.bot_name) |
| | if args.prompt_template == args.system_template == 'moss_sft': |
| | if not inner_thoughts_open: |
| | prompt_text.replace('- Inner thoughts: enabled.', |
| | '- Inner thoughts: disabled.') |
| | if not calculate_open: |
| | prompt_text.replace(('- Calculator: enabled. API: ' |
| | 'Calculate(expression)'), |
| | '- Calculator: disabled.') |
| | if not solve_open: |
| | prompt_text.replace( |
| | '- Equation solver: enabled. API: Solve(equation)', |
| | '- Equation solver: disabled.') |
| | if not search_open: |
| | prompt_text.replace( |
| | '- Web search: enabled. API: Search(query)', |
| | '- Web search: disabled.') |
| | else: |
| | prompt_text = text |
| | inputs += prompt_text |
| | if args.image is None: |
| | if n_turn == 0: |
| | ids = tokenizer.encode(inputs, return_tensors='pt') |
| | else: |
| | ids = tokenizer.encode( |
| | inputs, return_tensors='pt', add_special_tokens=False) |
| |
|
| | if args.with_plugins is not None: |
| | generate_output = llm.generate( |
| | inputs=ids.cuda(), |
| | generation_config=gen_config, |
| | streamer=streamer, |
| | stopping_criteria=stop_criteria).cpu() |
| | generate_output_text = tokenizer.decode( |
| | generate_output[0][len(ids[0]):]) |
| | if streamer is None: |
| | end = '' if generate_output_text[-1] == '\n' else '\n' |
| | print(generate_output_text, end=end) |
| | pattern = r'<\|Commands\|>:(.*?)<eoc>' |
| | command_text = ', '.join( |
| | re.findall(pattern, generate_output_text)) |
| | extent_text = plugins_api( |
| | command_text, |
| | calculate_open=calculate_open, |
| | solve_open=solve_open, |
| | search_open=search_open) |
| | end = '' if extent_text[-1] == '\n' else '\n' |
| | print(extent_text, end=end) |
| | extent_text_ids = tokenizer.encode( |
| | extent_text, |
| | return_tensors='pt', |
| | add_special_tokens=False) |
| | new_ids = torch.cat((generate_output, extent_text_ids), |
| | dim=1) |
| |
|
| | generate_output = llm.generate( |
| | inputs=new_ids.cuda(), |
| | generation_config=gen_config, |
| | streamer=streamer, |
| | stopping_criteria=stop_criteria) |
| | if streamer is None: |
| | output_text = tokenizer.decode( |
| | generate_output[0][len(new_ids[0]):]) |
| | end = '' if output_text[-1] == '\n' else '\n' |
| | print(output_text, end=end) |
| | else: |
| | generate_output = llm.generate( |
| | inputs=ids.cuda(), |
| | generation_config=gen_config, |
| | streamer=streamer, |
| | stopping_criteria=stop_criteria) |
| | if streamer is None: |
| | output_text = tokenizer.decode( |
| | generate_output[0][len(ids[0]):]) |
| | end = '' if output_text[-1] == '\n' else '\n' |
| | print(output_text, end=end) |
| | inputs = tokenizer.decode(generate_output[0]) |
| | else: |
| | chunk_encode = [] |
| | for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): |
| | if idx == 0 and n_turn == 0: |
| | cur_encode = tokenizer.encode(chunk) |
| | else: |
| | cur_encode = tokenizer.encode( |
| | chunk, add_special_tokens=False) |
| | chunk_encode.append(cur_encode) |
| | assert len(chunk_encode) == 2 |
| | ids = [] |
| | for idx, cur_chunk_encode in enumerate(chunk_encode): |
| | ids.extend(cur_chunk_encode) |
| | if idx != len(chunk_encode) - 1: |
| | ids.append(IMAGE_TOKEN_INDEX) |
| | ids = torch.tensor(ids).cuda().unsqueeze(0) |
| | mm_inputs = prepare_inputs_labels_for_multimodal( |
| | llm=llm, input_ids=ids, pixel_values=pixel_values) |
| |
|
| | generate_output = llm.generate( |
| | **mm_inputs, |
| | generation_config=gen_config, |
| | streamer=streamer, |
| | bos_token_id=tokenizer.bos_token_id, |
| | stopping_criteria=stop_criteria) |
| | if streamer is None: |
| | output_text = tokenizer.decode(generate_output[0]) |
| | end = '' if output_text[-1] == '\n' else '\n' |
| | print(output_text, end=end) |
| | inputs += tokenizer.decode(generate_output[0]) |
| | n_turn += 1 |
| | inputs += sep |
| | if len(generate_output[0]) >= args.max_new_tokens: |
| | print( |
| | 'Remove the memory of history responses, since ' |
| | f'it exceeds the length limitation {args.max_new_tokens}.') |
| | n_turn = 0 |
| | inputs = '' |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|