| import torch
|
| import logging
|
| import os
|
|
|
| logger = logging.getLogger(__name__)
|
| from torchvision import transforms
|
| from PIL import Image
|
|
|
|
|
| class SBInputExample(object):
|
| """A single training/test example for simple sequence classification."""
|
|
|
| def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None):
|
| """Constructs a InputExample.
|
|
|
| Args:
|
| guid: Unique id for the example.
|
| text_a: string. The untokenized text of the first sequence. For single
|
| sequence tasks, only this sequence must be specified.
|
| text_b: (Optional) string. The untokenized text of the second sequence.
|
| Only must be specified for sequence pair tasks.
|
| label: (Optional) string. The label of the example. This should be
|
| specified for train and dev examples, but not for test examples.
|
| """
|
| self.guid = guid
|
| self.text_a = text_a
|
| self.text_b = text_b
|
| self.img_id = img_id
|
| self.label = label
|
|
|
|
|
| self.auxlabel = auxlabel
|
|
|
|
|
| class SBInputFeatures(object):
|
| """A single set of features of data"""
|
|
|
| def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat, label_id, auxlabel_id):
|
| self.input_ids = input_ids
|
| self.input_mask = input_mask
|
| self.added_input_mask = added_input_mask
|
| self.segment_ids = segment_ids
|
| self.img_feat = img_feat
|
| self.label_id = label_id
|
| self.auxlabel_id = auxlabel_id
|
|
|
|
|
| def sbreadfile(filename):
|
| '''
|
| Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các cặp từ và nhãn, cùng với danh sách hình ảnh và nhãn phụ.
|
| '''
|
| print("Chuẩn bị dữ liệu cho ", filename)
|
| f = open(filename, encoding='utf8')
|
| data = []
|
| imgs = []
|
| auxlabels = []
|
| sentence = []
|
| label = []
|
| auxlabel = []
|
| imgid = ''
|
|
|
| for line in f:
|
| line = line.strip()
|
| if line.startswith('IMGID:'):
|
| imgid = line.split('IMGID:')[1] + '.jpg'
|
| continue
|
| if line == '':
|
| if len(sentence) > 0:
|
| data.append((sentence, label))
|
| imgs.append(imgid)
|
| auxlabels.append(auxlabel)
|
| sentence = []
|
| label = []
|
| auxlabel = []
|
| imgid = ''
|
| continue
|
| splits = line.split('\t')
|
| if len(splits) == 2:
|
| word, cur_label = splits
|
| sentence.append(word)
|
| label.append(cur_label)
|
| auxlabel.append(cur_label[0])
|
|
|
| if len(sentence) > 0:
|
| data.append((sentence, label))
|
| imgs.append(imgid)
|
| auxlabels.append(auxlabel)
|
|
|
| print("Số lượng mẫu: " + str(len(data)))
|
| print("Số lượng hình ảnh: " + str(len(imgs)))
|
| return data, imgs, auxlabels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class DataProcessor(object):
|
| """Base class for data converters for sequence classification data sets."""
|
|
|
| def get_train_examples(self, data_dir):
|
| """Gets a collection of `InputExample`s for the train set."""
|
| raise NotImplementedError()
|
|
|
| def get_dev_examples(self, data_dir):
|
| """Gets a collection of `InputExample`s for the dev set."""
|
| raise NotImplementedError()
|
|
|
| def get_labels(self):
|
| """Gets the list of labels for this data set."""
|
| raise NotImplementedError()
|
|
|
| @classmethod
|
| def _read_sbtsv(cls, input_file, quotechar=None):
|
| """Reads a tab separated value file."""
|
| return sbreadfile(input_file)
|
|
|
|
|
| class MNERProcessor_2021(DataProcessor):
|
| """Processor for the CoNLL-2003 data set."""
|
|
|
| def get_train_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "train.txt"))
|
| return self._create_examples(data, imgs, auxlabels, "train")
|
|
|
| def get_dev_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "dev.txt"))
|
| return self._create_examples(data, imgs, auxlabels, "dev")
|
|
|
| def get_test_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "test.txt"))
|
| return self._create_examples(data, imgs, auxlabels, "test")
|
|
|
| def get_labels(self):
|
| return [
|
| "O",
|
| "I-PRODUCT-AWARD",
|
| "B-MISCELLANEOUS",
|
| "B-QUANTITY-NUM",
|
| "B-ORGANIZATION-SPORTS",
|
| "B-DATETIME",
|
| "I-ADDRESS",
|
| "I-PERSON",
|
| "I-EVENT-SPORT",
|
| "B-ADDRESS",
|
| "B-EVENT-NATURAL",
|
| "I-LOCATION-GPE",
|
| "B-EVENT-GAMESHOW",
|
| "B-DATETIME-TIMERANGE",
|
| "I-QUANTITY-NUM",
|
| "I-QUANTITY-AGE",
|
| "B-EVENT-CUL",
|
| "I-QUANTITY-TEM",
|
| "I-PRODUCT-LEGAL",
|
| "I-LOCATION-STRUC",
|
| "I-ORGANIZATION",
|
| "B-PHONENUMBER",
|
| "B-IP",
|
| "B-QUANTITY-AGE",
|
| "I-DATETIME-TIME",
|
| "I-DATETIME",
|
| "B-ORGANIZATION-MED",
|
| "B-DATETIME-SET",
|
| "I-EVENT-CUL",
|
| "B-QUANTITY-DIM",
|
| "I-QUANTITY-DIM",
|
| "B-EVENT",
|
| "B-DATETIME-DATERANGE",
|
| "I-EVENT-GAMESHOW",
|
| "B-PRODUCT-AWARD",
|
| "B-LOCATION-STRUC",
|
| "B-LOCATION",
|
| "B-PRODUCT",
|
| "I-MISCELLANEOUS",
|
| "B-SKILL",
|
| "I-QUANTITY-ORD",
|
| "I-ORGANIZATION-STOCK",
|
| "I-LOCATION-GEO",
|
| "B-PERSON",
|
| "B-PRODUCT-COM",
|
| "B-PRODUCT-LEGAL",
|
| "I-LOCATION",
|
| "B-QUANTITY-TEM",
|
| "I-PRODUCT",
|
| "B-QUANTITY-CUR",
|
| "I-QUANTITY-CUR",
|
| "B-LOCATION-GPE",
|
| "I-PHONENUMBER",
|
| "I-ORGANIZATION-MED",
|
| "I-EVENT-NATURAL",
|
| "I-EMAIL",
|
| "B-ORGANIZATION",
|
| "B-URL",
|
| "I-DATETIME-TIMERANGE",
|
| "I-QUANTITY",
|
| "I-IP",
|
| "B-EVENT-SPORT",
|
| "B-PERSONTYPE",
|
| "B-QUANTITY-PER",
|
| "I-QUANTITY-PER",
|
| "I-PRODUCT-COM",
|
| "I-DATETIME-DURATION",
|
| "B-LOCATION-GPE-GEO",
|
| "B-QUANTITY-ORD",
|
| "I-EVENT",
|
| "B-DATETIME-TIME",
|
| "B-QUANTITY",
|
| "I-DATETIME-SET",
|
| "I-LOCATION-GPE-GEO",
|
| "B-ORGANIZATION-STOCK",
|
| "I-ORGANIZATION-SPORTS",
|
| "I-SKILL",
|
| "I-URL",
|
| "B-DATETIME-DURATION",
|
| "I-DATETIME-DATE",
|
| "I-PERSONTYPE",
|
| "B-DATETIME-DATE",
|
| "I-DATETIME-DATERANGE",
|
| "B-LOCATION-GEO",
|
| "B-EMAIL",
|
| "X",
|
| "<s>",
|
| "</s>"
|
| ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_auxlabels(self):
|
| return ["O", "B", "I", "X", "<s>", "</s>"]
|
|
|
| def get_start_label_id(self):
|
| label_list = self.get_labels()
|
| label_map = {label: i for i, label in enumerate(label_list, 1)}
|
| return label_map['<s>']
|
|
|
| def get_stop_label_id(self):
|
| label_list = self.get_labels()
|
| label_map = {label: i for i, label in enumerate(label_list, 1)}
|
| return label_map['</s>']
|
|
|
| def _create_examples(self, lines, imgs, auxlabels, set_type):
|
| examples = []
|
| for i, (sentence, label) in enumerate(lines):
|
| guid = "%s-%s" % (set_type, i)
|
| text_a = ' '.join(sentence)
|
| text_b = None
|
| img_id = imgs[i]
|
| label = label
|
| auxlabel = auxlabels[i]
|
| examples.append(
|
| SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id, label=label, auxlabel=auxlabel))
|
| return examples
|
|
|
|
|
| def image_process(image_path, transform):
|
| image = Image.open(image_path).convert('RGB')
|
| image = transform(image)
|
| return image
|
|
|
|
|
| def convert_mm_examples_to_features(examples, label_list, auxlabel_list,
|
| max_seq_length, tokenizer, crop_size, path_img):
|
| label_map = {label: i for i, label in enumerate(label_list, 1)}
|
| auxlabel_map = {label: i for i, label in enumerate(auxlabel_list, 1)}
|
|
|
| features = []
|
| count = 0
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize([256, 256]),
|
| transforms.RandomCrop(crop_size),
|
| transforms.RandomHorizontalFlip(),
|
| transforms.ToTensor(),
|
| transforms.Normalize((0.485, 0.456, 0.406),
|
| (0.229, 0.224, 0.225))])
|
|
|
| for (ex_index, example) in enumerate(examples):
|
| textlist = example.text_a.split(' ')
|
| labellist = example.label
|
| auxlabellist = example.auxlabel
|
| tokens = []
|
| labels = []
|
| auxlabels = []
|
| for i, word in enumerate(textlist):
|
| token = tokenizer.tokenize(word)
|
| tokens.extend(token)
|
| label_1 = labellist[i]
|
| auxlabel_1 = auxlabellist[i]
|
| for m in range(len(token)):
|
| if m == 0:
|
| labels.append(label_1)
|
| auxlabels.append(auxlabel_1)
|
| else:
|
| labels.append("X")
|
| auxlabels.append("X")
|
| if len(tokens) >= max_seq_length - 1:
|
| tokens = tokens[0:(max_seq_length - 2)]
|
| labels = labels[0:(max_seq_length - 2)]
|
| auxlabels = auxlabels[0:(max_seq_length - 2)]
|
| ntokens = []
|
| segment_ids = []
|
| label_ids = []
|
| auxlabel_ids = []
|
| ntokens.append("<s>")
|
| segment_ids.append(0)
|
| label_ids.append(label_map["<s>"])
|
| auxlabel_ids.append(auxlabel_map["<s>"])
|
| for i, token in enumerate(tokens):
|
| ntokens.append(token)
|
| segment_ids.append(0)
|
| label_ids.append(label_map[labels[i]])
|
| auxlabel_ids.append(auxlabel_map[auxlabels[i]])
|
| ntokens.append("</s>")
|
| segment_ids.append(0)
|
| label_ids.append(label_map["</s>"])
|
| auxlabel_ids.append(auxlabel_map["</s>"])
|
| input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
| input_mask = [1] * len(input_ids)
|
| added_input_mask = [1] * (len(input_ids) + 49)
|
|
|
| while len(input_ids) < max_seq_length:
|
| input_ids.append(0)
|
| input_mask.append(0)
|
| added_input_mask.append(0)
|
| segment_ids.append(0)
|
| label_ids.append(0)
|
| auxlabel_ids.append(0)
|
|
|
| assert len(input_ids) == max_seq_length
|
| assert len(input_mask) == max_seq_length
|
| assert len(segment_ids) == max_seq_length
|
| assert len(label_ids) == max_seq_length
|
| assert len(auxlabel_ids) == max_seq_length
|
|
|
| image_name = example.img_id
|
| image_path = os.path.join(path_img, image_name)
|
|
|
| if not os.path.exists(image_path):
|
| if 'NaN' not in image_path:
|
| print(image_path)
|
| try:
|
| image = image_process(image_path, transform)
|
| except:
|
| count += 1
|
| image_path_fail = os.path.join(path_img, 'background.jpg')
|
| image = image_process(image_path_fail, transform)
|
|
|
| else:
|
| if ex_index < 2:
|
| logger.info("*** Example ***")
|
| logger.info("guid: %s" % (example.guid))
|
| logger.info("tokens: %s" % " ".join(
|
| [str(x) for x in tokens]))
|
| logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
| logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
| logger.info(
|
| "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
| logger.info("label: %s" % " ".join([str(x) for x in label_ids]))
|
| logger.info("auxlabel: %s" % " ".join([str(x) for x in auxlabel_ids]))
|
|
|
| features.append(
|
| SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
|
| segment_ids=segment_ids, img_feat=image, label_id=label_ids, auxlabel_id=auxlabel_ids))
|
|
|
| print('the number of problematic samples: ' + str(count))
|
| return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |