santh-cpu commited on
Commit
ec6dc5b
·
verified ·
1 Parent(s): 2be77de

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -2
README.md CHANGED
@@ -33,10 +33,105 @@ To use this model in your own application, download the weights directly from th
33
  from huggingface_hub import hf_hub_download
34
  import torch
35
 
36
- # Download weights
37
  weights_path = hf_hub_download(repo_id="santh-cpu/ai_code_detect", filename="pytorch_model.bin")
38
 
39
- # Load into your architecture
40
  model = TemporalFusionClassifier(base_model)
41
  model.load_state_dict(torch.load(weights_path))
42
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  from huggingface_hub import hf_hub_download
34
  import torch
35
 
 
36
  weights_path = hf_hub_download(repo_id="santh-cpu/ai_code_detect", filename="pytorch_model.bin")
37
 
 
38
  model = TemporalFusionClassifier(base_model)
39
  model.load_state_dict(torch.load(weights_path))
40
  model.eval()
41
+ ```
42
+
43
+ ### Example
44
+ ```python
45
+ import torch
46
+ import torch.nn as nn
47
+ import torch.nn.functional as F
48
+ from transformers import RobertaTokenizer, T5EncoderModel, AutoTokenizer, AutoModelForMaskedLM
49
+ from huggingface_hub import hf_hub_download
50
+
51
+ class TemporalFusionClassifier(nn.Module):
52
+ def __init__(self, base, metric_dim=7):
53
+ super().__init__()
54
+ self.base = base
55
+ h = base.config.hidden_size
56
+
57
+ self.metric_cnn = nn.Sequential(
58
+ nn.Conv1d(metric_dim, 32, 3, padding=1),
59
+ nn.BatchNorm1d(32),
60
+ nn.ReLU(),
61
+ nn.MaxPool1d(2),
62
+ nn.Conv1d(32, 64, 3, padding=1),
63
+ nn.BatchNorm1d(64),
64
+ nn.ReLU(),
65
+ nn.AdaptiveAvgPool1d(1)
66
+ )
67
+
68
+ self.classifier = nn.Sequential(
69
+ nn.Linear(h + 64, 1024),
70
+ nn.ReLU(),
71
+ nn.Dropout(0.1),
72
+ nn.Linear(1024, 1)
73
+ )
74
+
75
+ def forward(self, input_ids, attention_mask, metric_vector):
76
+ out = self.base(input_ids=input_ids, attention_mask=attention_mask)
77
+ hidden = out.last_hidden_state
78
+ mask = attention_mask.unsqueeze(-1).float()
79
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-4)
80
+
81
+ cnn_features = self.metric_cnn(metric_vector.transpose(1, 2)).squeeze(-1)
82
+ return self.classifier(torch.cat([pooled, cnn_features], dim=1))
83
+
84
+ class AICodeDetector:
85
+ def __init__(self, repo_id="santh-cpu/ai_code_detect"):
86
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ self.max_len = 256
88
+
89
+ self.cb_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base-mlm")
90
+ self.cb_model = AutoModelForMaskedLM.from_pretrained("microsoft/codebert-base-mlm").to(self.device).eval()
91
+
92
+ self.t5_tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-base")
93
+ base_t5 = T5EncoderModel.from_pretrained("Salesforce/codet5-base")
94
+
95
+ weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
96
+ self.detector = TemporalFusionClassifier(base_t5).to(self.device)
97
+ self.detector.load_state_dict(torch.load(weights_path, map_location=self.device))
98
+ self.detector.eval()
99
+
100
+ def analyze(self, code_snippet):
101
+ with torch.no_grad():
102
+ cb_in = self.cb_tokenizer(code_snippet, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_len).to(self.device)
103
+ logits = self.cb_model(**cb_in).logits
104
+
105
+ seq_len = cb_in["attention_mask"][0].sum().item()
106
+ metrics = torch.zeros((1, self.max_len, 7), device=self.device)
107
+
108
+ if seq_len > 1:
109
+ seq_logits = logits[0:1, :seq_len-1, :]
110
+ seq_labels = cb_in["input_ids"][0:1, 1:seq_len]
111
+ probs = F.softmax(seq_logits, dim=-1)
112
+
113
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
114
+ ranks = (torch.argsort(seq_logits, dim=-1, descending=True) == seq_labels.unsqueeze(-1)).nonzero(as_tuple=True)[2].view(1, -1) + 1
115
+
116
+ token_metrics = torch.stack([
117
+ torch.log(probs.gather(2, seq_labels.unsqueeze(-1)).squeeze(-1) + 1e-9),
118
+ torch.log(ranks.float()),
119
+ entropy,
120
+ (ranks <= 10).float(),
121
+ ((ranks > 10) & (ranks <= 100)).float(),
122
+ ((ranks > 100) & (ranks <= 1000)).float(),
123
+ (ranks > 1000).float()
124
+ ], dim=-1)
125
+ metrics[0, :token_metrics.size(1), :] = token_metrics[0]
126
+
127
+ clean_metrics = torch.nan_to_num(metrics, nan=0.0, posinf=10.0, neginf=-100.0)
128
+ t5_in = self.t5_tokenizer(code_snippet, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_len).to(self.device)
129
+ prob = torch.sigmoid(self.detector(t5_in["input_ids"], t5_in["attention_mask"], clean_metrics)).item()
130
+
131
+ return {"prediction": "AI Generated" if prob > 0.5 else "Human Written", "ai_probability": round(prob * 100, 2)}
132
+
133
+ if __name__ == "__main__":
134
+ detector = AICodeDetector()
135
+ sample = "def fib(n):\n a, b = 0, 1\n for _ in range(n):\n yield a\n a, b = b, a + b"
136
+ print(detector.analyze(sample))
137
+ ```