Instructions to use ConvLab/roberta-base-trippy-dst-multiwoz21 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ConvLab/roberta-base-trippy-dst-multiwoz21 with Transformers:
# Load model directly from transformers import AutoTokenizer, TransformerForDST tokenizer = AutoTokenizer.from_pretrained("ConvLab/roberta-base-trippy-dst-multiwoz21") model = TransformerForDST.from_pretrained("ConvLab/roberta-base-trippy-dst-multiwoz21") - Notebooks
- Google Colab
- Kaggle
| language: | |
| - en | |
| license: apache-2.0 | |
| tags: | |
| - dialogue state tracking | |
| - task-oriented dialog | |
| # roberta-base-trippy-dst-multiwoz21 | |
| This is a TripPy model trained on [MultiWOZ 2.1](https://github.com/budzianowski/multiwoz) for use in [ConvLab-3](https://github.com/ConvLab/ConvLab-3). | |
| This model predicts informable slots, requestable slots, general actions and domain indicator slots. | |
| Expected joint goal accuracy for MultiWOZ 2.1 is in the range of 55-56\%. | |
| For information about TripPy DST, refer to [TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://aclanthology.org/2020.sigdial-1.4/). | |
| The training and evaluation code is available at the official [TripPy repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). | |
| ## Training procedure | |
| The model was trained on MultiWOZ 2.1 data via supervised learning using the [TripPy codebase](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). | |
| MultiWOZ 2.1 data was loaded via ConvLab-3's unified data format dataloader. | |
| The pre-trained encoder is [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta) (base). | |
| Fine-tuning the encoder and training the DST specific classification heads was conducted for 10 epochs. | |
| ### Training hyperparameters | |
| ``` | |
| python3 run_dst.py \ | |
| --task_name="unified" \ | |
| --model_type="roberta" \ | |
| --model_name_or_path="roberta-base" \ | |
| --dataset_config=dataset_config/unified_multiwoz21.json \ | |
| --do_lower_case \ | |
| --learning_rate=1e-4 \ | |
| --num_train_epochs=10 \ | |
| --max_seq_length=180 \ | |
| --per_gpu_train_batch_size=24 \ | |
| --per_gpu_eval_batch_size=32 \ | |
| --output_dir=results \ | |
| --save_epochs=2 \ | |
| --eval_all_checkpoints \ | |
| --warmup_proportion=0.1 \ | |
| --adam_epsilon=1e-6 \ | |
| --weight_decay=0.01 \ | |
| --fp16 \ | |
| --do_train \ | |
| --predict_type=dummy \ | |
| --seed=42 | |
| ``` | |