| | from vllm import LLM, SamplingParams |
| | import json |
| | import torch |
| | import time |
| | from datetime import datetime, timedelta |
| | import argparse |
| | from tqdm import tqdm |
| | from typing import List, Dict, Any |
| | import concurrent.futures |
| |
|
| | class VLLMGenerator: |
| | def __init__(self, model_name: str, n: int = 8, max_tokens: int = 512, |
| | temperature: float = 0.7, top_p: float = 1.0, |
| | frequency_penalty: float = 0.0, presence_penalty: float = 0.0, |
| | stop: List[str] = ['\n\n\n'], batch_size: int = 32): |
| | self.device_count = torch.cuda.device_count() |
| | print(f"Initializing with {self.device_count} GPUs") |
| | self.llm = LLM( |
| | model=model_name, |
| | tensor_parallel_size=self.device_count, |
| | max_model_len=4096, |
| | gpu_memory_utilization=0.95, |
| | enforce_eager=True, |
| | trust_remote_code=True, |
| | |
| | |
| | |
| | max_num_batched_tokens=4096, |
| | max_num_seqs=batch_size |
| | ) |
| | self.sampling_params = SamplingParams( |
| | n=n, |
| | max_tokens=max_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | frequency_penalty=frequency_penalty, |
| | presence_penalty=presence_penalty, |
| | stop=stop, |
| | logprobs=1 |
| | ) |
| | self.batch_size = batch_size |
| | self.tokenizer = self.llm.get_tokenizer() |
| | print(f"Initialization complete. Batch size: {batch_size}") |
| | |
| | def parse_response(self, responses): |
| | all_outputs = [] |
| | for response in responses: |
| | to_return = [] |
| | for output in response.outputs: |
| | text = output.text.strip() |
| | try: |
| | logprob = sum(logprob_obj.logprob for item in output.logprobs for logprob_obj in item.values()) |
| | except: |
| | logprob = 0 |
| | to_return.append((text, logprob)) |
| | texts = [r[0] for r in sorted(to_return, key=lambda tup: tup[1], reverse=True)] |
| | all_outputs.append(texts) |
| | return all_outputs |
| |
|
| | def prepare_prompt(self, claim: str, model_name: str) -> str: |
| | base_prompt = f"Please write a fact-checking article passage to support, refute, indicate not enough evidence, or present conflicting evidence regarding the claim.\nClaim: {claim}" |
| | |
| | if "OLMo" in model_name: |
| | return base_prompt |
| | else: |
| | messages = [{"role": "user", "content": base_prompt}] |
| | return self.tokenizer.apply_chat_template(messages, tokenize=False) + "<|start_header_id|>assistant<|end_header_id|>\n\nPassage: " |
| |
|
| | def process_batch(self, batch: List[Dict[str, Any]], model_name: str) -> tuple[List[Dict[str, Any]], float]: |
| | start_time = time.time() |
| | prompts = [self.prepare_prompt(example["claim"], model_name) for example in batch] |
| | |
| | try: |
| | results = self.llm.generate(prompts, sampling_params=self.sampling_params) |
| | outputs = self.parse_response(results) |
| | |
| | for example, output in zip(batch, outputs): |
| | example['hypo_fc_docs'] = output |
| | |
| | batch_time = time.time() - start_time |
| | return batch, batch_time |
| | except Exception as e: |
| | print(f"Error processing batch: {str(e)}") |
| | return batch, time.time() - start_time |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def main(args): |
| | total_start_time = time.time() |
| | print(f"Script started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| | |
| | |
| | print("Loading data...") |
| | with open(args.target_data, 'r', encoding='utf-8') as json_file: |
| | examples = json.load(json_file) |
| | print(f"Loaded {len(examples)} examples") |
| | |
| | |
| | print("Initializing generator...") |
| | generator = VLLMGenerator( |
| | model_name=args.model, |
| | batch_size=32 |
| | ) |
| | |
| | |
| | processed_data = [] |
| | |
| | batches = [examples[i:i + generator.batch_size] for i in range(0, len(examples), generator.batch_size)] |
| | |
| | print(f"\nProcessing {len(batches)} batches...") |
| | with tqdm(total=len(examples), desc="Processing examples") as pbar: |
| | for batch_idx, batch in enumerate(batches, 1): |
| | processed_batch, batch_time = generator.process_batch(batch, args.model) |
| | processed_data.extend(processed_batch) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | with open(args.json_output, "w", encoding="utf-8") as output_json: |
| | json.dump(processed_data, output_json, ensure_ascii=False, indent=4) |
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('-i', '--target_data', default='data_store/averitec/dev.json') |
| | parser.add_argument('-o', '--json_output', default='data_store/hyde_fc.json') |
| | parser.add_argument('-m', '--model', default="meta-llama/Llama-3.1-8B-Instruct") |
| | args = parser.parse_args() |
| | main(args) |