| | |
| | |
| | """ |
| | preprocess input data into feature and stores binary as python shelve DB |
| | each chunk is gzipped JSON string |
| | """ |
| | import argparse |
| | import gzip |
| | import json |
| | import subprocess as sp |
| | import shelve |
| | import os |
| | from os.path import dirname, exists, join |
| |
|
| | import torch |
| | from lsp_model import GPT2Tokenizer |
| | from tqdm import tqdm |
| |
|
| | from env import END_OF_TEXT_TOKEN |
| | from gpt2_training.train_utils import InputFeatures_train as InputFeatures |
| |
|
| |
|
| | def _get_file_len(corpus): |
| | n_line = int(sp.check_output(f"wc -l {corpus}".split(), |
| | universal_newlines=True).split()[0]) |
| | return n_line |
| |
|
| |
|
| | def _norm_text(text): |
| | w, *toks = text.strip().split() |
| | try: |
| | w = float(w) |
| | except Exception: |
| | toks = [w] + toks |
| | w = 1.0 |
| | return w, ' '.join(toks) |
| |
|
| |
|
| | def _get_inputs_from_text(text, tokenizer): |
| | srcs, tgt = text.strip().split('\t') |
| | weights = [] |
| | inputs = [] |
| | for src in srcs.split(' EOS '): |
| | src_weight, src = _norm_text(src) |
| | context_id = tokenizer.encode(src) |
| | weights.append(src_weight) |
| | inputs.append(context_id) |
| | tgt_weight, tgt = _norm_text(tgt) |
| | if tgt_weight != 0: |
| | response_id = tokenizer.encode(tgt) |
| | weights.append(tgt_weight) |
| | inputs.append(response_id) |
| | return weights, inputs |
| |
|
| |
|
| | def _make_features(id_, weights, inputs, tokenizer, max_len): |
| | end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] |
| | features = [] |
| | sents = [] |
| | ws = [] |
| | len_ = 0 |
| | i = 0 |
| | for ids, w in zip(inputs, weights): |
| | if len(ids) > max_len: |
| | if len(sents) >= 2: |
| | feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
| | if feat is not None: |
| | features.append(feat) |
| | i += 1 |
| | len_ = 0 |
| | sents = [] |
| | ws = [] |
| | continue |
| | elif len_ > max_len: |
| | feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
| | if feat is not None: |
| | features.append(feat) |
| | i += 1 |
| | len_ = len(sents[-1]) + 1 |
| | sents = sents[-1:] |
| | ws = ws[-1:] |
| | len_ += (len(ids) + 1) |
| | sents.append(ids) |
| | ws.append(w) |
| | if len(sents) >= 2: |
| | feat = _make_feature(id_ + i, sents, ws, end_of_text_id) |
| | if feat is not None: |
| | features.append(feat) |
| |
|
| | return features |
| |
|
| |
|
| | def _make_feature(id_, sents, ws, eos): |
| | if all(w == 0 for w in ws[1:]): |
| | return None |
| | input_ids = [i for s in sents for i in s+[eos]][:-1] |
| | lm_labels = [] |
| | weights = [] |
| | token_type_ids = [] |
| | for i, (s, w) in enumerate(zip(sents, ws)): |
| | if i == 0: |
| | lm_labels += [-1] * len(s) |
| | weights += [0.0] * len(s) |
| | token_type_ids += [0] * len(s) |
| | continue |
| |
|
| | token_type_ids += [i] * (len(s) + 1) |
| | if w == 0.0: |
| | lm_labels += [-1] * (len(s) + 1) |
| | weights += [0.0] * (len(s) + 1) |
| | else: |
| | lm_labels += (s + [eos]) |
| | weights += [w] * (len(s) + 1) |
| |
|
| | |
| | i = len(lm_labels) - 1 |
| | while i >= 0: |
| | if lm_labels[i] != -1: |
| | break |
| | i -= 1 |
| | input_ids = input_ids[:i+1] |
| | lm_labels = lm_labels[:i+1] |
| | weights = weights[:i+1] |
| | token_type_ids = token_type_ids[:i+1] |
| |
|
| | |
| | while len(input_ids) % 8 != 0: |
| | input_ids.append(0) |
| | token_type_ids.append(0) |
| | lm_labels.append(-1) |
| | weights.append(0.0) |
| |
|
| | position_ids = list(range(len(input_ids))) |
| | assert (len(input_ids) == len(position_ids) == len(token_type_ids) |
| | == len(lm_labels) == len(weights)) |
| | assert len(input_ids) % 8 == 0 |
| | if len(input_ids) == 0: |
| | import pdb |
| | pdb.set_trace() |
| | feature = InputFeatures(id_, input_ids, position_ids, token_type_ids, |
| | lm_labels, weights) |
| | return feature |
| |
|
| |
|
| | def main(args): |
| | toker = GPT2Tokenizer.from_pretrained('gpt2') |
| | attrs = [] |
| | if args.reverse: |
| | attrs.append('reverse') |
| | if args.two_turn: |
| | attrs.append('2turn') |
| | if attrs: |
| | db_path = (f'{args.corpus[:-4]}.{args.max_seq_len}len.' |
| | f'{".".join(attrs)}.db/db') |
| | else: |
| | db_path = f'{args.corpus[:-4]}.{args.max_seq_len}len.db/db' |
| | if exists(dirname(db_path)): |
| | raise ValueError('Found existing DB, please backup') |
| | else: |
| | os.makedirs(dirname(db_path)) |
| | with open(args.corpus, "r", encoding="utf-8") as reader, \ |
| | shelve.open(db_path, 'n') as db: |
| | chunk = [] |
| | n_chunk = 0 |
| | n_example = 0 |
| | for line in tqdm(reader, total=_get_file_len(args.corpus)): |
| | try: |
| | if len(chunk) >= args.chunk_size: |
| | |
| | db[f'chunk_{n_chunk}'] = gzip.compress( |
| | json.dumps(chunk[:args.chunk_size]).encode('utf-8')) |
| | chunk = chunk[args.chunk_size:] |
| | n_chunk += 1 |
| |
|
| | weights, inputs = _get_inputs_from_text(line, toker) |
| | if args.reverse: |
| | weights = list(reversed(weights)) |
| | inputs = list(reversed(inputs)) |
| | if args.two_turn: |
| | weights = weights[:2] |
| | inputs = inputs[:2] |
| | if len(weights) < 2: |
| | continue |
| | features = _make_features(n_example, weights, inputs, |
| | toker, args.max_seq_len) |
| | for feature in features: |
| | chunk.append(vars(feature)) |
| | n_example += 1 |
| | except Exception as e: |
| | print('!!! prepro exception !!!', e) |
| | continue |
| | |
| | db[f'chunk_{n_chunk}'] = gzip.compress( |
| | json.dumps(chunk).encode('utf-8')) |
| | |
| | meta = {'n_example': n_example, |
| | 'chunk_size': args.chunk_size, |
| | 'max_seq_len': args.max_seq_len, |
| | 'reverse': args.reverse, |
| | 'two_turn': args.two_turn} |
| | with open(join(dirname(db_path), 'meta.json'), 'w') as writer: |
| | json.dump(meta, writer, indent=4) |
| | torch.save(toker, join(dirname(db_path), 'tokenizer.pt')) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--corpus', required=True, |
| | help='file name of training corpus (should be .tsv)') |
| | parser.add_argument('--chunk_size', type=int, default=65536, |
| | help='num of data examples in a storing chunk') |
| | parser.add_argument('--max_seq_len', type=int, default=128, |
| | help='discard data longer than this') |
| | parser.add_argument('--reverse', action='store_true', |
| | help='reverse the src tgt') |
| | parser.add_argument('--two_turn', action='store_true', |
| | help='take only the first 2 turns') |
| |
|
| | args = parser.parse_args() |
| |
|
| | main(args) |
| |
|