| | import re |
| | from typing import Dict, List, Any |
| | from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline |
| | import torch |
| |
|
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | |
| | |
| | |
| |
|
| | |
| | self.date_model_path = path + "/deberta-qa-finetuned" |
| | self.date_tokenizer = AutoTokenizer.from_pretrained(self.date_model_path) |
| | self.date_model = AutoModelForQuestionAnswering.from_pretrained(self.date_model_path) |
| |
|
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.date_model.to(self.device) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | start_date = self.remove_special_characters(self.extract_start_date(data["inputs"])) |
| | end_date = self.remove_special_characters(self.extract_end_date(data["inputs"])) |
| | return {"start_date": start_date, "end_date": end_date} |
| |
|
| | def remove_special_characters(self, s): |
| | return re.sub(r'(?<!\d)[^\w\s/]+|[^\w\s/]+(?!\d)', '', s).strip() |
| |
|
| |
|
| | def extract_start_date(self, text): |
| | question = "What is the start date?" |
| |
|
| | |
| | inputs = self.date_tokenizer(question, text, return_tensors="pt") |
| | if torch.cuda.is_available(): |
| | inputs = {k: v.cuda() for k, v in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.date_model(**inputs) |
| |
|
| | |
| | answer_start = torch.argmax(outputs.start_logits) |
| | answer_end = torch.argmax(outputs.end_logits) + 1 |
| |
|
| | |
| | answer_tokens = inputs["input_ids"][0][answer_start:answer_end] |
| | answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True) |
| |
|
| | return answer |
| |
|
| | def extract_end_date(self, text): |
| | question = "What is the end date?" |
| |
|
| | |
| | inputs = self.date_tokenizer(question, text, return_tensors="pt") |
| | if torch.cuda.is_available(): |
| | inputs = {k: v.cuda() for k, v in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.date_model(**inputs) |
| |
|
| | |
| | answer_start = torch.argmax(outputs.start_logits) |
| | answer_end = torch.argmax(outputs.end_logits) + 1 |
| |
|
| | |
| | answer_tokens = inputs["input_ids"][0][answer_start:answer_end] |
| | answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True) |
| |
|
| | return answer |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from handler import EndpointHandler |
| |
|
| | |
| | my_handler = EndpointHandler(path=".") |
| |
|
| | |
| | non_holiday_payload = {"inputs": "I am quite excited how this will turn out 08-08-2025 - 09-08-2025"} |
| | |
| |
|
| | |
| | non_holiday_pred=my_handler(non_holiday_payload) |
| | |
| |
|
| | |
| | print(non_holiday_pred) |
| | |