| | import os
|
| | import random
|
| | import argparse
|
| | from pathlib import Path
|
| | from tqdm import tqdm
|
| |
|
| | def split_data(input_file, train_file, test_file, test_ratio=0.1, seed=42):
|
| | """
|
| | Split tokenized data into training and test sets.
|
| |
|
| | Args:
|
| | input_file (str): Path to the input file containing tokenized sequences
|
| | train_file (str): Path to write training sequences
|
| | test_file (str): Path to write test sequences
|
| | test_ratio (float): Proportion of data to use for testing (default: 0.1)
|
| | seed (int): Random seed for reproducibility (default: 42)
|
| | """
|
| | random.seed(seed)
|
| |
|
| |
|
| | print("Counting sequences in the file...")
|
| | with open(input_file, 'r') as f:
|
| | total_sequences = sum(1 for _ in f)
|
| |
|
| | print(f"Total sequences found: {total_sequences}")
|
| |
|
| |
|
| | test_count = int(total_sequences * test_ratio)
|
| | train_count = total_sequences - test_count
|
| |
|
| |
|
| | all_indices = list(range(total_sequences))
|
| | random.shuffle(all_indices)
|
| | test_indices = set(all_indices[:test_count])
|
| |
|
| | print(f"Splitting data: {train_count} training sequences, {test_count} test sequences")
|
| |
|
| |
|
| | with open(input_file, 'r') as infile, \
|
| | open(train_file, 'w') as train_out, \
|
| | open(test_file, 'w') as test_out:
|
| |
|
| | for i, line in tqdm(enumerate(infile), total=total_sequences, desc="Splitting data"):
|
| | if i in test_indices:
|
| | test_out.write(line)
|
| | else:
|
| | train_out.write(line)
|
| |
|
| | print(f"Done! Training data saved to {train_file}, test data saved to {test_file}")
|
| |
|
| |
|
| | train_size_mb = os.path.getsize(train_file) / (1024 * 1024)
|
| | test_size_mb = os.path.getsize(test_file) / (1024 * 1024)
|
| | print(f"Training file size: {train_size_mb:.2f} MB")
|
| | print(f"Test file size: {test_size_mb:.2f} MB")
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Split tokenized data into training and test sets")
|
| | parser.add_argument("--input", type=str, default="./data/output.txt", help="Input file path")
|
| | parser.add_argument("--train", type=str, default="./data/train.txt", help="Output path for training data")
|
| | parser.add_argument("--test", type=str, default="./data/test.txt", help="Output path for test data")
|
| | parser.add_argument("--test-ratio", type=float, default=0.1, help="Proportion of data to use for testing")
|
| | parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | split_data(args.input, args.train, args.test, args.test_ratio, args.seed) |