Instructions to use curious008/BertForStorySkillClassification with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use curious008/BertForStorySkillClassification with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="curious008/BertForStorySkillClassification", trust_remote_code=True)# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification", trust_remote_code=True) model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| from typing import Dict, List, Union | |
| from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer | |
| import torch.nn as nn | |
| import torch | |
| class BertForStorySkillClassification(BertPreTrainedModel): | |
| def __init__(self,config): | |
| super(BertForStorySkillClassification,self).__init__(config) | |
| self.num_labels = config.num_labels | |
| self.bert = BertModel(config) | |
| self.classifier = nn.Linear(config.hidden_size, self.num_labels) | |
| self.post_init() | |
| def forward(self,input_ids,attention_mask=None,labels=None,**kwargs): | |
| outputs = self.bert(input_ids,attention_mask=attention_mask) | |
| cls_hidden_state = outputs.last_hidden_state[:,0,:] ## [batch_size,seq_len,hidden_size] | |
| logits = self.classifier(cls_hidden_state) ## [batch_size,num_labels] | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1)) | |
| return loss | |
| return logits | |
| def predict( | |
| self, | |
| texts: Union[str, List[str]], | |
| tokenizer: PreTrainedTokenizer, | |
| batch_size: int = 32, | |
| return_probabilities: bool = False, | |
| device: Union[str, torch.device] = 'cpu', | |
| ) -> List[Dict]: | |
| """ | |
| 对输入文本进行分类预测。 | |
| Args: | |
| texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"] | |
| tokenizer: 分词器实例(需与模型兼容) | |
| batch_size: 批处理大小(提升推理速度) | |
| return_probabilities: 是否返回概率值(默认返回标签) | |
| device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备 | |
| Returns: | |
| 预测结果列表,格式为: | |
| [{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...] | |
| """ | |
| # 自动获取模型所在设备 | |
| if device is None: | |
| device = self.device | |
| # 统一输入格式为列表 | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # 结果存储 | |
| predictions = [] | |
| # 批处理预测 | |
| with torch.no_grad(): | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i : i + batch_size] | |
| # 分词并转换为张量 | |
| inputs = tokenizer( | |
| batch_texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, # 与BERT最大长度一致 | |
| ).to(device) | |
| # 模型推理 | |
| logits = self(**inputs) | |
| probs = torch.softmax(logits, dim=-1) | |
| scores, class_ids = torch.max(probs, dim=-1) | |
| # 转换为标签和分数 | |
| for text, class_id, score in zip(batch_texts, class_ids, scores): | |
| label = self.config.id2label[class_id.item()] | |
| result = {"text": text, "label": label} | |
| if return_probabilities: | |
| result["score"] = score.item() | |
| predictions.append(result) | |
| return predictions |