| import torch
|
| import logging
|
| import os
|
|
|
| logger = logging.getLogger(__name__)
|
| from torchvision import transforms
|
| from PIL import Image
|
|
|
|
|
| class SBInputExample(object):
|
| """A single training/test example for simple sequence classification."""
|
|
|
| def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None):
|
| """Constructs a InputExample.
|
|
|
| Args:
|
| guid: Unique id for the example.
|
| text_a: string. The untokenized text of the first sequence. For single
|
| sequence tasks, only this sequence must be specified.
|
| text_b: (Optional) string. The untokenized text of the second sequence.
|
| Only must be specified for sequence pair tasks.
|
| label: (Optional) string. The label of the example. This should be
|
| specified for train and dev examples, but not for test examples.
|
| """
|
| self.guid = guid
|
| self.text_a = text_a
|
| self.text_b = text_b
|
| self.img_id = img_id
|
|
|
|
|
|
|
|
|
| class SBInputFeatures(object):
|
| """A single set of features of data"""
|
|
|
| def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat):
|
| self.input_ids = input_ids
|
| self.input_mask = input_mask
|
| self.added_input_mask = added_input_mask
|
| self.segment_ids = segment_ids
|
| self.img_feat = img_feat
|
|
|
|
|
| def sbreadfile(filename):
|
| '''
|
| Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các từ và danh sách hình ảnh.
|
| '''
|
| print("Chuẩn bị dữ liệu từ", filename)
|
| with open(filename, encoding='utf8') as f:
|
| data = []
|
| imgs = []
|
| sentence = []
|
| imgid = ''
|
| for line in f:
|
| line = line.strip()
|
| if line.startswith('IMGID:'):
|
| imgid = line.split('IMGID:')[1] + '.jpg'
|
| continue
|
| if line == '':
|
| if len(sentence) > 0:
|
| data.append(sentence)
|
| imgs.append(imgid)
|
| sentence = []
|
| imgid = ''
|
| continue
|
| word = line.split('\t')[0]
|
| sentence.append(word)
|
|
|
| if len(sentence) > 0:
|
| data.append(sentence)
|
| imgs.append(imgid)
|
|
|
| print("Số lượng mẫu: " + str(len(data)))
|
| print("Số lượng hình ảnh: " + str(len(imgs)))
|
| return data, imgs
|
|
|
|
|
| class DataProcessor(object):
|
| """Base class for data converters for sequence classification data sets."""
|
|
|
| def get_train_examples(self, data_dir):
|
| """Gets a collection of `InputExample`s for the train set."""
|
| raise NotImplementedError()
|
|
|
| def get_dev_examples(self, data_dir):
|
| """Gets a collection of `InputExample`s for the dev set."""
|
| raise NotImplementedError()
|
|
|
| def get_labels(self):
|
| """Gets the list of labels for this data set."""
|
| raise NotImplementedError()
|
|
|
| @classmethod
|
| def _read_sbtsv(cls, input_file, quotechar=None):
|
| """Reads a tab separated value file."""
|
| return sbreadfile(input_file)
|
|
|
|
|
| class MNERProcessor(DataProcessor):
|
| """Processor for the CoNLL-2003 data set."""
|
|
|
| def get_train_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs = self._read_sbtsv(os.path.join(data_dir, "train.txt"))
|
| return self._create_examples(data, imgs, "train")
|
|
|
| def get_dev_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs = self._read_sbtsv(os.path.join(data_dir, "dev.txt"))
|
| return self._create_examples(data, imgs, "dev")
|
|
|
| def get_test_examples(self, data_dir):
|
| """See base class."""
|
| data, imgs = self._read_sbtsv(os.path.join(data_dir, "test.txt"))
|
| return self._create_examples(data, imgs, "test")
|
|
|
| def get_labels(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| return [
|
| "I-LOC", "B-MISC",
|
| "I-PER",
|
| "I-ORG",
|
| "B-LOC",
|
| "I-MISC",
|
| "B-ORG",
|
| "O",
|
| "B-PER",
|
| "X",
|
| "<s>",
|
| "</s>"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_auxlabels(self):
|
| return ["O", "B", "I", "X", "<s>", "</s>"]
|
|
|
| def get_start_label_id(self):
|
| label_list = self.get_labels()
|
| label_map = {label: i for i, label in enumerate(label_list, 1)}
|
| return label_map['<s>']
|
|
|
| def get_stop_label_id(self):
|
| label_list = self.get_labels()
|
| label_map = {label: i for i, label in enumerate(label_list, 1)}
|
| return label_map['</s>']
|
|
|
| def _create_examples(self, lines, imgs, set_type):
|
| examples = []
|
| for i, (sentence) in enumerate(lines):
|
| guid = "%s-%s" % (set_type, i)
|
| text_a = ' '.join(sentence)
|
| text_b = None
|
| img_id = imgs[i]
|
| examples.append(
|
| SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
|
| return examples
|
|
|
|
|
| def create_examples(lines, imgs, set_type):
|
| examples = []
|
| for i, (sentence) in enumerate(lines):
|
| guid = "%s-%s" % (set_type, i)
|
| text_a = ' '.join(sentence)
|
| text_b = None
|
| img_id = imgs[i]
|
| examples.append(
|
| SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
|
| return examples
|
|
|
|
|
| def get_test_examples_predict(data_dir):
|
| """See base class."""
|
| data, imgs = sbreadfile(os.path.join(data_dir, "test.txt"))
|
| return create_examples(data, imgs, "test")
|
|
|
|
|
| def image_process(image_path, transform):
|
| image = Image.open(image_path).convert('RGB')
|
| image = transform(image)
|
| return image
|
|
|
|
|
| def convert_mm_examples_to_features_predict(examples,
|
| max_seq_length, tokenizer, crop_size, path_img):
|
| features = []
|
| count = 0
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize([256, 256]),
|
| transforms.RandomCrop(crop_size),
|
| transforms.RandomHorizontalFlip(),
|
| transforms.ToTensor(),
|
| transforms.Normalize((0.485, 0.456, 0.406),
|
| (0.229, 0.224, 0.225))])
|
|
|
| for (ex_index, example) in enumerate(examples):
|
| textlist = example.text_a.split(' ')
|
| tokens = []
|
| for i, word in enumerate(textlist):
|
| token = tokenizer.tokenize(word)
|
| tokens.extend(token)
|
| if len(tokens) >= max_seq_length - 1:
|
| tokens = tokens[0:(max_seq_length - 2)]
|
| ntokens = []
|
| segment_ids = []
|
| ntokens.append("<s>")
|
| segment_ids.append(0)
|
| for i, token in enumerate(tokens):
|
| ntokens.append(token)
|
| segment_ids.append(0)
|
| ntokens.append("</s>")
|
| segment_ids.append(0)
|
| input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
| input_mask = [1] * len(input_ids)
|
| added_input_mask = [1] * (len(input_ids) + 49)
|
|
|
| while len(input_ids) < max_seq_length:
|
| input_ids.append(0)
|
| input_mask.append(0)
|
| added_input_mask.append(0)
|
| segment_ids.append(0)
|
|
|
| assert len(input_ids) == max_seq_length
|
| assert len(input_mask) == max_seq_length
|
| assert len(segment_ids) == max_seq_length
|
|
|
| image_name = example.img_id
|
| image_path = os.path.join(path_img, image_name)
|
|
|
| if not os.path.exists(image_path):
|
| if 'NaN' not in image_path:
|
| print(image_path)
|
| try:
|
| image = image_process(image_path, transform)
|
| except:
|
| count += 1
|
| image_path_fail = os.path.join(path_img, 'background.jpg')
|
| image = image_process(image_path_fail, transform)
|
|
|
| else:
|
| if ex_index < 1:
|
| logger.info("*** Example ***")
|
| logger.info("guid: %s" % (example.guid))
|
| logger.info("tokens: %s" % " ".join(
|
| [str(x) for x in tokens]))
|
| logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
| logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
| logger.info(
|
| "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
|
|
| features.append(
|
| SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
|
| segment_ids=segment_ids, img_feat=image))
|
|
|
| print('the number of problematic samples: ' + str(count))
|
| return features |