Pujan-Dev commited on
Commit
c8f4e3f
·
1 Parent(s): ddbc845

Add document forgery detection feature and refactor model loading

Browse files

- Introduced class for detecting document forgery using ELA-trained EfficientNet model.
- Updated to support loading document forgery model from local path.
- Added new API endpoint to check if an uploaded document is forged.
- Refactored imports in various modules for consistency and clarity.

features/real_forged_classifier/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Package for real_forged_classifier feature.
2
+
3
+ This module ensures package-relative imports work when importing
4
+ `features.real_forged_classifier.*` from the application.
5
+ """
6
+
7
+ __all__ = [
8
+ 'controller', 'routes', 'preprocessor', 'inferencer', 'model_loader', 'model'
9
+ ]
features/real_forged_classifier/controller.py CHANGED
@@ -1,6 +1,15 @@
1
  from typing import IO
2
- from preprocessor import preprocessor
3
- from inferencer import interferencer
 
 
 
 
 
 
 
 
 
4
 
5
  class ClassificationController:
6
  """
@@ -34,3 +43,71 @@ class ClassificationController:
34
 
35
  # Create a single instance of the controller
36
  controller = ClassificationController()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import IO
2
+ import io
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+
8
+ from .preprocessor import preprocessor
9
+ from .inferencer import interferencer
10
+ from .model_loader import models
11
+ from config import Config
12
+
13
 
14
  class ClassificationController:
15
  """
 
43
 
44
  # Create a single instance of the controller
45
  controller = ClassificationController()
46
+
47
+ class documentForger:
48
+ """
49
+ Document forgery detector that uses the ELA-trained EfficientNet model
50
+ when available (models.doc_model). Returns a dict with verdict and confidence.
51
+ """
52
+ def is_forged(self, document_file: IO) -> dict:
53
+ # Ensure a document model is loaded
54
+ if not hasattr(models, 'doc_model') or models.doc_model is None:
55
+ return {"error": "Document forgery model not available."}
56
+
57
+ # Read file bytes
58
+ try:
59
+ data = document_file.read()
60
+ img = Image.open(io.BytesIO(data)).convert('RGB')
61
+ except Exception as e:
62
+ return {"error": f"Could not open document image: {e}"}
63
+
64
+ # Compute ELA map (same approach as the notebook)
65
+ try:
66
+ buf = io.BytesIO()
67
+ img.save(buf, format='JPEG', quality=90)
68
+ buf.seek(0)
69
+ recompressed = Image.open(buf).convert('RGB')
70
+
71
+ ela_arr = np.abs(np.array(img, dtype=np.float32) - np.array(recompressed, dtype=np.float32))
72
+ p99 = np.percentile(ela_arr, 99)
73
+ if p99 > 0:
74
+ ela_arr = np.clip(ela_arr * (255.0 / p99), 0, 255).astype(np.uint8)
75
+ else:
76
+ ela_arr = ela_arr.astype(np.uint8)
77
+
78
+ ela_pil = Image.fromarray(ela_arr, mode='RGB')
79
+ except Exception as e:
80
+ return {"error": f"Failed to compute ELA: {e}"}
81
+
82
+ # Transform and run through model
83
+ transform = transforms.Compose([
84
+ transforms.Resize((224, 224)),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
87
+ ])
88
+
89
+ tensor = transform(ela_pil).unsqueeze(0).to(models.device)
90
+
91
+ with torch.no_grad():
92
+ logits = models.doc_model(tensor)
93
+ probs = torch.softmax(logits, dim=1)[0, 1].item()
94
+
95
+ # Interpret confidence using configurable thresholds (values in 0..1)
96
+ low = getattr(Config, 'DOCUMENT_FORGERY_POSSIBLE_LOW', 0.40)
97
+ high = getattr(Config, 'DOCUMENT_FORGERY_FORGED_LOW', 0.55)
98
+
99
+ if probs < low:
100
+ verdict = 'LIKELY AUTHENTIC'
101
+ elif probs < high:
102
+ verdict = 'POSSIBLY FORGED'
103
+ else:
104
+ verdict = 'LIKELY FORGED'
105
+
106
+ return {
107
+ "verdict": verdict,
108
+ "confidence": float(probs),
109
+ "confidence_pct": round(float(probs) * 100, 2),
110
+ }
111
+
112
+ # Create a single instance of the document forger
113
+ document_forger = documentForger()
features/real_forged_classifier/inferencer.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn.functional as F
3
  import numpy as np
4
 
5
  # Import the globally loaded models instance
6
- from model_loader import models
7
 
8
  class Interferencer:
9
  """
@@ -26,6 +26,10 @@ class Interferencer:
26
  Returns:
27
  dict: A dictionary containing the classification label and confidence score.
28
  """
 
 
 
 
29
  # 1. Get model outputs (logits)
30
  outputs = self.fft_model(image_tensor)
31
 
 
3
  import numpy as np
4
 
5
  # Import the globally loaded models instance
6
+ from .model_loader import models
7
 
8
  class Interferencer:
9
  """
 
26
  Returns:
27
  dict: A dictionary containing the classification label and confidence score.
28
  """
29
+ # 0. Ensure model is loaded
30
+ if self.fft_model is None:
31
+ return {"error": "FFT model not loaded."}
32
+
33
  # 1. Get model outputs (logits)
34
  outputs = self.fft_model(image_tensor)
35
 
features/real_forged_classifier/model_loader.py CHANGED
@@ -1,61 +1,167 @@
1
- import torch
2
  from pathlib import Path
3
- from huggingface_hub import hf_hub_download
4
- from model import FFTCNN # Import the model architecture
5
  from config import Config
6
 
 
 
 
 
 
 
 
7
  class ModelLoader:
 
 
 
 
 
8
  """
9
- A class to load and hold the PyTorch CNN model.
10
- """
11
- def __init__(self, model_repo_id: str, model_filename: str):
12
- """
13
- Initializes the ModelLoader and loads the model.
14
 
15
- Args:
16
- model_repo_id (str): The repository ID on Hugging Face.
17
- model_filename (str): The name of the model file (.pth) in the repository.
18
- """
19
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- print(f"Using device: {self.device}")
 
 
 
 
 
 
 
 
 
 
21
 
22
- self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
23
- print("FFT CNN model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def _load_fft_model(self, repo_id: str, filename: str):
26
- """
27
- Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
 
28
 
29
- Args:
30
- repo_id (str): The repository ID on Hugging Face.
31
- filename (str): The name of the model file (.pth) in the repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- Returns:
34
- The loaded PyTorch model object.
35
- """
36
- print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
37
  try:
38
- # Download the model file from the Hub. It returns the cached path.
39
  model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=Config.HF_TOKEN)
40
  print(f"Model downloaded to: {model_path}")
41
-
42
- # Initialize the model architecture
43
  model = FFTCNN()
44
-
45
- # Load the saved weights (state_dict) into the model
46
  model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
47
-
48
- # Set the model to evaluation mode
49
  model.to(self.device)
50
  model.eval()
51
-
52
  return model
53
  except Exception as e:
54
- print(f"Error downloading or loading model from Hugging Face: {e}")
55
  raise
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # --- Global Model Instance ---
58
  MODEL_REPO_ID = Config.REAL_FORGED_MODEL_REPO_ID
59
  MODEL_FILENAME = Config.REAL_FORGED_MODEL_FILENAME
60
- models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)
 
61
 
 
 
1
  from pathlib import Path
2
+ from typing import Any
3
+ from .model import FFTCNN # Import the FFT CNN architecture (package-relative)
4
  from config import Config
5
 
6
+
7
+ # NOTE: EfficientNet/nn imports are done lazily when torch is available.
8
+ ELAForgeryNet = None # will be constructed dynamically when needed
9
+ torch = None
10
+ TORCH_AVAILABLE = False
11
+
12
+
13
  class ModelLoader:
14
+ """A class to load and hold PyTorch models used by this feature.
15
+
16
+ It loads:
17
+ - an FFT-based CNN (downloaded from Hugging Face Hub)
18
+ - an ELA-based document forgery detector (local .pth by default)
19
  """
 
 
 
 
 
20
 
21
+ def __init__(self, model_repo_id: str, model_filename: str, doc_model_path: str = None):
22
+ # Try to import torch once and expose module-level variables
23
+ global torch, TORCH_AVAILABLE
24
+ try:
25
+ import torch as _torch
26
+ torch = _torch
27
+ TORCH_AVAILABLE = True
28
+ except Exception:
29
+ torch = None
30
+ TORCH_AVAILABLE = False
31
+ print("[WARN] PyTorch not available; model loading will be skipped until torch is installed.")
32
+ if TORCH_AVAILABLE:
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ else:
35
+ self.device = "cpu"
36
+ print(f"Using device: {self.device} (torch available: {TORCH_AVAILABLE})")
37
 
38
+ # Load FFT CNN from HF Hub
39
+ self.fft_model = None
40
+ if TORCH_AVAILABLE:
41
+ try:
42
+ self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
43
+ print("FFT CNN model loaded successfully from Hub.")
44
+ except Exception:
45
+ # Try local fallback path (if provided in config)
46
+ self.fft_model = None
47
+ local_path = Path(getattr(Config, 'REAL_FORGED_MODEL_LOCAL_PATH', ''))
48
+ if local_path and local_path.exists():
49
+ try:
50
+ print(f"Attempting to load FFT model from local path: {local_path}")
51
+ model = FFTCNN()
52
+ state = torch.load(str(local_path), map_location=torch.device(self.device))
53
+ state_dict = state.get('state_dict', state) if isinstance(state, dict) else state
54
+ model.load_state_dict(state_dict, strict=False)
55
+ model.to(self.device)
56
+ model.eval()
57
+ self.fft_model = model
58
+ print("FFT CNN model loaded successfully from local path.")
59
+ except Exception as e:
60
+ print(f"Failed to load local FFT model: {e}")
61
+ else:
62
+ print("No local FFT model path configured or file missing; FFT model not loaded.")
63
+ else:
64
+ print("Skipping FFT model load because PyTorch is not installed.")
65
 
66
+ # Load document forgery model (ELA CNN) from local path if present
67
+ self.doc_model = None
68
+ if doc_model_path is None:
69
+ doc_model_path = Config.DOCUMENT_FORGERY_MODEL_PATH
70
 
71
+ self.doc_model = None
72
+ if TORCH_AVAILABLE:
73
+ try:
74
+ self.doc_model = self._load_document_forgery_model(Path(doc_model_path))
75
+ if self.doc_model is not None:
76
+ print("Document forgery (ELA) model loaded successfully.")
77
+ except Exception as e:
78
+ print(f"Warning: failed to load document forgery model: {e}")
79
+ else:
80
+ print("Skipping document forgery model load because PyTorch is not installed.")
81
+
82
+ def _load_fft_model(self, repo_id: str, filename: str):
83
+ """Downloads and loads the FFT CNN model from a Hugging Face Hub repository."""
84
+ print(f"Attempting to download FFT CNN model from Hugging Face repo: {repo_id}")
85
+ try:
86
+ from huggingface_hub import hf_hub_download
87
+ except Exception as e:
88
+ raise RuntimeError(f"huggingface_hub not available: {e}")
89
 
 
 
 
 
90
  try:
 
91
  model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=Config.HF_TOKEN)
92
  print(f"Model downloaded to: {model_path}")
93
+
 
94
  model = FFTCNN()
 
 
95
  model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
 
 
96
  model.to(self.device)
97
  model.eval()
 
98
  return model
99
  except Exception as e:
100
+ print(f"Error downloading or loading FFT model from Hugging Face: {e}")
101
  raise
102
 
103
+ def _load_document_forgery_model(self, path: Path):
104
+ """Load the ELA-based document forgery model from a local .pth checkpoint.
105
+
106
+ Returns the model instance or None if the file does not exist.
107
+ """
108
+ # If the configured path doesn't exist, try sensible fallbacks in the repo.
109
+ if not path.exists():
110
+ print(f"Document forgery model file not found at configured path: {path}")
111
+
112
+ # 1) Try features/Model/document_forgery/ela_cnn_model.pth relative to repo root
113
+ repo_root = Path(__file__).resolve().parents[2]
114
+ candidate = repo_root / 'features' / 'Model' / 'document_forgery' / 'ela_cnn_model.pth'
115
+ if candidate.exists():
116
+ path = candidate
117
+ print(f"Found document forgery model at fallback path: {path}")
118
+ else:
119
+ # 2) Search the repo for any file named ela_cnn_model.pth
120
+ print("Searching repository for 'ela_cnn_model.pth'...")
121
+ matches = list(repo_root.rglob('ela_cnn_model.pth'))
122
+ if matches:
123
+ path = matches[0]
124
+ print(f"Found document forgery model at: {path}")
125
+ else:
126
+ print("Document forgery model not found in repository; skipping load.")
127
+ return None
128
+
129
+ print(f"Loading document forgery model from: {path}")
130
+ # Build the ELA model architecture lazily (requires torchvision & torch.nn)
131
+ try:
132
+ import torchvision.models as tv_models
133
+ import torch.nn as nn
134
+ except Exception as e:
135
+ raise RuntimeError(f"Required packages for ELA model not available: {e}")
136
+
137
+ backbone = tv_models.efficientnet_b0(weights='IMAGENET1K_V1')
138
+ in_features = backbone.classifier[1].in_features
139
+ backbone.classifier = nn.Sequential(
140
+ nn.Dropout(p=0.4),
141
+ nn.Linear(in_features, 256),
142
+ nn.ReLU(inplace=True),
143
+ nn.Dropout(p=0.2),
144
+ nn.Linear(256, 2),
145
+ )
146
+ model = backbone
147
+ state = torch.load(str(path), map_location=torch.device(self.device))
148
+
149
+ # The checkpoint might be either a state_dict or a full checkpoint dict
150
+ if isinstance(state, dict) and 'state_dict' in state:
151
+ state_dict = state['state_dict']
152
+ else:
153
+ state_dict = state
154
+
155
+ # Attempt to load state dict; allow strict=False to be tolerant to minor key name differences
156
+ model.load_state_dict(state_dict, strict=False)
157
+ model.to(self.device)
158
+ model.eval()
159
+ return model
160
+
161
+
162
  # --- Global Model Instance ---
163
  MODEL_REPO_ID = Config.REAL_FORGED_MODEL_REPO_ID
164
  MODEL_FILENAME = Config.REAL_FORGED_MODEL_FILENAME
165
+ DOC_MODEL_PATH = Config.DOCUMENT_FORGERY_MODEL_PATH
166
+ models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME, doc_model_path=DOC_MODEL_PATH)
167
 
features/real_forged_classifier/preprocessor.py CHANGED
@@ -6,7 +6,7 @@ import cv2
6
  from torchvision import transforms
7
 
8
  # Import the globally loaded models instance
9
- from model_loader import models
10
 
11
  class ImagePreprocessor:
12
  """
 
6
  from torchvision import transforms
7
 
8
  # Import the globally loaded models instance
9
+ from .model_loader import models
10
 
11
  class ImagePreprocessor:
12
  """
features/real_forged_classifier/routes.py CHANGED
@@ -1,8 +1,8 @@
1
  from fastapi import APIRouter, File, UploadFile, HTTPException, status
2
  from fastapi.responses import JSONResponse
3
 
4
- # Import the controller instance
5
- from controller import controller
6
 
7
  # Create an API router
8
  router = APIRouter()
@@ -35,3 +35,20 @@ async def classify_image_endpoint(image: UploadFile = File(...)):
35
 
36
  return JSONResponse(content=result, status_code=status.HTTP_200_OK)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, File, UploadFile, HTTPException, status
2
  from fastapi.responses import JSONResponse
3
 
4
+ # Import the controller instance and document forger
5
+ from .controller import controller, document_forger
6
 
7
  # Create an API router
8
  router = APIRouter()
 
35
 
36
  return JSONResponse(content=result, status_code=status.HTTP_200_OK)
37
 
38
+ @router.post("/isforged", summary="Check if the document is forged")
39
+ async def is_forged_endpoint(file: UploadFile = File(...)):
40
+ """Run the document forgery detector on an uploaded image file.
41
+
42
+ Accepts image uploads (multipart/form-data) and returns a JSON verdict with confidence.
43
+ """
44
+ if not file.content_type.startswith("image/"):
45
+ raise HTTPException(
46
+ status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
47
+ detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
48
+ )
49
+
50
+ result = document_forger.is_forged(file.file)
51
+ if isinstance(result, dict) and result.get("error"):
52
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.get("error"))
53
+
54
+ return JSONResponse(content=result, status_code=status.HTTP_200_OK)