| | --- |
| | license: mit |
| | language: |
| | - ru |
| | widget: |
| | - text: 'привет' |
| | example_title: example_1 |
| | - text: 'тебя как звать' |
| | example_title: example_2 |
| | - text: 'как приготовить рагу' |
| | example_title: example_3 |
| | - text: 'в чем смысл жизни' |
| | example_title: example_4 |
| | - text: 'у меня кот сбежал' |
| | example_title: example_5 |
| | - text: 'что такое спидометр' |
| | example_title: example_6 |
| | - text: 'меня артур зовут' |
| | example_title: example_7 |
| | --- |
| | # Den4ikAI/ruBert-tiny-replicas-classifier |
| | Описание классов: |
| | 1. about_user - реагирует, когда пользователь говорит о себе. Например, "меня зовут андрей" |
| | 2. question - реагирует на вопросы |
| | 3. instruct - реагирует на вопросы, ответ на которые представляет собой инструкцию. Например, "как установить windows, как приготовить борщ" |
| | 4. about_system - реагирует на вопросы о личности ассистента. Например, "как тебя зовут, ты кто такая" |
| | 5. problem - реагирует на реплики, где пользователь рассказывает о своих проблемах. Например, "у меня болит зуб, мне проткнули колесо" |
| | 6. dialogue - реагирует на диалоговые реплики. Например, "привет" |
| |
|
| | Примечание: модель обучалась без знаков '?' |
| |
|
| | # Использование |
| | ```python |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | tokenizer = AutoTokenizer.from_pretrained('Den4ikAI/ruBert-tiny-replicas-classifier') |
| | model = AutoModelForSequenceClassification.from_pretrained('Den4ikAI/ruBert-tiny-replicas-classifier') |
| | model.to(device) |
| | model.eval() |
| | |
| | classes = ['instruct', 'question', 'dialogue', 'problem', 'about_system', 'about_user'] |
| | |
| | |
| | def get_sentence_type(text): |
| | inputs = tokenizer(text, max_length=512, add_special_tokens=False, return_tensors='pt').to(device) |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | probas = list(torch.sigmoid(logits)[0].cpu().detach().numpy()) |
| | out = classes[probas.index(max(probas))] |
| | return out |
| | |
| | while 1: |
| | print(get_sentence_type(input(":> "))) |
| | |
| | ``` |