Upload shield-82m.ipynb
Browse files- shield-82m.ipynb +1 -0
shield-82m.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -q datasets transformers seqeval evaluate\n\nimport torch\nimport numpy as np\nfrom datasets import load_dataset\nfrom transformers import (AutoTokenizer, \n AutoModelForTokenClassification, \n DataCollatorForTokenClassification, \n TrainingArguments, \n Trainer)\nimport evaluate\n\nMODEL_NAME = \"distilroberta-base\"\nDATASET_NAME = \"ai4privacy/pii-masking-200k\"\nMAX_LENGTH = 512","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import numpy as np\nfrom datasets import load_dataset\n\nprint(\"Loading dataset...\")\nraw_datasets = load_dataset(DATASET_NAME, split='train[:20000]').train_test_split(test_size=0.1)\n\nprint(\"Extracting label list...\")\nunique_labels = set()\nfor mask_list in raw_datasets[\"train\"][\"privacy_mask\"]:\n for item in mask_list:\n unique_labels.add(item[\"label\"])\n\nlabel_list = [\"O\"] + sorted(list(unique_labels))\nlabel2id = {l: i for i, l in enumerate(label_list)}\nid2label = {i: l for i, l in enumerate(label_list)}\n\nprint(f\"Found classes incl. 'O': {label_list[:5]}...\")\n\ndef align_labels_with_spans(examples):\n tokenized_inputs = tokenizer(\n examples[\"source_text\"], \n truncation=True, \n max_length=MAX_LENGTH, \n return_offsets_mapping=True, \n padding=False\n )\n \n all_labels = []\n \n for i, spans in enumerate(examples[\"privacy_mask\"]):\n offsets = tokenized_inputs[\"offset_mapping\"][i]\n token_labels = []\n \n for idx, (o_start, o_end) in enumerate(offsets):\n if o_start == 0 and o_end == 0:\n token_labels.append(-100)\n continue\n \n label_id = 0\n \n for span in spans:\n if o_start >= span[\"start\"] and o_end <= span[\"end\"]:\n label_id = label2id[span[\"label\"]]\n break\n \n token_labels.append(label_id)\n \n all_labels.append(token_labels)\n \n tokenized_inputs[\"labels\"] = all_labels\n return tokenized_inputs\n\nprint(\"Tokenizing...\")\ntokenized_datasets = raw_datasets.map(\n align_labels_with_spans, \n batched=True, \n remove_columns=raw_datasets[\"train\"].column_names\n)\n\nprint(\"Data preparation for Shield 82M done.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"metric = evaluate.load(\"seqeval\")\n\ndef compute_metrics(p):\n predictions, labels = p\n predictions = np.argmax(predictions, axis=2)\n\n true_predictions = [\n [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n for prediction, label in zip(predictions, labels)\n ]\n true_labels = [\n [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n for prediction, label in zip(predictions, labels)\n ]\n results = metric.compute(predictions=true_predictions, references=true_labels)\n return {\n \"precision\": results[\"overall_precision\"],\n \"recall\": results[\"overall_recall\"],\n \"f1\": results[\"overall_f1\"],\n \"accuracy\": results[\"overall_accuracy\"],\n }\n\nmodel = AutoModelForTokenClassification.from_pretrained(\n MODEL_NAME, num_labels=len(label_list), id2label=id2label, label2id=label2id\n)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def count_parameters(model):\n return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\nparams = count_parameters(model)\nprint(f\"The model has {params:,} trainable params.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from transformers import DataCollatorForTokenClassification\n\ntraining_args = TrainingArguments(\n output_dir=\"./Shield\",\n eval_strategy=\"epoch\",\n learning_rate=2e-5,\n per_device_train_batch_size=16,\n per_device_eval_batch_size=16,\n num_train_epochs=3,\n weight_decay=0.01,\n report_to=\"none\",\n save_strategy=\"epoch\",\n load_best_model_at_end=True\n)\n\ndata_collator = DataCollatorForTokenClassification(tokenizer)\n\ntrainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=tokenized_datasets[\"train\"],\n eval_dataset=tokenized_datasets[\"test\"],\n data_collator=data_collator,\n compute_metrics=compute_metrics,\n)\n\ntrainer.train()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\n\nGROUPS = {\n \"FIRSTNAME\": \"PERSON\", \"MIDDLENAME\": \"PERSON\", \"LASTNAME\": \"PERSON\",\n \"BUILDINGNUMBER\": \"ADDRESS\", \"STREET\": \"ADDRESS\", \"CITY\": \"ADDRESS\", \n \"STATE\": \"ADDRESS\", \"ZIPCODE\": \"ADDRESS\", \"SECONDARYADDRESS\": \"ADDRESS\",\n \"EMAIL\": \"EMAIL\", \"PHONENUMBER\": \"PHONE\", \"PHONEIMEI\": \"PHONE\",\n \"DATE\": \"DOB\", \"TIME\": \"DOB\"\n}\n\ndef shield_filter_production(text):\n inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512, return_offsets_mapping=True).to(model.device)\n offsets = inputs.pop(\"offset_mapping\")[0].cpu().numpy()\n \n with torch.no_grad():\n outputs = model(**inputs).logits\n \n predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()\n \n spans_to_replace = []\n current_group = None\n start_char = -1\n last_char = -1\n \n for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):\n if offset[0] == 0 and offset[1] == 0:\n continue\n \n label = id2label[pred_id]\n \n if label == \"O\":\n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n current_group = None\n else:\n group_tag = GROUPS.get(label, label)\n \n if current_group != group_tag:\n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n current_group = group_tag\n start_char = offset[0]\n \n last_char = offset[1]\n \n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n \n filtered_text = text\n for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):\n filtered_text = filtered_text[:start] + f\"[{tag}]\" + filtered_text[end:]\n \n return filtered_text\n\ntests = [\n \"Mein Name ist Max Mustermann und ich wohne in der Hauptstraße 5, Berlin. Meine Email ist max@example.com.\",\n \"Liebe Lena, ich möchte dir heute mitteilen, dass ich ins Altmühltal umgezogen bin.\",\n \"Alice was born on 1990-01-02 and lives at 1 Main St.\",\n \"Mon e-mail est jean.dupont@example.fr et mon téléphone est +33 6 12 34 56 78.\"\n]\n\nfor t in tests:\n print(f\"In: {t}\")\n print(f\"Out: {shield_filter_production(t)}\\n\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import os\n\nsave_directory = \"./Shield-v1-final\"\nos.makedirs(save_directory, exist_ok=True)\n\ntrainer.save_model(save_directory)\ntokenizer.save_pretrained(save_directory)\n\nprint(f\"Model files saved to {save_directory}.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import shutil\n\nshutil.make_archive(\"Shield_v1_Model\", 'zip', save_directory)\n\nprint(\"Success! Zipped successfully :D\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\nfrom transformers import AutoTokenizer, AutoModelForTokenClassification\n\nclass ShieldFilter:\n def __init__(self, model_path=\"LH-Tech-AI/Shield-82M\"):\n print(f\"Loading Shield-82M from {model_path}...\")\n self.tokenizer = AutoTokenizer.from_pretrained(model_path)\n self.model = AutoModelForTokenClassification.from_pretrained(model_path)\n self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n self.model.to(self.device)\n self.model.eval()\n\n self.group_map = {\n # Personal\n \"FIRSTNAME\": \"PERSON\", \"MIDDLENAME\": \"PERSON\", \"LASTNAME\": \"PERSON\", \n \"USERNAME\": \"PERSON\", \"PREFIX\": \"PERSON\", \"AGE\": \"AGE\", \"GENDER\": \"GENDER\", \"SEX\": \"GENDER\",\n \n # Adress and location\n \"BUILDINGNUMBER\": \"ADDRESS\", \"STREET\": \"ADDRESS\", \"CITY\": \"ADDRESS\", \n \"STATE\": \"ADDRESS\", \"ZIPCODE\": \"ADDRESS\", \"SECONDARYADDRESS\": \"ADDRESS\",\n \"COUNTY\": \"ADDRESS\", \"NEARBYGPSCOORDINATE\": \"LOCATION\", \"ORDINALDIRECTION\": \"LOCATION\",\n \n # Contact\n \"EMAIL\": \"EMAIL\", \"PHONENUMBER\": \"PHONE\", \"PHONEIMEI\": \"PHONE\", \"URL\": \"URL\",\n \n # Finances\n \"IBAN\": \"BANK_ACCOUNT\", \"BIC\": \"BANK_ACCOUNT\", \"ACCOUNTNUMBER\": \"BANK_ACCOUNT\",\n \"CREDITCARDNUMBER\": \"CREDIT_CARD\", \"CREDITCARDCVV\": \"CREDIT_CARD\", \"CREDITCARDISSUER\": \"CREDIT_CARD\",\n \"BITCOINADDRESS\": \"CRYPTO\", \"ETHEREUMADDRESS\": \"CRYPTO\", \"LITECOINADDRESS\": \"CRYPTO\",\n \"AMOUNT\": \"AMOUNT\", \"CURRENCY\": \"AMOUNT\", \"CURRENCYCODE\": \"AMOUNT\", \n \"CURRENCYNAME\": \"AMOUNT\", \"CURRENCYSYMBOL\": \"AMOUNT\",\n \n # IT & Security\n \"IP\": \"IT_INFO\", \"IPV4\": \"IT_INFO\", \"IPV6\": \"IT_INFO\", \"MAC\": \"IT_INFO\", \n \"PASSWORD\": \"PASSWORD\", \"PIN\": \"PASSWORD\", \"USERAGENT\": \"IT_INFO\",\n \n # Work\n \"COMPANYNAME\": \"ORGANIZATION\", \"JOBTITLE\": \"JOB\", \"JOBAREA\": \"JOB\", \"JOBTYPE\": \"JOB\",\n \n # Documents and vehicles\n \"SSN\": \"ID_DOC\", \"VEHICLEVIN\": \"VEHICLE\", \"VEHICLEVRM\": \"VEHICLE\",\n \n # Time\n \"DATE\": \"DOB\", \"DOB\": \"DOB\", \"TIME\": \"TIME\"\n }\n\n def protect(self, text):\n inputs = self.tokenizer(\n text, \n return_tensors=\"pt\", \n truncation=True, \n max_length=512, \n return_offsets_mapping=True\n ).to(self.device)\n \n offsets = inputs.pop(\"offset_mapping\")[0].cpu().numpy()\n \n with torch.no_grad():\n outputs = self.model(**inputs).logits\n \n predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()\n id2label = self.model.config.id2label\n \n spans_to_replace = []\n current_group = None\n start_char = -1\n last_char = -1\n \n for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):\n if offset[0] == 0 and offset[1] == 0:\n continue\n \n label = id2label[pred_id]\n \n if label == \"O\":\n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n current_group = None\n else:\n group_tag = self.group_map.get(label, label)\n \n if current_group != group_tag:\n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n current_group = group_tag\n start_char = offset[0]\n \n last_char = offset[1]\n \n if current_group is not None:\n spans_to_replace.append((start_char, last_char, current_group))\n \n filtered_text = text\n for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):\n filtered_text = filtered_text[:start] + f\"[{tag}]\" + filtered_text[end:]\n \n return filtered_text\n\nif __name__ == \"__main__\":\n shield = ShieldFilter()\n sample = \"My name is John Doe. Email: john@example.com. Phone: +49 123 45678.\"\n print(f\"Original: {sample}\")\n print(f\"Protected: {shield.protect(sample)}\")","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
|