| | --- |
| | license: apache-2.0 |
| | datasets: |
| | - WorkInTheDark/FairytaleQA |
| | language: |
| | - en |
| | metrics: |
| | - f1 |
| | - accuracy |
| | - recall |
| | base_model: |
| | - google-bert/bert-base-uncased |
| | pipeline_tag: text-classification |
| | library_name: transformers |
| | --- |
| | # BertForStorySkillClassification |
| |
|
| | ## Model Overview |
| | `BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes: |
| | 1. **Character** |
| | 2. **Setting** |
| | 3. **Feeling** |
| | 4. **Action** |
| | 5. **Causal Relationship** |
| | 6. **Outcome Resolution** |
| | 7. **Prediction** |
| |
|
| | This model is suitable for applications in education, literary analysis, and story comprehension. |
| |
|
| | --- |
| |
|
| | ## Model Architecture |
| | - **Base Model**: `bert-base-uncased` |
| | - **Classification Layer**: A fully connected layer on top of BERT for 7-class classification. |
| | - **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ") |
| | - **Output**: Predicted label and confidence score. |
| |
|
| | --- |
| |
|
| | ## Quick Start |
| |
|
| | ### Install Dependencies |
| | Ensure you have the `transformers` library installed: |
| | ```bash |
| | pip install transformers |
| | ``` |
| |
|
| | ### Load Model and Tokenizer |
| |
|
| | ```python |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification") |
| | tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification") |
| | ``` |
| |
|
| | ### Use the predict Method for Inference |
| |
|
| | ```python |
| | # Single text prediction |
| | result = model.predict( |
| | texts="Where does this story take place?", |
| | tokenizer=tokenizer, |
| | return_probabilities=True |
| | ) |
| | print(result) |
| | # Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}] |
| | |
| | # Batch prediction |
| | results = model.predict( |
| | texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "], |
| | tokenizer=tokenizer, |
| | batch_size=16, |
| | device="cuda" |
| | ) |
| | print(results) |
| | """ |
| | output: |
| | [{'text': 'Why is the character sad?', 'label': 'causal relationship'}, |
| | {'text': 'How does the story end?', 'label': 'action'}, |
| | {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ", |
| | 'label': 'causal relationship'}] |
| | """ |
| | ``` |
| |
|
| | ## Training Details |
| | ### Dataset |
| | Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData) |
| |
|
| | ### Training Parameters |
| | Learning Rate: 2e-5 |
| | Batch Size: 32 |
| | Epochs: 3 |
| | Optimizer: AdamW |
| |
|
| | ### Performance Metrics |
| | Accuracy: 97.3% |
| |
|
| | Recall: 96.59% |
| |
|
| | F1 Score: 96.96% |
| |
|
| | ## Notes |
| | 1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated. |
| | 2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance. |
| | 3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model. |
| |
|
| | ## Citation |
| |
|
| | If you use this model, please cite the following: |
| |
|
| | ``` |
| | @misc{BertForStorySkillClassification, |
| | author = {curious}, |
| | title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification}, |
| | year = {2025}, |
| | publisher = {Hugging Face}, |
| | howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}} |
| | } |
| | ``` |
| |
|
| | ## License |
| | This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file. |