| | """ |
| | This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking). |
| | |
| | The query and the passage are passed simoultanously to a Transformer network. The network then returns |
| | a score between 0 and 1 how relevant the passage is for a given query. |
| | |
| | The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages |
| | for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder |
| | for scoring. You sort the results then according to the output of the CrossEncoder. |
| | |
| | This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking. |
| | |
| | Running this script: |
| | python train_cross-encoder.py |
| | """ |
| | from torch.utils.data import DataLoader |
| | from sentence_transformers import LoggingHandler, util |
| | from sentence_transformers.cross_encoder import CrossEncoder |
| | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator |
| | from sentence_transformers import InputExample |
| | import logging |
| | from datetime import datetime |
| | import gzip |
| | import os |
| | import tarfile |
| | import tqdm |
| |
|
| | |
| | logging.basicConfig(format='%(asctime)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S', |
| | level=logging.INFO, |
| | handlers=[LoggingHandler()]) |
| | |
| |
|
| |
|
| | |
| | model_name = 'distilroberta-base' |
| | train_batch_size = 32 |
| | num_epochs = 1 |
| | model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | pos_neg_ration = 4 |
| |
|
| | |
| | max_train_samples = 2e7 |
| |
|
| | |
| | model = CrossEncoder(model_name, num_labels=1, max_length=512) |
| |
|
| |
|
| | |
| | data_folder = 'msmarco-data' |
| | os.makedirs(data_folder, exist_ok=True) |
| |
|
| |
|
| | |
| | corpus = {} |
| | collection_filepath = os.path.join(data_folder, 'collection.tsv') |
| | if not os.path.exists(collection_filepath): |
| | tar_filepath = os.path.join(data_folder, 'collection.tar.gz') |
| | if not os.path.exists(tar_filepath): |
| | logging.info("Download collection.tar.gz") |
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) |
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar: |
| | tar.extractall(path=data_folder) |
| |
|
| | with open(collection_filepath, 'r', encoding='utf8') as fIn: |
| | for line in fIn: |
| | pid, passage = line.strip().split("\t") |
| | corpus[pid] = passage |
| |
|
| |
|
| | |
| | queries = {} |
| | queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
| | if not os.path.exists(queries_filepath): |
| | tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
| | if not os.path.exists(tar_filepath): |
| | logging.info("Download queries.tar.gz") |
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar: |
| | tar.extractall(path=data_folder) |
| |
|
| |
|
| | with open(queries_filepath, 'r', encoding='utf8') as fIn: |
| | for line in fIn: |
| | qid, query = line.strip().split("\t") |
| | queries[qid] = query |
| |
|
| |
|
| |
|
| | |
| | train_samples = [] |
| | dev_samples = {} |
| |
|
| | |
| | |
| | num_dev_queries = 200 |
| | num_max_dev_negatives = 200 |
| |
|
| | |
| | |
| | |
| | train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz') |
| | if not os.path.exists(train_eval_filepath): |
| | logging.info("Download "+os.path.basename(train_eval_filepath)) |
| | util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath) |
| |
|
| | with gzip.open(train_eval_filepath, 'rt') as fIn: |
| | for line in fIn: |
| | qid, pos_id, neg_id = line.strip().split() |
| |
|
| | if qid not in dev_samples and len(dev_samples) < num_dev_queries: |
| | dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()} |
| |
|
| | if qid in dev_samples: |
| | dev_samples[qid]['positive'].add(corpus[pos_id]) |
| |
|
| | if len(dev_samples[qid]['negative']) < num_max_dev_negatives: |
| | dev_samples[qid]['negative'].add(corpus[neg_id]) |
| |
|
| |
|
| | |
| | train_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train.tsv.gz') |
| | if not os.path.exists(train_filepath): |
| | logging.info("Download "+os.path.basename(train_filepath)) |
| | util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz', train_filepath) |
| |
|
| | cnt = 0 |
| | with gzip.open(train_filepath, 'rt') as fIn: |
| | for line in tqdm.tqdm(fIn, unit_scale=True): |
| | qid, pos_id, neg_id = line.strip().split() |
| |
|
| | if qid in dev_samples: |
| | continue |
| |
|
| | query = queries[qid] |
| | if (cnt % (pos_neg_ration+1)) == 0: |
| | passage = corpus[pos_id] |
| | label = 1 |
| | else: |
| | passage = corpus[neg_id] |
| | label = 0 |
| |
|
| | train_samples.append(InputExample(texts=[query, passage], label=label)) |
| | cnt += 1 |
| |
|
| | if cnt >= max_train_samples: |
| | break |
| |
|
| | |
| | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) |
| |
|
| | |
| | |
| | evaluator = CERerankingEvaluator(dev_samples, name='train-eval') |
| |
|
| | |
| | warmup_steps = 5000 |
| | logging.info("Warmup-steps: {}".format(warmup_steps)) |
| |
|
| |
|
| | |
| | model.fit(train_dataloader=train_dataloader, |
| | evaluator=evaluator, |
| | epochs=num_epochs, |
| | evaluation_steps=10000, |
| | warmup_steps=warmup_steps, |
| | output_path=model_save_path, |
| | use_amp=True) |
| |
|
| | |
| | model.save(model_save_path+'-latest') |