Sentence Similarity
Transformers
Safetensors
PEFT
sentence-transformers
English
lora
reinforcement-learning
domain-adaptation
sentence-embeddings
curriculum-learning
multi-task-learning
rag
information-retrieval
cross-domain
Eval Results (legacy)
Instructions to use EphAsad/DomainEmbedder with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use EphAsad/DomainEmbedder with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("EphAsad/DomainEmbedder", dtype="auto") - PEFT
How to use EphAsad/DomainEmbedder with PEFT:
Task type is invalid.
- sentence-transformers
How to use EphAsad/DomainEmbedder with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("EphAsad/DomainEmbedder") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| license: mit | |
| language: | |
| - en | |
| library_name: transformers | |
| tags: | |
| - lora | |
| - peft | |
| - reinforcement-learning | |
| - domain-adaptation | |
| - sentence-embeddings | |
| - curriculum-learning | |
| - multi-task-learning | |
| - rag | |
| - information-retrieval | |
| - cross-domain | |
| - sentence-transformers | |
| base_model: | |
| - sentence-transformers/all-MiniLM-L6-v2 | |
| - EphAsad/FireDevourerEmbedder-RL-v3.6 | |
| pipeline_tag: sentence-similarity | |
| datasets: | |
| - sentence-transformers/stsb | |
| - nyu-mll/multi_nli | |
| - quora | |
| - google-research-datasets/paws | |
| - nyu-mll/glue | |
| - GBaker/MedQA-USMLE-4-options-hf | |
| - lex_glue | |
| - gbharti/finance-alpaca | |
| - scientific_papers | |
| model-index: | |
| - name: DomainEmbedder-v2.6 | |
| results: | |
| - task: | |
| type: domain-classification | |
| name: Domain Classification | |
| metrics: | |
| - type: accuracy | |
| value: 0.925 | |
| name: Training Accuracy | |
| - type: accuracy | |
| value: 0.56 | |
| name: Stress-Test Accuracy | |
| # DomainEmbedder-v2.6 | |
| > **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval** | |
| DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing. | |
| ## What This Model Does | |
| | Component | Description | | |
| |-----------|-------------| | |
| | **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting | | |
| | **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) | | |
| | **RL Policy** | Automatically selects the optimal domain adapter for any input | | |
| **Why this matters for RAG/Retrieval:** | |
| - Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions) | |
| - Domain routing provides context-appropriate representations | |
| - Results in more precise retrieval across diverse content types | |
| ## Key Innovation: Dual RL Architecture | |
| | Stage | RL Application | Purpose | | |
| |-------|---------------|---------| | |
| | Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training | | |
| | Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference | | |
| This dual RL approach is novel: **RL at training time AND inference time**. | |
| ## Quick Start | |
| ### Installation | |
| ```bash | |
| pip install torch transformers peft | |
| ``` | |
| ### Loading the Model | |
| ```python | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| # Device setup | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| # Define the base embedder architecture | |
| class FireDevourerEmbedder(nn.Module): | |
| def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(base_model_name) | |
| self.hidden_size = 384 | |
| # Task heads | |
| self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid()) | |
| self.nli_head = nn.Linear(384, 3) | |
| self.qqp_head = nn.Linear(384, 2) | |
| self.paws_head = nn.Linear(384, 2) | |
| self.domain_head = nn.Linear(384, 5) | |
| def mean_pool(self, token_embeddings, attention_mask): | |
| mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) | |
| def forward(self, input_ids, attention_mask, task='encode'): | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| embedding = self.mean_pool(outputs.last_hidden_state, attention_mask) | |
| if task == 'encode': | |
| return embedding | |
| elif task == 'domain': | |
| return self.domain_head(embedding) | |
| # Add other tasks as needed | |
| # Define RL Policy Network | |
| class RLPolicyNetwork(nn.Module): | |
| def __init__(self, input_dim=384, hidden_dim=128, num_actions=5): | |
| super().__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU() | |
| ) | |
| self.policy_head = nn.Linear(hidden_dim, num_actions) | |
| self.value_head = nn.Linear(hidden_dim, 1) | |
| def forward(self, x): | |
| features = self.network(x) | |
| policy = torch.softmax(self.policy_head(features), dim=-1) | |
| value = self.value_head(features) | |
| return policy, value | |
| # Load model | |
| model_dir = "path/to/DomainEmbedder-v2.6" | |
| # 1. Load base model with checkpoint | |
| base_model = FireDevourerEmbedder() | |
| checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device) | |
| base_model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| base_model.to(device) | |
| base_model.eval() | |
| # 2. Load RL policy | |
| rl_policy = RLPolicyNetwork() | |
| rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device) | |
| rl_policy.load_state_dict(rl_checkpoint['policy_state_dict']) | |
| rl_policy.to(device) | |
| rl_policy.eval() | |
| # 3. Load LoRA adapters (example: medical) | |
| from peft import PeftModel | |
| lora_model = PeftModel.from_pretrained( | |
| base_model.encoder, | |
| f"{model_dir}/medical_lora" | |
| ) | |
| ``` | |
| ### Computing Embeddings with Domain Selection | |
| ```python | |
| def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device): | |
| """Get domain-aware embedding for input text.""" | |
| # Tokenize | |
| inputs = tokenizer(text, return_tensors='pt', padding=True, | |
| truncation=True, max_length=512).to(device) | |
| # Get base embedding | |
| with torch.no_grad(): | |
| base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode') | |
| # Get domain selection from RL policy | |
| policy_probs, _ = rl_policy(base_emb) | |
| domain_idx = torch.argmax(policy_probs, dim=-1).item() | |
| domains = ['medical', 'legal', 'code', 'finance', 'scientific'] | |
| selected_domain = domains[domain_idx] | |
| confidence = policy_probs[0, domain_idx].item() | |
| return { | |
| 'embedding': base_emb, | |
| 'domain': selected_domain, | |
| 'confidence': confidence, | |
| 'all_probs': policy_probs[0].cpu().numpy() | |
| } | |
| # Example usage | |
| result = get_domain_embedding( | |
| "What are the symptoms of diabetes?", | |
| base_model, rl_policy, None, tokenizer, device | |
| ) | |
| print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})") | |
| ``` | |
| ## Architecture | |
| ``` | |
| Input Text | |
| β | |
| βΌ | |
| ββββββββββββββββββββββββββββββββββββββββββββββ | |
| β MiniLM-L6-v2 Encoder (FROZEN) β | |
| β + Optional LoRA Adapter (domain-specific) β | |
| β 384-dimensional output β | |
| ββββββββββββββββββββββββββββββββββββββββββββββ | |
| β | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| β β | |
| βΌ βΌ | |
| βββββββββββββββββββ ββββββββββββββββββββ | |
| β Base Embedding β β RL Policy Net β | |
| β (384-dim) β β (66K params) β | |
| βββββββββββββββββββ ββββββββββββββββββββ | |
| β | |
| βΌ | |
| Domain Selection | |
| [Medical, Legal, Code, | |
| Finance, Scientific] | |
| β | |
| βΌ | |
| Load corresponding LoRA adapter | |
| β | |
| βΌ | |
| Domain-Adapted Embedding | |
| ``` | |
| ### Component Details | |
| | Component | Specification | | |
| |-----------|---------------| | |
| | Base Encoder | MiniLM-L6-v2 (22M params) | | |
| | Embedding Dim | 384 | | |
| | LoRA Rank | 16 | | |
| | LoRA Alpha | 32 | | |
| | LoRA Target | Query, Value projections | | |
| | LoRA Params | 147,456 per adapter (0.645%) | | |
| | RL Policy | 66,566 params | | |
| | Domains | Medical, Legal, Code, Finance, Scientific | | |
| ## Performance | |
| ### Base Model: Multi-Task Embedding Quality | |
| The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks: | |
| | Task | Dataset | Score | What It Measures | | |
| |------|---------|-------|------------------| | |
| | Question Similarity | QQP | 0.8636 | Intent matching | | |
| | Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness | | |
| | Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase | | |
| | NLI | MultiNLI | 0.7465 | Logical relationships | | |
| | Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity | | |
| | **Average** | | **0.7134** | **Cross-task capability** | | |
| **Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content. | |
| ### Domain Routing Accuracy | |
| **Training Results (In-Distribution)** | |
| | Metric | Value | | |
| |--------|-------| | |
| | Domain Accuracy | 92.5% | | |
| | Average Reward | 1.527 | | |
| | Training Steps | 5,000 | | |
| **Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)** | |
| The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness: | |
| | Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement | | |
| |--------|--------------------------|------------|-------------| | |
| | Domain Accuracy | 56.0% | 20.4% | **+35.6%** | | |
| | Avg Confidence | 28.5% | 77.6% | More calibrated | | |
| ### Per-Domain Breakdown | |
| | Domain | DomainEmbedder | Base Model | Note | | |
| |--------|----------------|------------|------| | |
| | Finance | 78.0% | 0.0% | +78.0% | | |
| | Medical | 73.0% | 0.0% | +73.0% | | |
| | Legal | 53.0% | 15.0% | +38.0% | | |
| | Scientific | 48.0% | 1.0% | +47.0% | | |
| | Code | 28.0% | 86.0% | Base over-predicted code | | |
| **Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution. | |
| ## Training Details | |
| ### Domain Training Data | |
| | Domain | Samples | Sources | | |
| |--------|---------|---------| | |
| | Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor | | |
| | Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B | | |
| | Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions | | |
| | Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA | | |
| | Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) | | |
| | **Total** | **200,000** | | | |
| ### LoRA Training Configuration | |
| | Parameter | Value | | |
| |-----------|-------| | |
| | Epochs | 3 per domain | | |
| | Batch Size | 32 | | |
| | Learning Rate | 2e-4 | | |
| | Loss | Contrastive (InfoNCE-style) | | |
| | Trainable Params | 147,456 (0.645% of base) | | |
| | Warmup Steps | 500 | | |
| | Max Gradient Norm | 1.0 | | |
| ### RL Training (Supervised A2C) | |
| | Parameter | Value | | |
| |-----------|-------| | |
| | Algorithm | Actor-Critic (A2C) | | |
| | Total Steps | 5,000 | | |
| | Episodes per Step | 5 | | |
| | Gamma (discount) | 0.99 | | |
| | Entropy Coef | 0.1 (high exploration) | | |
| | Value Coef | 0.5 | | |
| | Correctness Bonus | +1.0 | | |
| | Correctness Penalty | -0.5 | | |
| | Baseline Decay | 0.99 | | |
| ### Curriculum Learning Phases | |
| | Phase | Steps | Data | Accuracy | | |
| |-------|-------|------|----------| | |
| | 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% β 87.5% | | |
| | 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% β 89.3% | | |
| | 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% β 92.5% | | |
| ### Training Progress | |
| | Version | Step | Accuracy | Reward | | |
| |---------|------|----------|--------| | |
| | v2.1 | 500 | 68.8% | 1.100 | | |
| | v2.2 | 1,000 | 80.1% | 1.336 | | |
| | v2.3 | 1,500 | 87.5% | 1.454 | | |
| | v2.4 | 2,000 | 88.9% | 1.480 | | |
| | v2.5 | 3,000 | 89.3% | 1.507 | | |
| | **v2.6** | **4,000** | **92.5%** | **1.527** | | |
| ## Package Contents | |
| ``` | |
| DomainEmbedder-v2.6/ | |
| βββ FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB) | |
| βββ rl_policy.pt # Trained RL policy (0.27 MB) | |
| βββ metadata.json # Training metadata | |
| βββ README.md # This file | |
| βββ medical_lora/ # Medical domain adapter (0.6 MB) | |
| β βββ adapter_config.json | |
| β βββ adapter_model.safetensors | |
| βββ legal_lora/ # Legal domain adapter (0.6 MB) | |
| βββ code_lora/ # Code domain adapter (0.6 MB) | |
| βββ finance_lora/ # Finance domain adapter (0.6 MB) | |
| βββ scientific_lora/ # Scientific domain adapter (0.6 MB) | |
| ``` | |
| **Total Size**: ~90 MB (self-contained) | |
| ## Intended Use | |
| ### Best Use Cases | |
| - **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases | |
| - **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains | |
| - **Document Classification**: Automatic domain routing for document processing pipelines | |
| - **Semantic Similarity**: Information-dense embeddings for precise matching | |
| - **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain | |
| ### Limitations | |
| - **English Only**: Trained exclusively on English data | |
| - **Max Length**: 512 tokens maximum input length | |
| - **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific) | |
| - **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries | |
| - **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage | |
| ## Citation | |
| ```bibtex | |
| @misc{domainembedder2025, | |
| author = {Asad, Zain}, | |
| title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA}, | |
| year = {2025}, | |
| publisher = {Hugging Face}, | |
| note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning} | |
| } | |
| ``` | |
| ## Author | |
| **Zain Asad** | |
| ## License | |
| MIT License |