bee / scripts /download_datasets.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
#!/usr/bin/env python3
"""Download and prepare instruction datasets for Bee LoRA training.
Fetches curated subsets of high-quality instruction data from HuggingFace,
saves as JSONL for training pipeline consumption.
Usage:
python scripts/download_datasets.py --output_dir ./datasets
Datasets:
- OpenOrca (subset: 10k random samples)
- CodeAlpaca (coding instructions, ~20k)
- teknium/OpenHermes-2.5 (high-quality, ~10k subset)
"""
import argparse
import json
import logging
import os
import random
from pathlib import Path
from datasets import load_dataset
logger = logging.getLogger("bee.data")
def _format_alpaca(ex) -> dict:
"""Convert Alpaca-style example to {instruction, input, output} dict."""
return {
"instruction": ex.get("instruction", ex.get("prompt", "")),
"input": ex.get("input", ""),
"output": ex.get("output", ex.get("response", ex.get("completion", ""))),
}
def _format_openorca(ex) -> dict:
"""Convert OpenOrca example."""
return {
"instruction": ex.get("question", ex.get("prompt", "")),
"input": "",
"output": ex.get("response", ex.get("answer", ex.get("completion", ""))),
}
def download_openorca(output_dir: str, max_samples: int = 10000):
logger.info("Downloading OpenOrca (subset: %d)...", max_samples)
try:
ds = load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)
samples = []
for i, ex in enumerate(ds):
if i >= max_samples:
break
samples.append(_format_openorca(ex))
_save_jsonl(os.path.join(output_dir, "openorca.jsonl"), samples)
logger.info("Saved %d OpenOrca samples", len(samples))
except Exception as e:
logger.warning("OpenOrca download failed: %s", e)
def download_code_alpaca(output_dir: str):
logger.info("Downloading CodeAlpaca...")
try:
ds = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
samples = [_format_alpaca(ex) for ex in ds]
_save_jsonl(os.path.join(output_dir, "codealpaca.jsonl"), samples)
logger.info("Saved %d CodeAlpaca samples", len(samples))
except Exception as e:
logger.warning("CodeAlpaca download failed: %s", e)
def download_openhermes(output_dir: str, max_samples: int = 10000):
logger.info("Downloading OpenHermes 2.5 (subset: %d)...", max_samples)
try:
ds = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True)
samples = []
for i, ex in enumerate(ds):
if i >= max_samples:
break
samples.append({
"instruction": ex.get("conversations", [{}])[0].get("value", ""),
"input": "",
"output": ex.get("conversations", [{}, {}])[1].get("value", ""),
})
_save_jsonl(os.path.join(output_dir, "openhermes.jsonl"), samples)
logger.info("Saved %d OpenHermes samples", len(samples))
except Exception as e:
logger.warning("OpenHermes download failed: %s", e)
def _save_jsonl(path: str, data: list):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for item in data:
f.write(json.dumps(item) + "\n")
def prepare_mixed_dataset(output_dir: str, datasets: list = None):
"""Combine all downloaded datasets into a single shuffled training file."""
datasets = datasets or ["openorca.jsonl", "codealpaca.jsonl", "openhermes.jsonl"]
all_samples = []
for fname in datasets:
path = os.path.join(output_dir, fname)
if os.path.exists(path):
with open(path) as f:
for line in f:
all_samples.append(json.loads(line))
logger.info("Loaded %s: %d samples", fname, len(all_samples))
else:
logger.warning("Missing dataset: %s", path)
random.shuffle(all_samples)
_save_jsonl(os.path.join(output_dir, "train_mixed.jsonl"), all_samples)
logger.info("Mixed dataset: %d total samples", len(all_samples))
return len(all_samples)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", default="./datasets")
parser.add_argument("--openorca_samples", type=int, default=10000)
parser.add_argument("--openhermes_samples", type=int, default=10000)
parser.add_argument("--skip_openorca", action="store_true")
parser.add_argument("--skip_codealpaca", action="store_true")
parser.add_argument("--skip_openhermes", action="store_true")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
os.makedirs(args.output_dir, exist_ok=True)
if not args.skip_openorca:
download_openorca(args.output_dir, args.openorca_samples)
if not args.skip_codealpaca:
download_code_alpaca(args.output_dir)
if not args.skip_openhermes:
download_openhermes(args.output_dir, args.openhermes_samples)
n = prepare_mixed_dataset(args.output_dir)
logger.info("Dataset preparation complete: %d samples in %s/train_mixed.jsonl", n, args.output_dir)
if __name__ == "__main__":
main()