Eraly-ml commited on
Commit
5e03908
·
verified ·
1 Parent(s): ab29cac

Add inference script

Browse files
Files changed (1) hide show
  1. inference.py +68 -0
inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for UnixCoder-MIL
3
+ =====================================
4
+ Usage: Simply run this script with your code samples
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForSequenceClassification
11
+ from safetensors.torch import load_file
12
+ import numpy as np
13
+
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ CLASS_NAMES = ["Human", "AI-Generated", "Hybrid", "Adversarial"]
16
+
17
+ class MilUnixCoder(nn.Module):
18
+ def __init__(self, model_name="microsoft/unixcoder-base", chunk_size=512, stride=256, max_chunks=16):
19
+ super().__init__()
20
+ self.config = AutoConfig.from_pretrained(model_name)
21
+ self.unixcoder = AutoModel.from_pretrained(model_name)
22
+ self.chunk_size, self.stride, self.max_chunks = chunk_size, stride, max_chunks
23
+ self.classifier = nn.Linear(self.config.hidden_size, 4)
24
+ self.dropout = nn.Dropout(0.1)
25
+ def forward(self, input_ids, attention_mask=None):
26
+ B, L = input_ids.size()
27
+ if attention_mask is None: attention_mask = torch.ones_like(input_ids)
28
+ if L > self.chunk_size:
29
+ c_ids = input_ids.unfold(1, self.chunk_size, self.stride)
30
+ c_mask = attention_mask.unfold(1, self.chunk_size, self.stride)
31
+ nc = min(c_ids.size(1), self.max_chunks)
32
+ flat_ids = c_ids[:,:nc,:].contiguous().view(-1, self.chunk_size)
33
+ flat_mask = c_mask[:,:nc,:].contiguous().view(-1, self.chunk_size)
34
+ else:
35
+ nc, flat_ids, flat_mask = 1, input_ids, attention_mask
36
+ out = self.unixcoder(input_ids=flat_ids, attention_mask=flat_mask)
37
+ logits = self.classifier(self.dropout(out.last_hidden_state[:, 0, :]))
38
+ return torch.max(logits.view(B, nc, -1), dim=1)[0]
39
+
40
+ def load_model():
41
+ """Load the model and tokenizer"""
42
+ tokenizer = AutoTokenizer.from_pretrained("YoungDSMLKZ/UnixCoder-MIL")
43
+ model = MilUnixCoder("microsoft/unixcoder-base")
44
+ model.load_state_dict(load_file("YoungDSMLKZ/UnixCoder-MIL/model.safetensors"))
45
+ model.to(DEVICE).eval()
46
+ return model, tokenizer
47
+
48
+ def predict(code: str, model, tokenizer) -> dict:
49
+ """Predict class for a single code sample"""
50
+ inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=4096, padding=True).to(DEVICE)
51
+ with torch.no_grad():
52
+ logits = model(inputs["input_ids"], inputs["attention_mask"])
53
+ probs = F.softmax(logits, dim=-1)[0]
54
+ pred = torch.argmax(probs).item()
55
+ return {"class": CLASS_NAMES[pred], "confidence": probs[pred].item()}
56
+
57
+ if __name__ == "__main__":
58
+ print("Loading model...")
59
+ model, tokenizer = load_model()
60
+
61
+ # Example usage
62
+ test_code = """
63
+ def hello_world():
64
+ print("Hello, World!")
65
+ """
66
+
67
+ result = predict(test_code, model, tokenizer)
68
+ print(f"Predicted: {result['class']} (confidence: {result['confidence']:.2%})")