| --- |
| license: apache-2.0 |
| datasets: |
| - WorkInTheDark/FairytaleQA |
| language: |
| - en |
| metrics: |
| - f1 |
| - accuracy |
| - recall |
| base_model: |
| - google-bert/bert-base-uncased |
| pipeline_tag: text-classification |
| library_name: transformers |
| --- |
| # BertForStorySkillClassification |
|
|
| ## Model Overview |
| `BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes: |
| 1. **Character** |
| 2. **Setting** |
| 3. **Feeling** |
| 4. **Action** |
| 5. **Causal Relationship** |
| 6. **Outcome Resolution** |
| 7. **Prediction** |
|
|
| This model is suitable for applications in education, literary analysis, and story comprehension. |
|
|
| --- |
|
|
| ## Model Architecture |
| - **Base Model**: `bert-base-uncased` |
| - **Classification Layer**: A fully connected layer on top of BERT for 7-class classification. |
| - **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ") |
| - **Output**: Predicted label and confidence score. |
|
|
| --- |
|
|
| ## Quick Start |
|
|
| ### Install Dependencies |
| Ensure you have the `transformers` library installed: |
| ```bash |
| pip install transformers |
| ``` |
|
|
| ### Load Model and Tokenizer |
|
|
| ```python |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| |
| model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification") |
| tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification") |
| ``` |
|
|
| ### Use the predict Method for Inference |
|
|
| ```python |
| # Single text prediction |
| result = model.predict( |
| texts="Where does this story take place?", |
| tokenizer=tokenizer, |
| return_probabilities=True |
| ) |
| print(result) |
| # Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}] |
| |
| # Batch prediction |
| results = model.predict( |
| texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "], |
| tokenizer=tokenizer, |
| batch_size=16, |
| device="cuda" |
| ) |
| print(results) |
| """ |
| output: |
| [{'text': 'Why is the character sad?', 'label': 'causal relationship'}, |
| {'text': 'How does the story end?', 'label': 'action'}, |
| {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ", |
| 'label': 'causal relationship'}] |
| """ |
| ``` |
|
|
| ## Training Details |
| ### Dataset |
| Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData) |
|
|
| ### Training Parameters |
| Learning Rate: 2e-5 |
| Batch Size: 32 |
| Epochs: 3 |
| Optimizer: AdamW |
|
|
| ### Performance Metrics |
| Accuracy: 97.3% |
|
|
| Recall: 96.59% |
|
|
| F1 Score: 96.96% |
|
|
| ## Notes |
| 1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated. |
| 2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance. |
| 3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model. |
|
|
| ## Citation |
|
|
| If you use this model, please cite the following: |
|
|
| ``` |
| @misc{BertForStorySkillClassification, |
| author = {curious}, |
| title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification}, |
| year = {2025}, |
| publisher = {Hugging Face}, |
| howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}} |
| } |
| ``` |
|
|
| ## License |
| This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file. |