| | from typing import Dict, Any, List |
| | import torch |
| | from transformers import AutoTokenizer, AutoModel |
| | import os |
| | import json |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.tokenizer.add_special_tokens({ |
| | "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"] |
| | }) |
| | self.model = AutoModel.from_pretrained(path).to(self.device) |
| |
|
| | head_path = os.path.join(path, "classifier_head.json") |
| | with open(head_path, "r") as f: |
| | head = json.load(f) |
| |
|
| | self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device) |
| | self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device) |
| | self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device) |
| |
|
| | self.model.eval() |
| | |
| | |
| | self.max_batch_size = 128 |
| | self.max_length = 64 |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | payload = data.get("inputs", data) |
| | |
| | |
| | if "queries" in payload: |
| | return self._process_batch(payload) |
| | else: |
| | return self._process_single(payload) |
| | |
| | def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """Original single query processing for backward compatibility""" |
| | query = payload["query"] |
| | candidates = payload["candidates"] |
| | results = [] |
| |
|
| | with torch.no_grad(): |
| | for entry in candidates: |
| | text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}" |
| | tokens = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.max_length |
| | ).to(self.device) |
| |
|
| | out = self.model(**tokens) |
| | cls = out.last_hidden_state[:, 0, :] |
| | score = torch.sigmoid(self.classifier(cls)).item() |
| | results.append({ |
| | "label": entry["label"], |
| | "description": entry["description"], |
| | "score": round(score, 4) |
| | }) |
| |
|
| | return sorted(results, key=lambda x: x["score"], reverse=True) |
| | |
| | def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]: |
| | """True batch processing for multiple queries""" |
| | queries = payload["queries"] |
| | candidates = payload["candidates"] |
| | |
| | |
| | all_texts = [] |
| | query_indices = [] |
| | candidate_indices = [] |
| | |
| | for q_idx, query in enumerate(queries): |
| | for c_idx, candidate in enumerate(candidates): |
| | text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}" |
| | all_texts.append(text) |
| | query_indices.append(q_idx) |
| | candidate_indices.append(c_idx) |
| | |
| | |
| | all_scores = [] |
| | total_combinations = len(all_texts) |
| | |
| | with torch.no_grad(): |
| | for i in range(0, total_combinations, self.max_batch_size): |
| | batch_texts = all_texts[i:i + self.max_batch_size] |
| | |
| | |
| | tokens = self.tokenizer( |
| | batch_texts, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.max_length |
| | ).to(self.device) |
| | |
| | |
| | out = self.model(**tokens) |
| | cls = out.last_hidden_state[:, 0, :] |
| | scores = torch.sigmoid(self.classifier(cls)).squeeze() |
| | |
| | |
| | if scores.dim() == 0: |
| | scores = scores.unsqueeze(0) |
| | |
| | all_scores.extend(scores.cpu().tolist()) |
| | |
| | |
| | results = [] |
| | for q_idx in range(len(queries)): |
| | query_results = [] |
| | for c_idx, candidate in enumerate(candidates): |
| | |
| | combination_idx = q_idx * len(candidates) + c_idx |
| | score = all_scores[combination_idx] |
| | |
| | query_results.append({ |
| | "label": candidate["label"], |
| | "description": candidate["description"], |
| | "score": round(score, 4) |
| | }) |
| | |
| | |
| | query_results.sort(key=lambda x: x["score"], reverse=True) |
| | results.append(query_results) |
| | |
| | return results |
| | |
| | def get_batch_stats(self) -> Dict[str, Any]: |
| | """Return batch processing statistics""" |
| | return { |
| | "max_batch_size": self.max_batch_size, |
| | "max_length": self.max_length, |
| | "device": str(self.device), |
| | "model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown" |
| | } |
| |
|
| |
|