| | |
| | |
| |
|
| | import os |
| | import sys |
| | import argparse |
| | from tqdm import trange |
| |
|
| | import torch |
| | import torch.optim |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from torch.autograd import Variable |
| |
|
| | lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..') |
| | sys.path.insert(1, lab_root) |
| | from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer |
| |
|
| | from IPython import embed |
| |
|
| | def top_k_logits(logits, k, probs=False): |
| | """ |
| | Masks everything but the k top entries as -infinity (1e10). |
| | Used to mask logits such that e^-infinity -> 0 won't contribute to the |
| | sum of the denominator. |
| | """ |
| | if k == 0: |
| | return logits |
| | else: |
| | values = torch.topk(logits, k)[0] |
| | batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
| | if probs: |
| | return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) |
| | return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) |
| |
|
| | def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, |
| | top_k=0, device='cuda', sample=True, return_past=False): |
| | if start_token is None: |
| | assert context is not None, 'Specify exactly one of start_token and context!' |
| | context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) |
| | else: |
| | assert context is None, 'Specify exactly one of start_token and context!' |
| | context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) |
| | |
| | prev = context |
| | output = context |
| | past = None |
| | with torch.no_grad(): |
| | for i in trange(length, ascii=True): |
| | logits, past = model(prev, past=past) |
| | logits = logits[:, -1, :] / temperature |
| | logits = top_k_logits(logits, k=top_k) |
| | log_probs = F.softmax(logits, dim=-1) |
| | if sample: |
| | prev = torch.multinomial(log_probs, num_samples=1) |
| | else: |
| | _, prev = torch.topk(log_probs, k=1, dim=-1) |
| | |
| | |
| | |
| | |
| | output = torch.cat((output, prev), dim=1) |
| | |
| | if return_past: |
| | return output, past |
| | else: |
| | return output |
| |
|
| |
|
| | def sample_from_hidden(model, length, hidden, context=None, past=None, temperature=1, |
| | top_k=0, device='cuda', sample=True, noise_level=1e-1): |
| | output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) if context else None |
| | with torch.no_grad(): |
| | for i in trange(length, ascii=True): |
| | logits = model.forward_hidden(hidden) |
| | logits = logits[:, -1, :] / temperature |
| | logits = top_k_logits(logits, k=top_k) |
| | log_probs = F.softmax(logits, dim=-1) |
| | if sample: |
| | prev = torch.multinomial(log_probs, num_samples=1) |
| | else: |
| | _, prev = torch.topk(log_probs, k=1, dim=-1) |
| | |
| | |
| | |
| | output = prev if output is None else torch.cat((output, prev), dim=1) |
| | if i == 0: |
| | _, past = model(output, past=None) |
| | else: |
| | _, past = model(prev, past=past) |
| | hidden = model.hidden_states |
| | |
| | |
| | |
| | |
| | hidden = modify_hidden(hidden, noise_level) |
| | return output |
| |
|
| | def modify_hidden(input_tensor, noise_level=1e-1): |
| | |
| | length = input_tensor.shape[-1] |
| | ret = input_tensor + torch.rand(length).cuda() * noise_level |
| | return ret |
| |
|
| | def compute_log_likelihood(model, phrase, tokenizer, device): |
| | token_ids = tokenizer.encode(phrase) |
| | batch_size = 1 |
| | context = torch.tensor(token_ids, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) |
| | print("Computing LL of phrase \"{}\"".format(phrase)) |
| | print("After encoding, number of tokens {}".format(len(token_ids))) |
| | with torch.no_grad(): |
| | logits, past = model(context, past=None) |
| |
|
| | _idxs = range(len(token_ids) - 1) |
| | token_ids = token_ids[1:] |
| | logits = logits[0, :-1] |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | likelihoods = probs[_idxs, token_ids] |
| | assert len(list(likelihoods.shape)) == 1 |
| |
|
| | log_likelihoods = torch.log(likelihoods) |
| | ll_list = [ls.item() for ls in log_likelihoods] |
| | |
| | for token, llh in zip(token_ids, log_likelihoods): |
| | print("LL of token {} (\'{}\') ==> {:.4f}".format(token, tokenizer.decode([token]), llh)) |
| |
|
| | print("LL of the phrase (sum of the above): {}".format(np.sum(ll_list))) |
| | return np.sum(ll_list) |
| |
|
| |
|
| | def get_embedding_grad(model, enc, context=None, target=40, device='cuda', ll_only=False, opt_embed=False): |
| | assert context is not None, 'Input text is needed' |
| | |
| | |
| | |
| | context = torch.tensor(context, device=device, dtype=torch.float).unsqueeze(0) |
| | |
| | model.zero_grad() |
| | logits, past = model(context, past=None) |
| | |
| | |
| | |
| |
|
| | logits = logits[:, -1, :] |
| | log_probs = F.softmax(logits, dim=-1) |
| |
|
| | if len(target) > 1: |
| | nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) |
| | else: |
| | nll = - torch.log(log_probs[:, target]) |
| |
|
| | |
| | with torch.no_grad(): |
| | |
| | log_probs = F.softmax(logits, dim=-1) |
| | top1, top1ind = torch.topk(log_probs, k=1, dim=-1) |
| | |
| | print('LL of target : {}'.format(-nll.data.squeeze().cpu().numpy())) |
| | print('LL of top 1 : {}'.format(torch.log(top1).data.squeeze().cpu().numpy())) |
| |
|
| | if ll_only: |
| | return |
| | |
| | if opt_embed: |
| | orig_embed = model.transformer.i_embeds.clone() |
| | embed_vars = Variable(model.transformer.i_embeds, requires_grad=True) |
| | |
| | optimizer = torch.optim.Adam([embed_vars], lr=0.01) |
| | optimizer.zero_grad() |
| | |
| | for ss in range(50): |
| | |
| | nll.backward() |
| | optimizer.step() |
| | |
| | logits, past = model.forward_embed(embed_vars, past=None) |
| | logits = logits[:, -1, :] |
| | log_probs = F.softmax(logits, dim=-1) |
| |
|
| | if len(target) > 1: |
| | nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) |
| | else: |
| | nll = - torch.log(log_probs[:, target]) |
| | |
| | print('LL of target (step {}): {}'.format(ss, -nll.data.squeeze().cpu().numpy())) |
| | |
| |
|
| | |
| | |
| | output_ids = torch.empty_like(context.long()) |
| | with torch.no_grad(): |
| | all_embeds = model.transformer.wte.weight |
| | embed_vars_unbind = torch.unbind(embed_vars, dim=1) |
| | orig_embed_unbind = torch.unbind(orig_embed, dim=1) |
| |
|
| | cc = 0 |
| | for ie_new, ie_orig, orig_id in zip(embed_vars_unbind, orig_embed_unbind, context.squeeze(0)): |
| | new_id = (all_embeds - ie_new).abs().sum(1).argmin() |
| |
|
| | print('emb {}: {} (`{}`) to {} (`{}`)'.format(cc, orig_id.tolist(), enc.decode([orig_id.tolist()]), |
| | new_id.tolist(), enc.decode([new_id.tolist()]))) |
| |
|
| | output_ids[0, cc] = new_id |
| | cc += 1 |
| | |
| | output_ids = torch.cat((context.long(), output_ids), dim=1) |
| | return output_ids |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | return output_ids |
| |
|
| | def run_model(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--model_path', '-M', type=str, default='gpt-2_pt_models/774M/', |
| | help='pretrained model name or path to local checkpoint') |
| | parser.add_argument("--seed", type=int, default=0) |
| | parser.add_argument("--nsamples", type=int, default=1) |
| | parser.add_argument("--batch_size", type=int, default=-1) |
| | parser.add_argument("--length", type=int, default=-1) |
| | parser.add_argument("--temperature", type=float, default=1.0) |
| | parser.add_argument("--top_k", type=int, default=0) |
| | parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') |
| | parser.add_argument('--nocuda', action='store_true', help='no cuda') |
| | parser.add_argument('--opt_ll', action='store_true', help='nll optimize') |
| | parser.add_argument('--get_ll', action='store_true', help='compute log likelihood of sentence') |
| | parser.add_argument('--hidden_playground', action='store_true', help='play around in the hidden representation') |
| | parser.add_argument("--noise_level", type=float, default=1e-1) |
| | parser.add_argument("--cond-text", type=str, default='', help='Prefix texts to condition on') |
| | parser.add_argument('--output', type=str, default=os.environ.get('GIT_RESULTS_MANAGER_DIR', None), help='output directory') |
| | args = parser.parse_args() |
| | print(args) |
| |
|
| | if args.batch_size == -1: |
| | args.batch_size = 1 |
| | assert args.nsamples % args.batch_size == 0 |
| |
|
| | np.random.seed(args.seed) |
| | torch.random.manual_seed(args.seed) |
| | torch.cuda.manual_seed(args.seed) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | if args.nocuda: |
| | device = torch.device("cpu") |
| |
|
| | print('device is {}'.format(device)) |
| |
|
| | enc = GPT2Tokenizer.from_pretrained(args.model_path) |
| | model = GPT2LMHeadModel.from_pretrained(args.model_path) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | if args.length == -1: |
| | args.length = model.config.n_ctx // 2 |
| | elif args.length > model.config.n_ctx: |
| | raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) |
| |
|
| | |
| | generated = 0 |
| | for _ in range(10): |
| | context_tokens = [] |
| | if not args.unconditional: |
| | |
| | raw_text = args.cond_text |
| | while not raw_text: |
| | print('Prompt should not be empty!') |
| | raw_text = input("Model prompt >>> ") |
| | context_tokens = enc.encode(raw_text) |
| | for _ in range(args.nsamples // args.batch_size): |
| | out = sample_sequence( |
| | model=model, length=args.length, |
| | context=context_tokens, |
| | start_token=None, |
| | batch_size=args.batch_size, |
| | temperature=args.temperature, top_k=args.top_k, device=device |
| | ) |
| | |
| | out = out[:, 0:].tolist() |
| | for i in range(args.batch_size): |
| | generated += 1 |
| | text = enc.decode(out[i]) |
| | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) |
| | print(text) |
| | if args.output: |
| | filepath = os.path.join(args.output, "generated_{}.txt".format(generated)) |
| | with open(filepath, "w") as f: |
| | f.write(text) |
| | |
| | |
| | if args.unconditional: |
| | generated = 0 |
| | for _ in range(args.nsamples // args.batch_size): |
| | out = sample_sequence( |
| | model=model, length=args.length, |
| | context=None, |
| | start_token=enc.encoder['<|endoftext|>'], |
| | batch_size=args.batch_size, |
| | temperature=args.temperature, top_k=args.top_k, device=device |
| | ) |
| | out = out[:,1:].tolist() |
| | for i in range(args.batch_size): |
| | generated += 1 |
| | text = enc.decode(out[i]) |
| | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) |
| | print(text) |
| | |
| | if args.unconditional: |
| | break |
| |
|
| |
|
| | if __name__ == '__main__': |
| | run_model() |
| |
|
| |
|
| |
|