Instructions to use yuneun92/koCSN_SAPR with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use yuneun92/koCSN_SAPR with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="yuneun92/koCSN_SAPR")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("yuneun92/koCSN_SAPR", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Author: | |
| """ | |
| import re | |
| import torch.nn as nn | |
| import torch.nn.functional as functional | |
| import torch | |
| from transformers import AutoModel | |
| import torch.autograd as autograd | |
| def get_nonlinear(nonlinear): | |
| """ | |
| Activation function. | |
| """ | |
| nonlinear_dict = {'relu': nn.ReLU(), 'tanh': nn.Tanh(), | |
| 'sigmoid': nn.Sigmoid(), 'softmax': nn.Softmax(dim=-1)} | |
| try: | |
| return nonlinear_dict[nonlinear] | |
| except: | |
| raise ValueError('not a valid nonlinear type!') | |
| class SeqPooling(nn.Module): | |
| """ | |
| Sequence pooling module. | |
| Can do max-pooling, mean-pooling and attentive-pooling on a list of sequences of different lengths. | |
| """ | |
| def __init__(self, pooling_type, hidden_dim): | |
| super(SeqPooling, self).__init__() | |
| self.pooling_type = pooling_type | |
| self.hidden_dim = hidden_dim | |
| if pooling_type == 'attentive_pooling': | |
| self.query_vec = nn.parameter.Parameter(torch.randn(hidden_dim)) | |
| def max_pool(self, seq): | |
| return seq.max(0)[0] | |
| def mean_pool(self, seq): | |
| return seq.mean(0) | |
| def attn_pool(self, seq): | |
| attn_score = torch.mm(seq, self.query_vec.view(-1, 1)).view(-1) | |
| attn_w = nn.Softmax(dim=0)(attn_score) | |
| weighted_sum = torch.mm(attn_w.view(1, -1), seq).view(-1) | |
| return weighted_sum | |
| def forward(self, batch_seq): | |
| pooling_fn = {'max_pooling': self.max_pool, | |
| 'mean_pooling': self.mean_pool, | |
| 'attentive_pooling': self.attn_pool} | |
| pooled_seq = [pooling_fn[self.pooling_type](seq) for seq in batch_seq] | |
| return torch.stack(pooled_seq, dim=0) | |
| class MLP_Scorer(nn.Module): | |
| """ | |
| MLP scorer module. | |
| A perceptron with two layers. | |
| """ | |
| def __init__(self, args, classifier_input_size): | |
| super(MLP_Scorer, self).__init__() | |
| self.scorer = nn.ModuleList() | |
| self.scorer.append(nn.Linear(classifier_input_size, args.classifier_intermediate_dim)) | |
| self.scorer.append(nn.Linear(args.classifier_intermediate_dim, 1)) | |
| self.nonlinear = get_nonlinear(args.nonlinear_type) | |
| def forward(self, x): | |
| for model in self.scorer: | |
| x = self.nonlinear(model(x)) | |
| return x | |
| class KCSN(nn.Module): | |
| """ | |
| Candidate Scoring Network. | |
| It's built on BERT with an MLP and other simple components. | |
| """ | |
| def __init__(self, args): | |
| super(KCSN, self).__init__() | |
| self.args = args | |
| self.bert_model = AutoModel.from_pretrained(args.bert_pretrained_dir) | |
| self.pooling = SeqPooling(args.pooling_type, self.bert_model.config.hidden_size) | |
| self.mlp_scorer = MLP_Scorer(args, self.bert_model.config.hidden_size * 3) | |
| self.dropout = nn.Dropout(args.dropout) | |
| def forward(self, features, sent_char_lens, mention_poses, quote_idxes, true_index, device, tokens_list, cut_css): | |
| # encoding | |
| qs_hid = [] | |
| ctx_hid = [] | |
| cdd_hid = [] | |
| unk_loc_li = [] | |
| unk_loc = 0 | |
| for i, (cdd_sent_char_lens, cdd_mention_pos, cdd_quote_idx) in enumerate( | |
| zip(sent_char_lens, mention_poses, quote_idxes)): | |
| unk_loc = unk_loc + 1 | |
| bert_output = self.bert_model( | |
| torch.tensor([features[i].input_ids], dtype=torch.long).to(device), | |
| token_type_ids=None, | |
| attention_mask=torch.tensor([features[i].input_mask], dtype=torch.long).to(device) | |
| ) | |
| modified_list = [s.replace('#', '') for s in tokens_list[i]] | |
| cnt = 1 | |
| verify = 0 | |
| num_check = 0 | |
| num_vid = -999 | |
| accum_char_len = [0] | |
| for idx, txt in enumerate(cut_css[i]): | |
| result_string = ''.join(txt) | |
| replace_dict = {']': r'\]', '[': r'\[', '?': r'\?', '-': r'\-', '!': r'\!'} | |
| string_processing = result_string[-7:].translate(str.maketrans(replace_dict)) | |
| pattern = re.compile(rf'[{string_processing}]') | |
| cnt = 1 | |
| if num_check == 1000: | |
| accum_char_len.append(num_vid) | |
| num_check = 1000 | |
| for string in modified_list: | |
| string_nospace = string.replace(' ','') | |
| if len(accum_char_len) > idx + 1: | |
| continue | |
| for letter in string_nospace: | |
| match_result = pattern.match(letter) | |
| if match_result: | |
| verify += 1 | |
| if verify == len(result_string[-7:]): | |
| if cnt > accum_char_len[-1]: | |
| accum_char_len.append(cnt) | |
| verify = 0 | |
| num_check = len(accum_char_len) | |
| else: | |
| verify = 0 | |
| cnt = cnt + 1 | |
| if num_check == 1000: | |
| accum_char_len.append(num_vid) | |
| if -999 in accum_char_len: | |
| unk_loc_li.append(unk_loc) | |
| continue | |
| CSS_hid = bert_output['last_hidden_state'][0][1:sum(cdd_sent_char_lens) + 1].to(device) | |
| qs_hid.append(CSS_hid[accum_char_len[cdd_quote_idx]:accum_char_len[cdd_quote_idx + 1]]) | |
| ## ๋ฐํ์ ๋ถ๋ถ ์ฐพ์์ - bert tokenizer ๋ ๋ถ๋ถ์ ์ธ๋ฑ์ฑ ํ๋ ๋ถ๋ถ | |
| cnt = 1 | |
| cdd_mention_pos_bert_li = [] | |
| cdd_mention_pos_unk = [] | |
| name = cut_css[i][cdd_mention_pos[0]][cdd_mention_pos[3]] | |
| # extract only name | |
| # ์ด๋ฆ๋ง ์ถ์ถ | |
| cdd_pattern = re.compile(r'&C[0-5][0-9]&') | |
| name_process = cdd_pattern.search(name) | |
| # find candidate location in bert output | |
| # ๋ฒํธ ๊ฒฐ๊ณผ์์ ๋ฐํ์ ์์น๋ฅผ ์ฐพ์ต๋๋ค | |
| pattern_unk = re.compile(r'[\[UNK\]]') | |
| # ์ด ๋ถ๋ถ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ๊ฒ ๋๋ฉด, ๋ ์ด์ ๋์ด๊ฐ์ง ์๋๋ก ํ๋ ์ฝ๋ ์ ๋๋ค. | |
| if len(accum_char_len) < cdd_mention_pos[0]+1: | |
| maxx_len = accum_char_len[len(accum_char_len)-1] | |
| elif len(accum_char_len) == cdd_mention_pos[0]+1: | |
| maxx_len = accum_char_len[-1] + 1000 | |
| else: | |
| maxx_len = accum_char_len[cdd_mention_pos[0]+1] | |
| # ํฌํจ๋๋ ๋ฐํ์๋ฅผ ์ฐพ๊ธฐ ์ํด. | |
| start_name = None | |
| name_match = '&' | |
| for string in modified_list: | |
| string_nospace = string.replace(' ','') | |
| for letter in string_nospace: | |
| match_result_unk = pattern_unk.match(letter) | |
| if match_result_unk: | |
| cdd_mention_pos_unk.append(cnt) | |
| if start_name is True: | |
| name_match += letter | |
| if (name_match == name_process.group(0) or letter == '&') and len( | |
| cdd_mention_pos_bert_li) < 3 and maxx_len > cnt >= accum_char_len[ | |
| cdd_mention_pos[0]]: # ๋ง์ฝ & ๊ฐ ํฌํจ๋์ด ์์ ๊ฒฝ์ฐ์ ์ฌ๋์ผ๋ก ์ถ์ถ | |
| start_name = True # ๋งค์นญ์ด ๋๋ฉด, 1์ ๋ํฉ๋๋ค. | |
| if len(cdd_mention_pos_bert_li) == 1 and name_match != name_process.group(0): # ๋ง์ฝ &๊ฐ ๋๋ฒ์งธ๋ก ๋์ค๊ณ , ๋งค์นญ์ด ์๋ ๊ฒฝ์ฐ | |
| start_name = None | |
| name_match = '&' | |
| cdd_mention_pos_bert_li = [] | |
| elif name_match == name_process.group(0): # ๋๋ฒ์งธ ์ถ๊ฐ | |
| cdd_mention_pos_bert_li.append(cnt) | |
| start_name = None | |
| name_match = '&' | |
| else: | |
| cdd_mention_pos_bert_li.append(cnt-1) | |
| cnt += 1 | |
| if len(cdd_mention_pos_bert_li) == 0 & len(cdd_mention_pos_unk) != 0: | |
| cdd_mention_pos_bert_li.extend([cdd_mention_pos_unk[0], cdd_mention_pos_unk[0]+1]) | |
| elif len(cdd_mention_pos_bert_li) != 2: | |
| cdd_mention_pos_bert_li = [] | |
| cdd_mention_pos_bert_li.extend([int(cdd_mention_pos[1] * accum_char_len[-1]/sum( | |
| cdd_sent_char_lens)), int(cdd_mention_pos[2] * accum_char_len[-1]/sum( | |
| cdd_sent_char_lens))]) | |
| if cdd_mention_pos_bert_li[0] == cdd_mention_pos_bert_li[1]: | |
| cdd_mention_pos_bert_li[1] = cdd_mention_pos_bert_li[1]+1 | |
| # ctx ๊ฒฐ์ ํ๋ ์ฝ๋. candidate ์ฃผ๋ณ ์ ๋ณด ์ถ์ถ | |
| # ํ๋์ผ ๊ฒฝ์ฐ์๋ ์ ์ฒด ๋ถ๋ถ์ ๊ฐ์ ธ์จ๋ค. | |
| if len(cdd_sent_char_lens) == 1: | |
| ctx_hid.append(torch.zeros(1, CSS_hid.size(1)).to(device)) | |
| # ๋ง์ฝ ์์ ๋ฐํ์๊ฐ ์์ ๊ฒฝ์ฐ์ ์ ๋ฌธ์ฅ๋ถํฐ, ๋ง์ง๋ง(์ธ์ฉ๋ฌธ) ์ ๊น์ง ๊ฐ์ ธ์จ๋ค. | |
| elif cdd_mention_pos[0] == 0: | |
| ctx_hid.append(CSS_hid[:accum_char_len[-2]]) | |
| # ๋ง์ง๋ง์ผ๋ก ๋ฐํ์๊ฐ ๋ค์ ์์ ๊ฒฝ์ฐ์๋ ๋๋ฒ์งธ ๋ถํฐ ๋๊น์ง ๊ฐ์ ธ์จ๋ค. | |
| else: | |
| ctx_hid.append(CSS_hid[accum_char_len[1]:]) | |
| cdd_mention_pos_bert = (cdd_mention_pos[0], cdd_mention_pos_bert_li[0], | |
| cdd_mention_pos_bert_li[1]) | |
| cdd_hid.append(CSS_hid[cdd_mention_pos_bert[1]:cdd_mention_pos_bert[2]]) | |
| # pooling | |
| if not qs_hid: | |
| scores = '1' | |
| scores_false = 1 | |
| scores_true = 1 | |
| return scores, scores_false, scores_true | |
| qs_rep = self.pooling(qs_hid).to(device) | |
| ctx_rep = self.pooling(ctx_hid).to(device) | |
| cdd_rep = self.pooling(cdd_hid).to(device) | |
| # concatenate | |
| feature_vector = torch.cat([qs_rep, ctx_rep, cdd_rep], dim=-1).to(device) | |
| # dropout | |
| feature_vector = self.dropout(feature_vector).to(device) | |
| # scoring | |
| scores = self.mlp_scorer(feature_vector).view(-1).to(device) | |
| for i in unk_loc_li: | |
| # ์ถ๊ฐํ ์์ | |
| new_element = torch.tensor([-0.9000], requires_grad=True).to(device) | |
| # ํน์ ์ธ๋ฑ์ค์ ์์๋ฅผ ์ถ๊ฐํ๊ธฐ ์ํด torch.cat()๊ณผ ์ฌ๋ผ์ด์ฑ์ ์ฌ์ฉํฉ๋๋ค. | |
| index_to_insert = i - 1 | |
| scores = torch.cat((scores[:index_to_insert], new_element, scores[index_to_insert:]), | |
| dim=0).to(device) | |
| scores_false = [scores[i] for i in range(scores.size(0)) if i != true_index] | |
| scores_true = [scores[true_index] for i in range(scores.size(0) - 1)] | |
| return scores, scores_false, scores_true | |