| | """ |
| | Module: tokenization_args.py |
| | |
| | This module defines the `TokenizationArgs` dataclass, which encapsulates all the configurable parameters |
| | required for the tokenization process in the TEDDY project. These parameters control how gene expression |
| | data and biological annotations are tokenized for training. |
| | |
| | Main Features: |
| | - Provides a structured way to define and manage tokenization arguments. |
| | - Supports configuration for gene selection, sequence truncation, and annotation inclusion. |
| | - Includes options for handling PerturbSeq-specific flags and preprocessing steps. |
| | - Allows for flexible mapping of biological annotations (e.g., disease, tissue, cell type, sex). |
| | - Enables reproducibility through random seed control for gene selection. |
| | |
| | Dependencies: |
| | - `dataclasses`: For defining the `TokenizationArgs` dataclass. |
| | |
| | Usage: |
| | 1. Import the `TokenizationArgs` class: |
| | ```python |
| | from teddy.tokenizer.tokenization_args import TokenizationArgs" |
| | ``` |
| | 2. Define tokenization arguments for a specific tokenization task: |
| | ```python |
| | tokenization_args = TokenizationArgs( |
| | tokenizer_name_or_path="path/to/tokenizer", |
| | ... |
| | ) |
| | ``` |
| | 3. Pass the `tokenization_args` object to the tokenization function: |
| | ```python |
| | tokenized_data = tokenize(data, tokenization_args) |
| | ``` |
| | """ |
| |
|
| | from dataclasses import dataclass, field |
| |
|
| |
|
| | @dataclass |
| | class TokenizationArgs: |
| | tokenizer_name_or_path: str = field(metadata={"help": "Path to tokenizer used."}) |
| | gene_id_column: str = field(default="index", metadata={"help": "Field to use while accessing gene_ids for values."}) |
| | random_genes: bool = field( |
| | default=False, metadata={"help": "whether we want random genes (True) selection or top expressed ones (False)"} |
| | ) |
| | include_zero_genes: bool = field(default=False, metadata={"help": "Path to tokenizer used."}) |
| | add_cls: bool = field(default=False, metadata={"help": "Whether to add cls token to the start of the sequence."}) |
| | cls_token_id: int = field(default=None, metadata={"help": "Token id for cls token."}) |
| | perturbseq: bool = field( |
| | default=False, |
| | metadata={"help": "[PerturbSeq specific flag] Whether to add perturbation token during tokenization."}, |
| | ) |
| | tokenize_perturbseq_for_train: bool = field( |
| | default=True, |
| | metadata={ |
| | "help": "[PerturbSeq specific flag] Whether to tokenize labels to prepare data for training or to simply prepare tokennized perturbation flags for inference." |
| | }, |
| | ) |
| | add_tokens: tuple = field( |
| | default=(), |
| | metadata={ |
| | "help": "Enter a tuple of string values for tokens. Will be pre-pended to the gene id sequence. Can be used instead of add_cls" |
| | }, |
| | ) |
| |
|
| | add_disease_annotation: bool = field(default=False) |
| |
|
| | label_column: str = field( |
| | default=None, metadata={"help": "Which column to use as a label for a classification task."} |
| | ) |
| | max_shard_samples: int = field(default=500, metadata={"help": "Number of samples included in sharding."}) |
| | max_seq_len: int = field(default=3001, metadata={"help": "Max seq length used for data processing"}) |
| | pad_length: int = field(default=3001, metadata={"help": "Pad sequence to x length so that all arrays in all batches are same length"}) |
| | truncation_method: str = field( |
| | default="max", |
| | metadata={ |
| | "help": "Indicate here how to restrict the number of genes to obtain max_seq_len from the full set of expresison values. Options: max, random" |
| | }, |
| | ) |
| | bins: int = field(default=None, metadata={"help": "Number of bins used when required for data processing"}) |
| |
|
| | rescale_labels: bool = field(default=False, metadata={"help": "If true, labels are binned or continiously ranked"}) |
| |
|
| | continuous_rank: bool = field( |
| | default=False, metadata={"help": "If true, gene values are overwritten with linspace[-1, 1] by rank."} |
| | ) |
| |
|
| | bio_annotations: bool = field( |
| | default=False, metadata={"help": "If true, include disease, tissue type, cell type, sex"} |
| | ) |
| |
|
| | bio_annotation_masking_prob: float = field( |
| | default=0.15, metadata={"help": "Mask annotation tokens with this probability"} |
| | ) |
| |
|
| | disease_mapping: str = field( |
| | default=None, metadata={"help": "Path to json mapping from disease names to standard disease categories"} |
| | ) |
| |
|
| | tissue_mapping: str = field( |
| | default=None, metadata={"help": "Path to json mapping from tissue names to standard tissue categories"} |
| | ) |
| |
|
| | cell_mapping: str = field( |
| | default=None, metadata={"help": "Path to json mapping from cell type names to standard cell types"} |
| | ) |
| |
|
| | sex_mapping: str = field( |
| | default=None, metadata={"help": "Path to json mapping from sex names to standard sex categories"} |
| | ) |
| |
|
| | load_dir: str = field(default="", metadata={"help": "Directory where h5ad data is loaded from."}) |
| |
|
| | save_dir: str = field( |
| | default="", |
| | metadata={ |
| | "help": "Directory where tokenization function will save data. tokenize() saves tokenized in data_path.replace(load_dir, save_dir)" |
| | }, |
| | ) |
| |
|
| | gene_seed: int = field(default=42, metadata={"help": "Random seed that controls randomness of gene selection"}) |
| |
|