import torch.nn as nn from peft import LoraConfig, LoraModel def wrap_lora(module: nn.Module, r: int, lora_alpha: float, lora_dropout: float) -> nn.Module: # these modules handle ESM++ and ESM2 attention types, as well as any additional transformer blocks from Syndev target_modules=["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"] lora_config = LoraConfig( r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", target_modules=target_modules, ) module = LoraModel(module, lora_config, 'default') for name, param in module.named_parameters(): if 'classifier' in name.lower(): param.requires_grad = True return module