| """
|
| CitationModule: Understands scientific citation structure.
|
| Detects citation spans, tracks provenance, and estimates claim confidence.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import re
|
| from typing import Optional, Tuple, List
|
|
|
|
|
| class CitationModule(nn.Module):
|
| """
|
| Understands scientific citation structure.
|
| - Detects citation spans [Author, Year] or (1) style
|
| - Learns that cited claims carry different epistemic weight
|
| - Distinguishes established facts vs recent/contested findings
|
| - Tracks claim provenance through the context window
|
| """
|
|
|
| def __init__(self, d_model: int):
|
| """
|
| Initialize CitationModule.
|
|
|
| Args:
|
| d_model: Model dimension
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
|
|
|
|
|
|
|
|
| self.citation_detector = nn.Linear(d_model, 3)
|
|
|
|
|
| self.provenance_gate = nn.Linear(d_model, d_model)
|
|
|
|
|
| self.confidence_head = nn.Linear(d_model, 1)
|
|
|
|
|
| self.citation_type_embedding = nn.Embedding(3, d_model)
|
|
|
|
|
| self._initialize_weights()
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights."""
|
| for module in [self.citation_detector, self.provenance_gate, self.confidence_head, self.citation_type_embedding]:
|
| if hasattr(module, 'weight'):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if hasattr(module, 'bias') and module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| def detect_citation_spans(
|
| self,
|
| text: str,
|
| ) -> List[Tuple[int, int, str]]:
|
| """
|
| Detect citation spans in text.
|
| Supports: (Author, Year), [1], [Author, Year], et al.
|
|
|
| Args:
|
| text: Input text string
|
|
|
| Returns:
|
| List of (start_char, end_char, citation_type)
|
| citation_type: "inline" or "reference"
|
| """
|
| spans = []
|
|
|
|
|
| for match in re.finditer(r'\([A-Za-z\s]+(?:et al\.)?,?\s*\d{4}\)', text):
|
| spans.append((match.start(), match.end(), "inline"))
|
|
|
|
|
| for match in re.finditer(r'\[\d+(?:[-,]\d+)*\]', text):
|
| spans.append((match.start(), match.end(), "inline"))
|
|
|
|
|
| for match in re.finditer(r'\[[A-Za-z\s]+,?\s*\d{4}\]', text):
|
| spans.append((match.start(), match.end(), "inline"))
|
|
|
|
|
| for match in re.finditer(r'\bet al\.\b', text):
|
| spans.append((match.start(), match.end(), "inline"))
|
|
|
| return spans
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| text: Optional[List[str]] = None,
|
| citation_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through citation module.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| text: Optional original text strings
|
| citation_spans: Optional pre-computed citation spans per batch
|
|
|
| Returns:
|
| Citation-enhanced representation (batch, seq_len, d_model)
|
| """
|
| batch, seq_len, d_model = x.shape
|
|
|
|
|
| if citation_spans is None and text is not None:
|
| citation_spans = []
|
| for b in range(batch):
|
| spans = self.detect_citation_spans(text[b])
|
|
|
| token_spans = []
|
| for start_char, end_char, ctype in spans:
|
| start_tok = max(0, start_char // 4)
|
| end_tok = min(seq_len, end_char // 4 + 1)
|
| token_spans.append((start_tok, end_tok, ctype))
|
| citation_spans.append(token_spans)
|
|
|
|
|
| citation_logits = self.citation_detector(x)
|
| citation_probs = F.softmax(citation_logits, dim=-1)
|
|
|
|
|
| output = x.clone()
|
|
|
| if citation_spans:
|
| for b in range(batch):
|
| spans_b = citation_spans[b] if b < len(citation_spans) else []
|
|
|
| for start_tok, end_tok, ctype in spans_b:
|
| if end_tok <= start_tok:
|
| continue
|
|
|
|
|
| if ctype == "inline":
|
| type_id = 1
|
| elif ctype == "reference":
|
| type_id = 2
|
| else:
|
| type_id = 0
|
|
|
| type_emb = self.citation_type_embedding(
|
| torch.tensor(type_id, device=x.device)
|
| )
|
|
|
|
|
| span_slice = x[b, start_tok:end_tok, :]
|
| gated = span_slice * torch.sigmoid(self.provenance_gate(span_slice))
|
|
|
|
|
| gated = gated + type_emb.unsqueeze(0).unsqueeze(0)
|
|
|
| output[b, start_tok:end_tok, :] = gated
|
|
|
|
|
| confidence = torch.sigmoid(self.confidence_head(x))
|
|
|
| return output, confidence
|
|
|
| def compute_citation_loss(
|
| self,
|
| x: torch.Tensor,
|
| citation_mask: torch.Tensor,
|
| confidence: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Compute auxiliary loss for citation detection and confidence.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| citation_mask: Ground truth citation mask (batch, seq_len), 1 if token is in citation
|
| confidence: Predicted confidence scores (batch, seq_len, 1)
|
|
|
| Returns:
|
| Combined citation loss
|
| """
|
|
|
| logits = self.citation_detector(x)
|
| detection_loss = F.cross_entropy(
|
| logits.view(-1, 3),
|
| citation_mask.long().view(-1),
|
| )
|
|
|
|
|
| confidence_loss = F.mse_loss(
|
| confidence.squeeze(-1),
|
| citation_mask.float(),
|
| )
|
|
|
| return detection_loss + 0.1 * confidence_loss
|
|
|
|
|
| def test_citation_module():
|
| """Test CitationModule."""
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 128
|
|
|
| module = CitationModule(d_model)
|
|
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = [
|
| "The theory of relativity (Einstein, 1905) revolutionized physics. See also [1, 2].",
|
| "According to Smith et al., the results are significant. Further reading: [Doe, 2020]."
|
| ]
|
|
|
| output, confidence = module(x, text=text)
|
| print(f"Input shape: {x.shape}")
|
| print(f"Output shape: {output.shape}")
|
| print(f"Confidence shape: {confidence.shape}")
|
| assert output.shape == x.shape
|
| assert confidence.shape == (batch_size, seq_len, 1)
|
|
|
|
|
| citation_mask = torch.zeros(batch_size, seq_len)
|
| citation_mask[0, 20:25] = 1.0
|
| citation_mask[1, 10:18] = 1.0
|
| loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| print(f"Citation loss: {loss.item():.4f}")
|
|
|
| print("CitationModule test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_citation_module()
|
|
|