korean-pii-e5-base / usage.py
vijaym's picture
Upload folder using huggingface_hub
2bf9c60 verified
#!/usr/bin/env python3
"""Minimal runnable example for FrameByFrame/korean-pii-e5-base.
pip install "transformers>=4.40" torch safetensors
python usage.py
"""
import os, re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
MODEL_ID = os.environ.get("MODEL_ID", "FrameByFrame/korean-pii-e5-base")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
model.cuda()
_TRAILING_JOSA = ["μ΄μ—μš”","이라고","μž…λ‹ˆλ‹€","이야","μ΄λž‘","ν•œν…Œ","μ—κ²Œ","으둜","이가","μ΄λŠ”",
"μ—μ„œ","이고","μ˜ˆμš”","씨","λ‹˜","이","κ°€","은","λŠ”","을","λ₯Ό","μ•Ό","μ•„","에","의","λž‘","께","κ³ "]
_DATE_END = re.compile(r".*(?:일|[0-9])", re.S)
def _normalize(text, label, s, e):
while s < e and text[s] in " .,\t\n": s += 1
while e > s and text[e - 1] in " .,\t\n": e -= 1
if label == "private_date":
m = _DATE_END.match(text[s:e])
if m and m.end() > 0:
e = s + m.end()
elif label in ("private_person", "personal_handle", "private_address"):
for _ in range(2):
seg = text[s:e]
for j in _TRAILING_JOSA:
if seg.endswith(j) and (e - s) - len(j) >= 2:
e -= len(j); break
else:
break
return s, e
def extract_pii(text, max_length=256):
enc = tokenizer(text, truncation=True, max_length=max_length,
return_offsets_mapping=True, return_tensors="pt")
offsets = enc.pop("offset_mapping")[0].tolist()
with torch.no_grad():
logits = model(**{k: v.to(model.device) for k, v in enc.items()}).logits
pred = logits.argmax(-1)[0].tolist()
id2label = model.config.id2label
spans, active = [], None
for i, lid in enumerate(pred):
label = id2label[int(lid)]
cs, ce = offsets[i]
if cs == ce:
if active: spans.append(active); active = None
continue
if label == "O":
if active: spans.append(active); active = None
continue
prefix, cat = label.split("-", 1)
if prefix in ("B", "S") or not active or active[0] != cat:
if active: spans.append(active)
active = [cat, cs, ce]
else:
active[2] = ce
if active: spans.append(active)
out = []
for cat, s, e in spans:
s, e = _normalize(text, cat, s, e)
if text[s:e].strip():
out.append({"label": cat, "start": s, "end": e, "text": text[s:e]})
return out
def redact(text):
spans = sorted(extract_pii(text), key=lambda s: s["start"], reverse=True)
for s in spans:
text = text[:s["start"]] + f"[{s['label'].upper()}]" + text[s["end"]:]
return text
if __name__ == "__main__":
for t in ["κΉ€λ―Όμˆ˜λ‹˜μ˜ λ²ˆν˜ΈλŠ” 010-1234-5678μž…λ‹ˆλ‹€.",
"κ³„μ’Œ 110-234-567890으둜 μž…κΈˆν•˜κ³  minsu@example.com으둜 μ•Œλ €μ£Όμ„Έμš”.",
"μ΄μˆ˜μ§„ κ³ κ°λ‹˜ 생년월일은 1985λ…„ 3μ›” 12μΌμž…λ‹ˆλ‹€."]:
print(t)
for sp in extract_pii(t):
print(f" {sp['label']:16} {sp['text']!r}")
print(" REDACT:", redact(t)); print()