from __future__ import annotations from pathlib import Path from threading import Thread from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from qwen_vl_utils import process_vision_info from transformers import ( AutoProcessor, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration, ) DEFAULT_MODEL_PATH = "./checkpoints" DEFAULT_SYSTEM_PROMPT = ( "You are a professional AI dermatology assistant. " "Reason step by step, keep the reasoning concise, avoid repetition, " "and always finish with ...." ) DEFAULT_MAX_NEW_TOKENS = 768 DEFAULT_CONTINUE_TOKENS = 256 DEFAULT_DO_SAMPLE = False DEFAULT_TEMPERATURE = 0.2 DEFAULT_TOP_P = 0.9 DEFAULT_REPETITION_PENALTY = 1.15 DEFAULT_NO_REPEAT_NGRAM_SIZE = 3 DEFAULT_PROMPT = ( "Act as a dermatologist. Analyze the visual features of this skin lesion " "step by step, and provide a final diagnosis." ) def resolve_model_path(model_path: str = DEFAULT_MODEL_PATH) -> str: raw_path = Path(model_path).expanduser() repo_root = Path(__file__).resolve().parents[2] candidates = [raw_path] if not raw_path.is_absolute(): candidates.append(Path.cwd() / raw_path) candidates.append(repo_root / raw_path) if raw_path.parts and raw_path.parts[0] == repo_root.name: candidates.append(repo_root.joinpath(*raw_path.parts[1:])) for candidate in candidates: if candidate.exists(): return str(candidate) return str(raw_path) def build_single_turn_messages( image_path: str, prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT, ) -> list[dict]: return [ { "role": "user", "content": [ {"type": "image", "image": image_path}, {"type": "text", "text": f"{system_prompt}\n\n{prompt}"}, ], } ] def build_quantization_config() -> BitsAndBytesConfig: return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) def resolve_quantized_device_map(): if not torch.cuda.is_available(): raise RuntimeError("INT4 quantized inference requires a CUDA GPU.") return {"": f"cuda:{torch.cuda.current_device()}"} class StopOnTokenSequence(StoppingCriteria): def __init__(self, stop_ids: list[int]): super().__init__() self.stop_ids = stop_ids self.stop_length = len(stop_ids) def __call__(self, input_ids, scores, **kwargs) -> bool: if self.stop_length == 0 or input_ids.shape[1] < self.stop_length: return False return input_ids[0, -self.stop_length :].tolist() == self.stop_ids class ExpertBlock(nn.Module): def __init__(self, hidden_dim, bottleneck_dim=64): super().__init__() self.net = nn.Sequential( nn.Linear(hidden_dim, bottleneck_dim), nn.ReLU(), nn.Linear(bottleneck_dim, hidden_dim), ) def forward(self, x): return self.net(x) class SkinAwareMoEAdapter(nn.Module): def __init__(self, hidden_dim, num_experts=8, top_k=2, bottleneck_dim=64): super().__init__() self.num_experts = num_experts self.top_k = top_k self.router_img = nn.Linear(hidden_dim, num_experts, bias=False) self.router_skin = nn.Linear(3, num_experts, bias=False) self.experts = nn.ModuleList( [ExpertBlock(hidden_dim, bottleneck_dim) for _ in range(num_experts)] ) def forward(self, x: torch.Tensor, skin_probs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: img_logits = self.router_img(x) skin_bias = self.router_skin(skin_probs) router_logits = img_logits + skin_bias router_probs = F.softmax(router_logits, dim=-1) top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1) top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-6) final_output = torch.zeros_like(x) for expert_idx, expert in enumerate(self.experts): expert_mask = top_k_indices == expert_idx if expert_mask.any(): rows, k_indices = torch.where(expert_mask) inp = x[rows] out = expert(inp) weights = top_k_probs[rows, k_indices].unsqueeze(-1) final_output.index_add_(0, rows, (out * weights).to(final_output.dtype)) mean_prob = router_probs.mean(0) mask_all = torch.zeros_like(router_probs) mask_all.scatter_(1, top_k_indices, 1.0) mean_freq = mask_all.mean(0) aux_loss = (mean_prob * mean_freq).sum() * self.num_experts return x + final_output, aux_loss class PatchDistillHead(nn.Module): def __init__( self, embed_dim: int = 1024, adapter_layers: int = 4, in_dim: Optional[int] = None, out_dim: Optional[int] = None, num_experts: int = 8, top_k: int = 2, ): super().__init__() self.embed_dim = embed_dim self.in_proj = None if in_dim is None else nn.Linear(in_dim, embed_dim, bias=False) self.skin_classifier = nn.Sequential( nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 3), ) self.adapters = nn.ModuleList( [ SkinAwareMoEAdapter(embed_dim, num_experts=num_experts, top_k=top_k) for _ in range(adapter_layers) ] ) self.out_proj: nn.Module = ( nn.Identity() if out_dim is None else nn.Linear(embed_dim, out_dim) ) def _ensure_in_proj(self, din: int, device, dtype): if self.in_proj is None: self.in_proj = nn.Linear(din, self.embed_dim, bias=False).to(device=device, dtype=dtype) def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> dict: _, din = pixel_values.shape counts = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2]).tolist() device, dtype = pixel_values.device, pixel_values.dtype self._ensure_in_proj(din, device, dtype) chunks = torch.split(pixel_values, counts, dim=0) pooled, all_skin_logits = [], [] total_aux_loss = torch.tensor(0.0, device=device, dtype=dtype) for x in chunks: h = self.in_proj(x) global_feat = h.mean(dim=0, keepdim=True) skin_logits = self.skin_classifier(global_feat) skin_probs = F.softmax(skin_logits, dim=-1) all_skin_logits.append(skin_logits) skin_probs_expanded = skin_probs.expand(h.size(0), -1) for adapter in self.adapters: h, layer_loss = adapter(h, skin_probs_expanded) total_aux_loss += layer_loss pooled.append(h.mean(dim=0)) vision_embed = torch.stack(pooled, dim=0) vision_proj = self.out_proj(vision_embed) return { "vision_embed": vision_embed, "vision_proj": vision_proj, "aux_loss": total_aux_loss, "skin_logits": torch.cat(all_skin_logits, dim=0), } def configure_out_dim(self, out_dim: int): if isinstance(self.out_proj, nn.Linear) and self.out_proj.out_features == out_dim: return self.out_proj = ( nn.Linear(self.embed_dim, out_dim, bias=False) if out_dim != self.embed_dim else nn.Identity() ) try: params = next(self.parameters()) self.out_proj.to(device=params.device, dtype=params.dtype) except StopIteration: pass class SkinVLModelWithAdapter(Qwen2_5_VLForConditionalGeneration): def __init__(self, config): super().__init__(config) self.distill_head = PatchDistillHead( embed_dim=1024, adapter_layers=4, num_experts=8, top_k=2, in_dim=1176, ) bottleneck = 64 self.text_bias = nn.Sequential( nn.Linear(1024, bottleneck, bias=False), nn.Tanh(), nn.Linear(bottleneck, config.hidden_size, bias=False), ) self.logit_bias_scale = nn.Parameter(torch.tensor(2.5, dtype=torch.bfloat16)) def forward(self, *args, **kwargs): skin_vocab_mask = kwargs.pop("skin_vocab_mask", None) skin_labels = kwargs.get("skin_labels", None) pixel_values = kwargs.get("pixel_values", None) image_grid_thw = kwargs.get("image_grid_thw", None) if isinstance(pixel_values, list): try: pixel_values = torch.stack(pixel_values) kwargs["pixel_values"] = pixel_values except Exception: pass outputs = super().forward(*args, **kwargs) vision_embed = None loss_skin = torch.tensor(0.0, device=outputs.logits.device) aux_loss = torch.tensor(0.0, device=outputs.logits.device) if pixel_values is not None and image_grid_thw is not None: if not isinstance(pixel_values, torch.Tensor): if isinstance(pixel_values, list): pixel_values = torch.stack(pixel_values) else: pixel_values = torch.tensor(pixel_values) image_grid_thw = image_grid_thw.to(pixel_values.device) side = self.distill_head(pixel_values=pixel_values, image_grid_thw=image_grid_thw) vision_embed = side["vision_embed"] aux_loss = side["aux_loss"] if skin_labels is not None: skin_labels = skin_labels.to(side["skin_logits"].device) loss_skin = nn.CrossEntropyLoss()(side["skin_logits"], skin_labels) setattr(outputs, "vision_embed", vision_embed) setattr(outputs, "vision_proj", side["vision_proj"]) setattr(outputs, "loss_skin", loss_skin) setattr(outputs, "aux_loss", aux_loss) setattr(outputs, "skin_logits", side["skin_logits"]) pack_vision_proj = ( side["vision_proj"] if side["vision_proj"] is not None else torch.tensor(0.0, device=aux_loss.device) ) pack_skin_logits = ( side["skin_logits"] if side["skin_logits"] is not None else torch.tensor(0.0, device=aux_loss.device) ) outputs.attentions = (pack_vision_proj, aux_loss, pack_skin_logits) self.latest_side_output = { "vision_proj": side["vision_proj"], "aux_loss": aux_loss, "skin_logits": side["skin_logits"], } if hasattr(outputs, "logits") and vision_embed is not None and skin_vocab_mask is not None: bias_features = self.text_bias(vision_embed.to(self.logit_bias_scale.dtype)) lm_weight = self.lm_head.weight.to(bias_features.dtype) vocab_bias = F.linear(bias_features, lm_weight) scale = self.logit_bias_scale.to(outputs.logits.dtype) outputs.logits = outputs.logits + (scale * vocab_bias[:, None, :] * skin_vocab_mask) if outputs.loss is not None: outputs.loss = outputs.loss + loss_skin + (0.01 * aux_loss) return outputs def freeze_all_but_distill(self): self.requires_grad_(False) for params in self.distill_head.parameters(): params.requires_grad_(True) for params in self.text_bias.parameters(): params.requires_grad_(True) self.logit_bias_scale.requires_grad_(True) def configure_out_dim(self, out_dim: int): self.distill_head.configure_out_dim(out_dim) def project_only(self, vision_embed: torch.Tensor) -> torch.Tensor: return self.distill_head.out_proj(vision_embed) def load_quantized_model_and_processor(model_path: str = DEFAULT_MODEL_PATH): resolved_model_path = resolve_model_path(model_path) quantization_config = build_quantization_config() model = SkinVLModelWithAdapter.from_pretrained( resolved_model_path, device_map=resolve_quantized_device_map(), quantization_config=quantization_config, attn_implementation="sdpa", ) model.eval() processor = AutoProcessor.from_pretrained( resolved_model_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28, ) return model, processor def get_model_device(model) -> torch.device: try: return model.device except AttributeError: return next(model.parameters()).device def prepare_inputs(processor, model, messages: list[dict]): text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(get_model_device(model)) inputs.pop("mm_token_type_ids", None) return inputs class QuantizedSkinGPTModel: def __init__(self, model_path: str = DEFAULT_MODEL_PATH): resolved_model_path = resolve_model_path(model_path) print(f"Loading INT4 model from {resolved_model_path}...") self.model, self.processor = load_quantized_model_and_processor(resolved_model_path) self.model_path = resolved_model_path self.device = get_model_device(self.model) self.stop_ids = self.processor.tokenizer.encode("", add_special_tokens=False) print(f"Model loaded successfully on {self.device}.") @staticmethod def has_complete_answer(text: str) -> bool: return "" in text and "" in text def _build_generation_kwargs( self, inputs, max_new_tokens: int, do_sample: bool, temperature: float, repetition_penalty: float, top_p: float, no_repeat_ngram_size: int, streamer=None, ) -> dict: generation_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "do_sample": do_sample, "repetition_penalty": repetition_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "use_cache": True, "stopping_criteria": StoppingCriteriaList([StopOnTokenSequence(self.stop_ids)]), } if streamer is not None: generation_kwargs["streamer"] = streamer if do_sample: generation_kwargs["temperature"] = temperature generation_kwargs["top_p"] = top_p return generation_kwargs def _generate_text( self, messages, max_new_tokens: int, do_sample: bool, temperature: float, repetition_penalty: float, top_p: float, no_repeat_ngram_size: int, ) -> str: inputs = prepare_inputs(self.processor, self.model, messages) generation_kwargs = self._build_generation_kwargs( inputs=inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, no_repeat_ngram_size=no_repeat_ngram_size, ) with torch.inference_mode(): generated_ids = self.model.generate(**generation_kwargs) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return output_text[0] def generate_response( self, messages, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, continue_tokens: int = DEFAULT_CONTINUE_TOKENS, do_sample: bool = DEFAULT_DO_SAMPLE, temperature: float = DEFAULT_TEMPERATURE, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, top_p: float = DEFAULT_TOP_P, no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE, ) -> str: output_text = self._generate_text( messages=messages, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, no_repeat_ngram_size=no_repeat_ngram_size, ) if not self.has_complete_answer(output_text) and continue_tokens > 0: output_text = self._generate_text( messages=messages, max_new_tokens=max_new_tokens + continue_tokens, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, no_repeat_ngram_size=no_repeat_ngram_size, ) return output_text def generate_response_stream( self, messages, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, continue_tokens: int = DEFAULT_CONTINUE_TOKENS, do_sample: bool = DEFAULT_DO_SAMPLE, temperature: float = DEFAULT_TEMPERATURE, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, top_p: float = DEFAULT_TOP_P, no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE, ): inputs = prepare_inputs(self.processor, self.model, messages) streamer = TextIteratorStreamer( self.processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = self._build_generation_kwargs( inputs=inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, no_repeat_ngram_size=no_repeat_ngram_size, streamer=streamer, ) def _generate(): with torch.inference_mode(): self.model.generate(**generation_kwargs) thread = Thread(target=_generate) thread.start() partial_chunks = [] for text_chunk in streamer: partial_chunks.append(text_chunk) yield text_chunk thread.join() partial_text = "".join(partial_chunks) if not self.has_complete_answer(partial_text) and continue_tokens > 0: completed_text = self._generate_text( messages=messages, max_new_tokens=max_new_tokens + continue_tokens, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, no_repeat_ngram_size=no_repeat_ngram_size, ) if completed_text.startswith(partial_text): tail_text = completed_text[len(partial_text) :] if tail_text: yield tail_text