| import torch |
| from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
| class ShieldFilter: |
| def __init__(self, model_path="LH-Tech-AI/Shield-82M"): |
| print(f"Loading Shield-82M from {model_path}...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.model = AutoModelForTokenClassification.from_pretrained(model_path) |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| self.group_map = { |
| "FIRSTNAME": "PERSON", "MIDDLENAME": "PERSON", "LASTNAME": "PERSON", |
| "BUILDINGNUMBER": "ADDRESS", "STREET": "ADDRESS", "CITY": "ADDRESS", |
| "STATE": "ADDRESS", "ZIPCODE": "ADDRESS", "SECONDARYADDRESS": "ADDRESS", |
| "EMAIL": "EMAIL", "PHONENUMBER": "PHONE", "PHONEIMEI": "PHONE", |
| "DATE": "DOB", "TIME": "DOB" |
| } |
|
|
| def protect(self, text): |
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| return_offsets_mapping=True |
| ).to(self.device) |
| |
| offsets = inputs.pop("offset_mapping")[0].cpu().numpy() |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs).logits |
| |
| predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy() |
| id2label = self.model.config.id2label |
| |
| spans_to_replace = [] |
| current_group = None |
| start_char = -1 |
| last_char = -1 |
| |
| for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)): |
| if offset[0] == 0 and offset[1] == 0: |
| continue |
| |
| label = id2label[pred_id] |
| |
| if label == "O": |
| if current_group is not None: |
| spans_to_replace.append((start_char, last_char, current_group)) |
| current_group = None |
| else: |
| group_tag = self.group_map.get(label, label) |
| |
| if current_group != group_tag: |
| if current_group is not None: |
| spans_to_replace.append((start_char, last_char, current_group)) |
| current_group = group_tag |
| start_char = offset[0] |
| |
| last_char = offset[1] |
| |
| if current_group is not None: |
| spans_to_replace.append((start_char, last_char, current_group)) |
| |
| filtered_text = text |
| for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True): |
| filtered_text = filtered_text[:start] + f"[{tag}]" + filtered_text[end:] |
| |
| return filtered_text |
|
|
| if __name__ == "__main__": |
| shield = ShieldFilter() |
| sample = "My name is John Doe. Email: john@example.com. Phone: +49 123 45678." |
| print(f"Original: {sample}") |
| print(f"Protected: {shield.protect(sample)}") |