|
|
import os |
|
|
import time |
|
|
import copy |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
import pandas as pd |
|
|
|
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass |
|
|
from torch import nn, Tensor |
|
|
from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
from transformers.file_utils import ModelOutput |
|
|
from typing import Dict, Optional |
|
|
|
|
|
from tevatron.tree import TreeBuilder |
|
|
from tevatron.arguments import ( |
|
|
GLENP2ModelArguments as ModelArguments, |
|
|
GLENP2TrainingArguments as TrainingArguments, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EncoderDecoderOutputForSeq2SeqLM(ModelOutput): |
|
|
q_reps: Optional[Tensor] = None |
|
|
p_reps: Optional[Tensor] = None |
|
|
loss: Optional[Tensor] = None |
|
|
scores: Optional[Tensor] = None |
|
|
|
|
|
|
|
|
class EncoderDecoderModelForSeq2SeqLM(nn.Module): |
|
|
TRANSFORMER_CLS = AutoModelForSeq2SeqLM |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lm_q: PreTrainedModel, |
|
|
lm_p: PreTrainedModel, |
|
|
untie_encoder: bool = False, |
|
|
negatives_x_device: bool = False, |
|
|
tokenizer=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.lm_q = lm_q |
|
|
self.lm_p = lm_p |
|
|
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") |
|
|
self.negatives_x_device = negatives_x_device |
|
|
self.untie_encoder = untie_encoder |
|
|
if self.negatives_x_device: |
|
|
if not dist.is_initialized(): |
|
|
raise ValueError( |
|
|
"Distributed training has not been initialized for representation all gather." |
|
|
) |
|
|
self.process_rank = dist.get_rank() |
|
|
self.world_size = dist.get_world_size() |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.gen_r1_list = [] |
|
|
|
|
|
def forward( |
|
|
self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None |
|
|
): |
|
|
|
|
|
decoder_start_token_id = self.config.decoder_start_token_id |
|
|
if query is not None: |
|
|
query["decoder_input_ids"] = ( |
|
|
torch.zeros(query["input_ids"].shape[0], 1, dtype=torch.long) |
|
|
.fill_(decoder_start_token_id) |
|
|
.to(query["input_ids"].device) |
|
|
) |
|
|
query["decoder_attention_mask"] = torch.ones( |
|
|
query["input_ids"].shape[0], 1, dtype=torch.long |
|
|
).to( |
|
|
query["input_ids"].device |
|
|
) |
|
|
if passage is not None: |
|
|
passage["decoder_input_ids"] = ( |
|
|
torch.zeros(passage["input_ids"].shape[0], 1, dtype=torch.long) |
|
|
.fill_(decoder_start_token_id) |
|
|
.to(passage["input_ids"].device) |
|
|
) |
|
|
passage["decoder_attention_mask"] = torch.ones( |
|
|
passage["input_ids"].shape[0], 1, dtype=torch.long |
|
|
).to( |
|
|
passage["input_ids"].device |
|
|
) |
|
|
|
|
|
q_reps = self.encode_query(query) |
|
|
|
|
|
if passage is None: |
|
|
p_reps, p_reps_dt = None, None |
|
|
else: |
|
|
p_reps, p_attention, p_reps_dt = self.encode_passage(passage) |
|
|
|
|
|
predid2oldid = self.trainer.train_dataset.predid2oldid |
|
|
oldid2predid = self.trainer.train_dataset.oldid2predid |
|
|
|
|
|
pred_docids = p_attention.argmax(dim=-1).cpu().tolist() |
|
|
|
|
|
|
|
|
if ( |
|
|
self.trainer.data_args.negative_passage_type == "self" |
|
|
and p_reps is not None |
|
|
): |
|
|
for oldid, predid in zip(passage["oldid"].cpu().tolist(), pred_docids): |
|
|
for j in range(1, len(predid) + 1): |
|
|
cur_predid = "-".join([str(x) for x in predid[:j]]) |
|
|
predid2oldid[cur_predid].add(oldid) |
|
|
oldid2predid[oldid] = cur_predid |
|
|
|
|
|
|
|
|
if q_reps is None or p_reps is None: |
|
|
return EncoderDecoderOutputForSeq2SeqLM(q_reps=q_reps, p_reps=p_reps_dt) |
|
|
|
|
|
|
|
|
loss, lm_loss = 0, 0 |
|
|
gen_r1 = 0 |
|
|
if self.training: |
|
|
if self.negatives_x_device: |
|
|
q_reps = self._dist_gather_tensor(q_reps) |
|
|
p_reps = self._dist_gather_tensor(p_reps) |
|
|
|
|
|
scores = self.compute_similarity( |
|
|
q_reps, p_reps |
|
|
) |
|
|
scores = scores.view(q_reps.size(0), -1) |
|
|
|
|
|
scores = scores / self.softmax_temperature |
|
|
|
|
|
target = torch.arange( |
|
|
scores.size(0), device=scores.device, dtype=torch.long |
|
|
) |
|
|
target = target * (p_reps.size(0) // q_reps.size(0)) |
|
|
|
|
|
loss = self.compute_loss(scores, target) |
|
|
|
|
|
if "infonce_loss" in self.model_args.__dict__: |
|
|
loss = loss * self.model_args.infonce_loss |
|
|
|
|
|
if self.negatives_x_device: |
|
|
loss = loss * self.world_size |
|
|
|
|
|
lm_logits = q_reps @ self.lm_q.shared.weight.T |
|
|
|
|
|
if self.model_args.mask_special_tokens_for_decoding: |
|
|
special_token_ids = self.tokenizer.all_special_ids |
|
|
special_token_ids = [ |
|
|
x |
|
|
for x in special_token_ids |
|
|
if x |
|
|
not in [ |
|
|
self.tokenizer.bos_token_id, |
|
|
self.tokenizer.eos_token_id, |
|
|
self.tokenizer.pad_token_id, |
|
|
] |
|
|
] |
|
|
lm_logits[:, :, special_token_ids] = -1e9 |
|
|
|
|
|
lm_logits = lm_logits / self.softmax_temperature |
|
|
|
|
|
|
|
|
pos_doc = torch.arange(q_reps.size(0), dtype=torch.long) |
|
|
pos_doc = pos_doc * (p_reps.size(0) // q_reps.size(0)) |
|
|
pos_p_attention = p_attention[ |
|
|
pos_doc, :, : |
|
|
] |
|
|
|
|
|
lm_targets = pos_p_attention.argmax(dim=-1) |
|
|
if ( |
|
|
"q_to_docid_loss" in self.model_args.__dict__ |
|
|
and self.model_args.q_to_docid_loss > 0 |
|
|
): |
|
|
lm_loss = self.cross_entropy( |
|
|
lm_logits.view(-1, lm_logits.size(-1)), lm_targets.view(-1) |
|
|
) |
|
|
|
|
|
loss += lm_loss * self.model_args.q_to_docid_loss |
|
|
|
|
|
if ( |
|
|
"cosine_point_loss" in self.model_args.__dict__ |
|
|
and self.model_args.cosine_point_loss > 0 |
|
|
): |
|
|
pos_doc_lm_logits = ( |
|
|
p_reps_dt[pos_doc] @ self.lm_p.shared.weight.T |
|
|
) |
|
|
query_lm_logits = lm_logits |
|
|
|
|
|
pos_doc_id_weight, pos_doc_id = pos_doc_lm_logits.max(dim=-1) |
|
|
query_id_weight = query_lm_logits.gather( |
|
|
dim=-1, index=pos_doc_id.unsqueeze(-1) |
|
|
).squeeze( |
|
|
-1 |
|
|
) |
|
|
|
|
|
cosine_sim = F.cosine_similarity( |
|
|
pos_doc_id_weight, query_id_weight, dim=-1 |
|
|
) |
|
|
cosine_loss = torch.mean(1 - cosine_sim) |
|
|
|
|
|
loss += cosine_loss * self.model_args.cosine_point_loss |
|
|
|
|
|
query_preds = lm_logits.argmax(dim=-1) |
|
|
gen_r1 = (query_preds == lm_targets).float().mean(dim=1) == 1 |
|
|
self.gen_r1_list += gen_r1.cpu().tolist() |
|
|
|
|
|
|
|
|
else: |
|
|
scores = self.compute_similarity(q_reps, p_reps) |
|
|
loss = None |
|
|
return EncoderDecoderOutputForSeq2SeqLM( |
|
|
loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps_dt |
|
|
) |
|
|
|
|
|
def encode_passage(self, psg): |
|
|
raise NotImplementedError("EncoderDecoderModel is an abstract class") |
|
|
|
|
|
def encode_query(self, qry): |
|
|
raise NotImplementedError("EncoderDecoderModel is an abstract class") |
|
|
|
|
|
def compute_similarity(self, q_reps, p_reps): |
|
|
raise NotImplementedError("EncoderModel is an abstract class") |
|
|
|
|
|
def compute_loss(self, scores, target): |
|
|
return self.cross_entropy(scores, target) |
|
|
|
|
|
def build_tree(self, log_step: int = None, log_file: str = None): |
|
|
docid_df = pd.read_csv( |
|
|
self.model_args.docid_file_name, |
|
|
sep="\t", |
|
|
names=["oldid", "docid", "docid_logit", "text"], |
|
|
dtype={"oldid": str, "docid": str, "docid_logit": str, "text": str}, |
|
|
).loc[:, ["oldid", "docid", "docid_logit"]] |
|
|
|
|
|
lines = [] |
|
|
|
|
|
|
|
|
time_str = time.strftime("%Y%m%d-%H%M%S") |
|
|
lines.append(f"TIME={time_str}") |
|
|
lines.append(f"STEP={log_step}") |
|
|
|
|
|
|
|
|
num_uniques = len(set(docid_df.docid)) |
|
|
unique_ratio = num_uniques / len(docid_df) |
|
|
lines.append( |
|
|
f"num_uniques: {num_uniques}/{len(docid_df)} ({unique_ratio*100:.2f}%)" |
|
|
) |
|
|
lines.append("[Frequent Collision]") |
|
|
lines.append(docid_df.docid.value_counts()[:5].to_string()) |
|
|
num_unique_tokens = [0] * 10 |
|
|
docid2num_docs = dict(docid_df.docid.value_counts()) |
|
|
for docid in docid_df.docid.unique(): |
|
|
tokens = docid.split("<->") |
|
|
num_unique_token = len(set(tokens)) |
|
|
num_unique_tokens[num_unique_token - 1] += 1 |
|
|
lines.append(f"distribution of number of unique tokens: {num_unique_tokens}") |
|
|
|
|
|
|
|
|
docid2oldids = defaultdict(list) |
|
|
oldid2docid_logit = dict() |
|
|
for docid, oldid, docid_logit in docid_df[ |
|
|
["docid", "oldid", "docid_logit"] |
|
|
].values: |
|
|
docid2oldids[docid].append(oldid) |
|
|
oldid2docid_logit[oldid] = torch.tensor( |
|
|
[float(x) for x in docid_logit.split("<->")] |
|
|
) |
|
|
|
|
|
|
|
|
oldid2docid = dict(zip(docid_df.oldid, docid_df.docid)) |
|
|
if self.model_args.tree == 1: |
|
|
builder = TreeBuilder() |
|
|
all_id = [] |
|
|
for docid in list(oldid2docid.values()): |
|
|
toks = docid.split("<->") |
|
|
toks = self.tokenizer.convert_tokens_to_ids(toks) |
|
|
if len(toks) != self.model_args.max_output_length - 1: |
|
|
print( |
|
|
f"length of docid of {docid} is not {self.model_args.max_output_length}" |
|
|
) |
|
|
toks = toks[: self.model_args.max_output_length - 1] |
|
|
all_id.append(toks) |
|
|
builder.add(toks) |
|
|
self.root = builder.build() |
|
|
|
|
|
self.docid2num_docs = docid2num_docs |
|
|
self.docid2oldids = docid2oldids |
|
|
self.oldid2docid_logit = oldid2docid_logit |
|
|
self.oldid2docid = oldid2docid |
|
|
|
|
|
if log_file is not None: |
|
|
with open(log_file, "a") as f: |
|
|
f.write("\n".join(lines) + "\n") |
|
|
else: |
|
|
print("\n".join(lines)) |
|
|
|
|
|
def _dist_gather_tensor(self, t: Optional[torch.Tensor]): |
|
|
if t is None: |
|
|
return None |
|
|
t = t.contiguous() |
|
|
|
|
|
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
|
|
dist.all_gather(all_tensors, t) |
|
|
|
|
|
all_tensors[self.process_rank] = t |
|
|
all_tensors = torch.cat(all_tensors, dim=0) |
|
|
|
|
|
return all_tensors |
|
|
|
|
|
@classmethod |
|
|
def build( |
|
|
cls, |
|
|
model_args: ModelArguments, |
|
|
train_args: TrainingArguments, |
|
|
tokenizer=None, |
|
|
**hf_kwargs, |
|
|
): |
|
|
cls.config = hf_kwargs["config"] |
|
|
cls.softmax_temperature = model_args.softmax_temperature |
|
|
cls.model_args = model_args |
|
|
|
|
|
|
|
|
if os.path.isdir(model_args.model_name_or_path): |
|
|
if model_args.untie_encoder: |
|
|
_qry_model_path = os.path.join( |
|
|
model_args.model_name_or_path, "query_model" |
|
|
) |
|
|
_psg_model_path = os.path.join( |
|
|
model_args.model_name_or_path, "passage_model" |
|
|
) |
|
|
if not os.path.exists(_qry_model_path): |
|
|
_qry_model_path = model_args.model_name_or_path |
|
|
_psg_model_path = model_args.model_name_or_path |
|
|
logger.info(f"loading query model weight from {_qry_model_path}") |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs) |
|
|
logger.info(f"loading passage model weight from {_psg_model_path}") |
|
|
lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs) |
|
|
else: |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_args.model_name_or_path, **hf_kwargs |
|
|
) |
|
|
lm_p = lm_q |
|
|
|
|
|
else: |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_args.model_name_or_path, **hf_kwargs |
|
|
) |
|
|
lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q |
|
|
|
|
|
model = cls( |
|
|
lm_q=lm_q, |
|
|
lm_p=lm_p, |
|
|
negatives_x_device=train_args.negatives_x_device, |
|
|
untie_encoder=model_args.untie_encoder, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
return model |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
model_args: ModelArguments, |
|
|
tokenizer: Optional[AutoTokenizer] = None, |
|
|
**hf_kwargs, |
|
|
): |
|
|
cls.config = hf_kwargs["config"] |
|
|
cls.softmax_temperature = model_args.softmax_temperature |
|
|
cls.model_args = model_args |
|
|
|
|
|
|
|
|
untie_encoder = True |
|
|
if os.path.isdir(model_args.model_name_or_path): |
|
|
_qry_model_path = os.path.join(model_args.model_name_or_path, "query_model") |
|
|
_psg_model_path = os.path.join( |
|
|
model_args.model_name_or_path, "passage_model" |
|
|
) |
|
|
if os.path.exists(_qry_model_path): |
|
|
logger.info(f"found separate weight for query/passage encoders") |
|
|
logger.info(f"loading query model weight from {_qry_model_path}") |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs) |
|
|
logger.info(f"loading passage model weight from {_psg_model_path}") |
|
|
lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs) |
|
|
untie_encoder = False |
|
|
else: |
|
|
logger.info(f"try loading tied weight") |
|
|
logger.info( |
|
|
f"loading model weight from {model_args.model_name_or_path}" |
|
|
) |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_args.model_name_or_path, **hf_kwargs |
|
|
) |
|
|
lm_p = lm_q |
|
|
else: |
|
|
logger.info(f"try loading tied weight") |
|
|
logger.info(f"loading model weight from {model_args.model_name_or_path}") |
|
|
lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
|
|
model_args.model_name_or_path, **hf_kwargs |
|
|
) |
|
|
lm_p = lm_q |
|
|
|
|
|
model = cls(lm_q=lm_q, lm_p=lm_p, untie_encoder=untie_encoder) |
|
|
return model |
|
|
|
|
|
def save(self, output_dir: str): |
|
|
if self.untie_encoder: |
|
|
os.makedirs(os.path.join(output_dir, "query_model")) |
|
|
os.makedirs(os.path.join(output_dir, "passage_model")) |
|
|
self.lm_q.save_pretrained(os.path.join(output_dir, "query_model")) |
|
|
self.lm_p.save_pretrained(os.path.join(output_dir, "passage_model")) |
|
|
else: |
|
|
self.lm_q.save_pretrained(output_dir) |
|
|
|