| import os |
| import pandas as pd |
| import textwrap |
| import random |
| from datasets import load_dataset |
| from IPython.display import display |
|
|
| class LMSYSChat1MHandler: |
| def __init__(self, hf_token, streaming=False, verbose=True): |
| self.hf_token = hf_token |
| self.streaming = streaming |
| self.lmsys_dataset = load_dataset( |
| 'lmsys/lmsys-chat-1m', |
| revision="main", |
| token=self.hf_token, |
| streaming=self.streaming |
| ) |
| self.verbose = verbose |
| if verbose: |
| print(self.lmsys_dataset) |
| self.df_sample = None |
| self.df_prompts = None |
| self.unwrapped_turns_df = None |
|
|
| if not self.streaming and verbose: |
| print('Data is cached at:\n') |
| for file_info in self.lmsys_dataset['train'].cache_files: |
| filename = file_info['filename'] |
| file_size = os.path.getsize(filename) |
| i = int((len(filename) - 41) / 2) |
| print(f"Filename: {filename[:i]}*{filename[-41:]}\nSize: {file_size} bytes") |
|
|
| def extract_df_sample(self, n_samples=None, conversation_ids=None): |
| """ |
| Extracts a sample of conversations or specific conversations based on their conversation IDs. |
| |
| Parameters: |
| - n_samples (int): Number of random samples to extract. Ignored if `conversation_ids` is provided. |
| - conversation_ids (list): List of conversation IDs to extract. If provided, this takes precedence over `n_samples`. |
| |
| Returns: |
| - pd.DataFrame: A DataFrame containing the extracted conversations. |
| """ |
| if conversation_ids: |
| |
| df_sample = self.lmsys_dataset['train'].to_pandas() |
| df_sample = df_sample[df_sample['conversation_id'].isin(conversation_ids)] |
| print(f"Retrieved {len(df_sample)} conversations based on specified IDs") |
| else: |
| |
| if not self.streaming: |
| df_sample = self.lmsys_dataset['train'].to_pandas().sample(n_samples) |
| print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m") |
| else: |
| |
| streamed_samples = [] |
| for i, row in enumerate(self.lmsys_dataset['train']): |
| streamed_samples.append(row) |
| if i + 1 == n_samples: |
| break |
| |
| random.shuffle(streamed_samples) |
| df_sample = pd.DataFrame(streamed_samples) |
| |
| self.df_sample = df_sample |
| if self.verbose and len(df_sample) > 4: |
| display(df_sample.head(2)) |
| print('...') |
| display(df_sample.tail(2)) |
| return df_sample |
| |
| def parquet_sampling(self, n_samples): |
| base_url = "https://huggingface.co/datasets/lmsys/lmsys-chat-1m/resolve/main/data/" |
| data_files = [ |
| "train-00000-of-00006-4feeb3f83346a0e9.parquet", |
| "train-00001-of-00006-4030672591c2f478.parquet", |
| "train-00002-of-00006-1779b7cec9462180.parquet", |
| "train-00003-of-00006-2fa862bfed56af1f.parquet", |
| "train-00004-of-00006-18f4bdd50c103e71.parquet", |
| "train-00005-of-00006-fe1acc5d10a9f0e2.parquet" |
| ] |
| sample_file = random.choice(data_files) |
| print(f"Sampling from {sample_file}") |
| data_files = {"train": base_url + sample_file} |
| parquet_sample = load_dataset("parquet", data_files=data_files, split="train") |
| df_sample = parquet_sample.to_pandas().sample(n_samples) |
| print(f"Retrieved {len(df_sample)} random conversations from lmsys/lmsys-chat-1m/{sample_file}") |
| self.df_sample = df_sample |
| if self.verbose and len(df_sample) > 4: |
| display(df_sample.head(2)) |
| print('...') |
| display(df_sample.tail(2)) |
| return df_sample |
|
|
| def add_turns_to_conversations(self): |
| """ |
| Adds 'turn' keys to each conversation in the 'conversation' column of the dataframe. |
| """ |
| self.df_sample['conversation'] = self.df_sample['conversation'].apply( |
| lambda conv: Conversation(conv).add_turns() |
| ) |
| df_with_turns = self.df_sample |
| return df_with_turns |
| |
| def unwrap_turns(self): |
| """ |
| Creates a dataframe where each row corresponds to a pair of user-assistant messages in a conversation and turn. |
| The 'prompt' column contains the user's message, and the 'response' column contains the assistant's message. |
| Each row includes a 'turn_id' column, which numbers the turns uniquely per conversation. |
| """ |
| paired_data = [] |
| for _, row in self.df_sample.iterrows(): |
| conversation_id = row['conversation_id'] |
| row_data = row.to_dict() |
| row_data.pop('conversation') |
|
|
| current_prompt = None |
| turn_id = None |
|
|
| for message in row['conversation']: |
| if message['role'] == 'user': |
| current_prompt = message['content'] |
| turn_id = f"{conversation_id}{message['turn']:03}" |
| elif message['role'] == 'assistant' and current_prompt is not None: |
| |
| paired_row = { |
| **row_data, |
| 'turn_n': message['turn'], |
| 'prompt': current_prompt, |
| 'response': message['content'], |
| } |
| paired_data.append(paired_row) |
| current_prompt = None |
|
|
| unwrapped_turns_df = pd.DataFrame(paired_data) |
| unwrapped_turns_df.rename(columns={"turn": "conversation_turns"}, inplace=True) |
| self.unwrapped_turns_df = unwrapped_turns_df |
| return unwrapped_turns_df |
|
|
| def extract_prompts(self, filter_language=None, min_char_length=20, max_char_length=500, exclusions=None): |
| """ |
| Extracts user prompts from the sample dataframe, optionally filtering by language and limiting the character length. |
| |
| Parameters: |
| - filter_language (list of str or None): A list of specific languages to filter prompts by. If None, no language |
| filter is applied. Examples of valid values include ['English'], ['English', 'Portuguese'], or |
| ['Spanish', 'French', 'German']. |
| - min_char_length (int): The minimum character length for user prompts to include. Defaults to 20. |
| - max_char_length (int): The maximum character length for user prompts to include. Defaults to 500. |
| - exclusions (str or None): Path to a text file containing phrases. Prompts containing any of these phrases |
| will be excluded from the results. If None, no exclusions are applied. |
| |
| Returns: |
| - pd.DataFrame: A DataFrame containing extracted prompts with columns 'prompt' and 'language'. |
| """ |
| df_sample = self.df_sample |
| if filter_language: |
| extracted_data = df_sample[df_sample['language'].isin(filter_language)].apply( |
| lambda row: [ |
| {'content': entry['content'], 'language': row['language']} |
| for entry in row['conversation'] |
| if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length |
| ], axis=1 |
| ).explode().dropna() |
| else: |
| extracted_data = df_sample.apply( |
| lambda row: [ |
| {'content': entry['content'], 'language': row['language']} |
| for entry in row['conversation'] |
| if entry['role'] == 'user' and min_char_length <= len(entry['content']) <= max_char_length |
| ], axis=1 |
| ).explode().dropna() |
|
|
| df_prompts = pd.DataFrame(extracted_data.tolist()) |
| df_prompts.rename(columns={'content': 'prompt'}, inplace=True) |
|
|
| orig_length = len(df_prompts) |
| if exclusions: |
| |
| with open(exclusions, 'r') as f: |
| exclusions = [line.strip() for line in f.readlines()] |
| df_prompts = df_prompts[~df_prompts['prompt'].apply(lambda x: any(exclusion in x for exclusion in exclusions))] |
| print(f"Excluded {orig_length - len(df_prompts)} prompts.") |
|
|
| self.df_prompts = df_prompts |
| if self.verbose and len(df_sample) > 4: |
| display(df_prompts.head(2)) |
| print('...') |
| display(df_prompts.tail(2)) |
| return df_prompts |
| |
| def extract_prompt_sample(self): |
| prompt_sample = self.df_prompts.sample(1)['prompt'].values[0] |
| if self.verbose: |
| wrapped_message = textwrap.fill(prompt_sample, width=120) |
| print(wrapped_message) |
| return prompt_sample |
| |
| def search_conversations(self, search_term): |
| """ |
| Searches the dataset for a given string and returns a DataFrame with matching records. |
| |
| Parameters: |
| - search_term (str): The string to search for in the dataset. |
| |
| Returns: |
| - pd.DataFrame: A DataFrame containing conversations where the search term is found. |
| """ |
| if self.streaming: |
| raise ValueError("Search is not supported in streaming mode.") |
| df = self.lmsys_dataset['train'].to_pandas() |
| |
| matching_records = df[df['conversation'].apply( |
| lambda conv: any(search_term.lower() in message['content'].lower() for message in conv) |
| )] |
| if self.verbose: |
| print(f"Found {len(matching_records)} matching conversations for search term: '{search_term}'") |
| return matching_records |
| |
| def print_language_counts(self, df): |
| language_counts = df['language'].value_counts() |
| print("Language Record Counts:") |
| print(language_counts.to_frame('Count').reset_index().rename(columns={'index': 'Language'})) |
|
|
|
|
| class Conversation: |
| def __init__(self, conversation_data): |
| """ |
| Initializes the Conversation object with the conversation data. |
| |
| Parameters: |
| - conversation_data (list): A list of dictionaries representing a conversation. |
| """ |
| self.conversation_data = conversation_data |
|
|
| def add_turns(self): |
| """ |
| Adds a 'turn' key to each dictionary in the conversation, |
| identifying the turn (pair of user and assistant messages). |
| |
| Returns: |
| - list: The updated conversation with 'turn' keys added. |
| """ |
| turn_counter = 0 |
| for message in self.conversation_data: |
| if message['role'] == 'user': |
| turn_counter += 1 |
| message['turn'] = turn_counter |
| return self.conversation_data |
| |
| def pretty_print(self, user_prefix, assistant_prefix, width=80): |
| """ |
| Prints the conversation with specified prefixes and wrapped text. |
| |
| Parameters: |
| - user_prefix (str): Prefix to prepend to user messages. |
| - assistant_prefix (str): Prefix to prepend to assistant messages. |
| - width (int): Maximum characters per line for wrapping. |
| """ |
| wrapper = textwrap.TextWrapper(width=width) |
| |
| for message in self.conversation_data: |
| if message['role'] == 'user': |
| prefix = user_prefix |
| elif message['role'] == 'assistant': |
| prefix = assistant_prefix |
| else: |
| continue |
| |
| |
| wrapped_content = "\n".join( |
| wrapper.fill(line) for line in message['content'].splitlines() |
| ) |
| print(f"{prefix} {wrapped_content}\n") |