GLEN-model / src /tevatron /modeling /encoder_decoder_seq2seq.py
QuanTH02's picture
Phase 1
3d5551b
import os
import time
import copy
import logging
import torch
import torch.nn.functional as F
import torch.distributed as dist
import pandas as pd
from collections import defaultdict
from dataclasses import dataclass
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.file_utils import ModelOutput
from typing import Dict, Optional
from tevatron.tree import TreeBuilder
from tevatron.arguments import (
GLENP2ModelArguments as ModelArguments,
GLENP2TrainingArguments as TrainingArguments,
)
logger = logging.getLogger(__name__)
@dataclass
class EncoderDecoderOutputForSeq2SeqLM(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class EncoderDecoderModelForSeq2SeqLM(nn.Module):
TRANSFORMER_CLS = AutoModelForSeq2SeqLM
def __init__(
self,
lm_q: PreTrainedModel,
lm_p: PreTrainedModel,
untie_encoder: bool = False,
negatives_x_device: bool = False,
tokenizer=None,
):
super().__init__()
self.lm_q = lm_q
self.lm_p = lm_p
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
self.negatives_x_device = negatives_x_device
self.untie_encoder = untie_encoder
if self.negatives_x_device:
if not dist.is_initialized():
raise ValueError(
"Distributed training has not been initialized for representation all gather."
)
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.tokenizer = tokenizer
self.gen_r1_list = []
def forward(
self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None
):
# query, passage: input_ids (B * L), attention_mask (B * L)
decoder_start_token_id = self.config.decoder_start_token_id
if query is not None:
query["decoder_input_ids"] = (
torch.zeros(query["input_ids"].shape[0], 1, dtype=torch.long)
.fill_(decoder_start_token_id)
.to(query["input_ids"].device)
) # (B, 1)
query["decoder_attention_mask"] = torch.ones(
query["input_ids"].shape[0], 1, dtype=torch.long
).to(
query["input_ids"].device
) # (B, 1)
if passage is not None:
passage["decoder_input_ids"] = (
torch.zeros(passage["input_ids"].shape[0], 1, dtype=torch.long)
.fill_(decoder_start_token_id)
.to(passage["input_ids"].device)
) # (B * train_n_passages, 1)
passage["decoder_attention_mask"] = torch.ones(
passage["input_ids"].shape[0], 1, dtype=torch.long
).to(
passage["input_ids"].device
) # (B * train_n_passages, 1)
q_reps = self.encode_query(query)
if passage is None:
p_reps, p_reps_dt = None, None # for gradcache
else:
p_reps, p_attention, p_reps_dt = self.encode_passage(passage)
predid2oldid = self.trainer.train_dataset.predid2oldid
oldid2predid = self.trainer.train_dataset.oldid2predid
pred_docids = p_attention.argmax(dim=-1).cpu().tolist()
# Cache predicted ids for prefix-aware negative sampling
if (
self.trainer.data_args.negative_passage_type == "self"
and p_reps is not None
):
for oldid, predid in zip(passage["oldid"].cpu().tolist(), pred_docids):
for j in range(1, len(predid) + 1):
cur_predid = "-".join([str(x) for x in predid[:j]])
predid2oldid[cur_predid].add(oldid)
oldid2predid[oldid] = cur_predid
# for inference
if q_reps is None or p_reps is None:
return EncoderDecoderOutputForSeq2SeqLM(q_reps=q_reps, p_reps=p_reps_dt)
# for training
loss, lm_loss = 0, 0
gen_r1 = 0
if self.training:
if self.negatives_x_device:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(
q_reps, p_reps
) # (B, B * train_n_passages)
scores = scores.view(q_reps.size(0), -1) # (B, B * train_n_passages)
scores = scores / self.softmax_temperature
target = torch.arange(
scores.size(0), device=scores.device, dtype=torch.long
)
target = target * (p_reps.size(0) // q_reps.size(0))
loss = self.compute_loss(scores, target)
if "infonce_loss" in self.model_args.__dict__:
loss = loss * self.model_args.infonce_loss
if self.negatives_x_device:
loss = loss * self.world_size # counter average weight reduction
lm_logits = q_reps @ self.lm_q.shared.weight.T # (B, L, V)
if self.model_args.mask_special_tokens_for_decoding:
special_token_ids = self.tokenizer.all_special_ids
special_token_ids = [
x
for x in special_token_ids
if x
not in [
self.tokenizer.bos_token_id,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
]
]
lm_logits[:, :, special_token_ids] = -1e9
lm_logits = lm_logits / self.softmax_temperature # (B, L, V)
# For each query, select the corresponding positive document
pos_doc = torch.arange(q_reps.size(0), dtype=torch.long) # (B)
pos_doc = pos_doc * (p_reps.size(0) // q_reps.size(0)) # (B)
pos_p_attention = p_attention[
pos_doc, :, :
] # (B * train_n_passages, L, V) -> (B, L, V)
lm_targets = pos_p_attention.argmax(dim=-1) # (B, L)
if (
"q_to_docid_loss" in self.model_args.__dict__
and self.model_args.q_to_docid_loss > 0
):
lm_loss = self.cross_entropy(
lm_logits.view(-1, lm_logits.size(-1)), lm_targets.view(-1)
)
loss += lm_loss * self.model_args.q_to_docid_loss
if (
"cosine_point_loss" in self.model_args.__dict__
and self.model_args.cosine_point_loss > 0
):
pos_doc_lm_logits = (
p_reps_dt[pos_doc] @ self.lm_p.shared.weight.T
) # (B, L, V)
query_lm_logits = lm_logits # (B, L, V)
pos_doc_id_weight, pos_doc_id = pos_doc_lm_logits.max(dim=-1) # (B, L)
query_id_weight = query_lm_logits.gather(
dim=-1, index=pos_doc_id.unsqueeze(-1)
).squeeze(
-1
) # (B, L)
cosine_sim = F.cosine_similarity(
pos_doc_id_weight, query_id_weight, dim=-1
) # (B)
cosine_loss = torch.mean(1 - cosine_sim) # ()
loss += cosine_loss * self.model_args.cosine_point_loss
query_preds = lm_logits.argmax(dim=-1) # (B, L)
gen_r1 = (query_preds == lm_targets).float().mean(dim=1) == 1 # (B)
self.gen_r1_list += gen_r1.cpu().tolist()
# for eval
else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None
return EncoderDecoderOutputForSeq2SeqLM(
loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps_dt
)
def encode_passage(self, psg):
raise NotImplementedError("EncoderDecoderModel is an abstract class")
def encode_query(self, qry):
raise NotImplementedError("EncoderDecoderModel is an abstract class")
def compute_similarity(self, q_reps, p_reps):
raise NotImplementedError("EncoderModel is an abstract class")
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def build_tree(self, log_step: int = None, log_file: str = None):
docid_df = pd.read_csv(
self.model_args.docid_file_name,
sep="\t",
names=["oldid", "docid", "docid_logit", "text"],
dtype={"oldid": str, "docid": str, "docid_logit": str, "text": str},
).loc[:, ["oldid", "docid", "docid_logit"]]
lines = []
# Log the time and step
time_str = time.strftime("%Y%m%d-%H%M%S")
lines.append(f"TIME={time_str}")
lines.append(f"STEP={log_step}")
# Check the uniqueness of docid and build docid2num_docs
num_uniques = len(set(docid_df.docid))
unique_ratio = num_uniques / len(docid_df)
lines.append(
f"num_uniques: {num_uniques}/{len(docid_df)} ({unique_ratio*100:.2f}%)"
)
lines.append("[Frequent Collision]")
lines.append(docid_df.docid.value_counts()[:5].to_string())
num_unique_tokens = [0] * 10
docid2num_docs = dict(docid_df.docid.value_counts())
for docid in docid_df.docid.unique():
tokens = docid.split("<->")
num_unique_token = len(set(tokens))
num_unique_tokens[num_unique_token - 1] += 1
lines.append(f"distribution of number of unique tokens: {num_unique_tokens}")
# Build docid2oldids, oldid2docid_logit
docid2oldids = defaultdict(list)
oldid2docid_logit = dict()
for docid, oldid, docid_logit in docid_df[
["docid", "oldid", "docid_logit"]
].values:
docid2oldids[docid].append(oldid)
oldid2docid_logit[oldid] = torch.tensor(
[float(x) for x in docid_logit.split("<->")]
)
# Build oldid2docid
oldid2docid = dict(zip(docid_df.oldid, docid_df.docid))
if self.model_args.tree == 1:
builder = TreeBuilder()
all_id = []
for docid in list(oldid2docid.values()):
toks = docid.split("<->")
toks = self.tokenizer.convert_tokens_to_ids(toks)
if len(toks) != self.model_args.max_output_length - 1:
print(
f"length of docid of {docid} is not {self.model_args.max_output_length}"
)
toks = toks[: self.model_args.max_output_length - 1]
all_id.append(toks)
builder.add(toks)
self.root = builder.build()
self.docid2num_docs = docid2num_docs
self.docid2oldids = docid2oldids
self.oldid2docid_logit = oldid2docid_logit
self.oldid2docid = oldid2docid
if log_file is not None:
with open(log_file, "a") as f:
f.write("\n".join(lines) + "\n")
else:
print("\n".join(lines))
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
@classmethod
def build(
cls,
model_args: ModelArguments,
train_args: TrainingArguments,
tokenizer=None,
**hf_kwargs, # config, cache_dir
):
cls.config = hf_kwargs["config"]
cls.softmax_temperature = model_args.softmax_temperature
cls.model_args = model_args
# load local
if os.path.isdir(model_args.model_name_or_path):
if model_args.untie_encoder:
_qry_model_path = os.path.join(
model_args.model_name_or_path, "query_model"
)
_psg_model_path = os.path.join(
model_args.model_name_or_path, "passage_model"
)
if not os.path.exists(_qry_model_path):
_qry_model_path = model_args.model_name_or_path
_psg_model_path = model_args.model_name_or_path
logger.info(f"loading query model weight from {_qry_model_path}")
lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f"loading passage model weight from {_psg_model_path}")
lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs)
else:
lm_q = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path, **hf_kwargs
)
lm_p = lm_q
# load pre-trained
else:
lm_q = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path, **hf_kwargs
)
lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q
model = cls(
lm_q=lm_q,
lm_p=lm_p,
negatives_x_device=train_args.negatives_x_device,
untie_encoder=model_args.untie_encoder,
tokenizer=tokenizer,
)
return model
@classmethod
def load(
cls,
model_args: ModelArguments,
tokenizer: Optional[AutoTokenizer] = None,
**hf_kwargs, # config, cache_dir
):
cls.config = hf_kwargs["config"]
cls.softmax_temperature = model_args.softmax_temperature
cls.model_args = model_args
# load local
untie_encoder = True
if os.path.isdir(model_args.model_name_or_path):
_qry_model_path = os.path.join(model_args.model_name_or_path, "query_model")
_psg_model_path = os.path.join(
model_args.model_name_or_path, "passage_model"
)
if os.path.exists(_qry_model_path):
logger.info(f"found separate weight for query/passage encoders")
logger.info(f"loading query model weight from {_qry_model_path}")
lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs)
logger.info(f"loading passage model weight from {_psg_model_path}")
lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs)
untie_encoder = False
else:
logger.info(f"try loading tied weight")
logger.info(
f"loading model weight from {model_args.model_name_or_path}"
)
lm_q = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path, **hf_kwargs
)
lm_p = lm_q
else:
logger.info(f"try loading tied weight")
logger.info(f"loading model weight from {model_args.model_name_or_path}")
lm_q = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path, **hf_kwargs
)
lm_p = lm_q
model = cls(lm_q=lm_q, lm_p=lm_p, untie_encoder=untie_encoder)
return model
def save(self, output_dir: str):
if self.untie_encoder:
os.makedirs(os.path.join(output_dir, "query_model"))
os.makedirs(os.path.join(output_dir, "passage_model"))
self.lm_q.save_pretrained(os.path.join(output_dir, "query_model"))
self.lm_p.save_pretrained(os.path.join(output_dir, "passage_model"))
else:
self.lm_q.save_pretrained(output_dir)