| | import os |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import requests |
| | from datasets import Dataset, DownloadMode, load_dataset |
| | from gradio_client import Client |
| |
|
| | from src.my_logger import setup_logger |
| |
|
| | SUBREDDIT = os.environ["SUBREDDIT"] |
| | USERNAME = os.environ["USERNAME"] |
| | OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}" |
| | PROCESSED_DATASET = os.environ['PROCESSED_DATASET'] |
| | embeddings_space = f"derek-thomas/nomic-embeddings" |
| | FILTER_IDS_URL = "https://huggingface.co/spaces/reddit-tools-HF/dataset-creator-reddit-bestofredditorupdates/raw/main/filter_ids.json" |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| |
|
| |
|
| | logger = setup_logger(__name__) |
| |
|
| |
|
| | def load_datasets(): |
| | |
| | logger.info(f"Trying to download {PROCESSED_DATASET}") |
| | dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) |
| | logger.info(f"Loaded {PROCESSED_DATASET}") |
| |
|
| | logger.info(f"Trying to download {OG_DATASET}") |
| | original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) |
| | logger.info(f"Loaded {OG_DATASET}") |
| | return dataset, original_dataset |
| |
|
| |
|
| | def merge_and_update_datasets(dataset, original_dataset): |
| | |
| | client = Client(embeddings_space, hf_token=HF_TOKEN) |
| |
|
| | |
| | odf = original_dataset['train'].to_pandas() |
| | df = dataset['train'].to_pandas() |
| |
|
| | |
| | odf = remove_filtered_rows(odf, FILTER_IDS_URL) |
| |
|
| | |
| | |
| | merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', '')) |
| | updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf]) |
| |
|
| | |
| | merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding']) |
| |
|
| | |
| | |
| | merged_df = merged_df.drop(columns=['content', 'new', 'updated']) |
| | merged_df.rename(columns={'content_odf': 'content'}, inplace=True) |
| |
|
| | logger.info(f"Updating {updated_row_count} rows...") |
| | |
| | for index, row in merged_df[merged_df['embedding'].isnull()].iterrows(): |
| | |
| | merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client) |
| |
|
| | dataset['train'] = Dataset.from_pandas(merged_df) |
| | logger.info(f"Updated {updated_row_count} rows") |
| | return dataset, updated_row_count |
| |
|
| |
|
| | def remove_filtered_rows(df: pd.DataFrame, url: str) -> pd.DataFrame: |
| | """ |
| | Removes rows from the DataFrame where the 'id' is present in the JSON file at the given URL. |
| | |
| | :param df: Input DataFrame to be filtered. |
| | :param url: URL to the JSON file containing the filter IDs. |
| | :return: DataFrame with rows containing IDs present in the JSON file removed. |
| | """ |
| |
|
| | |
| | response = requests.get(url) |
| | filter_ids = response.json() |
| |
|
| | logger.info(f"Loaded {len(filter_ids)} IDs from {url}") |
| |
|
| | |
| | filtered_df = df[~df['id'].astype(str).isin(filter_ids)] |
| |
|
| | logger.info(f"Filtered {len(df) - len(filtered_df)} rows from the DataFrame") |
| |
|
| | return filtered_df |
| |
|
| |
|
| | def update_embeddings(content, client): |
| | embedding = client.predict('search_document: ' + content, api_name="/embed") |
| | return np.array(embedding) |
| |
|