Upload 8_entities_top_span_self_ensemble_no_weight_20's state dict
Browse files- 8_entities_top_span_self_ensemble_no_weight_20/8_entities_top_span_self_ensemble_no_weight_20.py +44 -232
- 8_entities_top_span_self_ensemble_no_weight_20/lasts/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_last_ema.pth +1 -1
- 8_entities_top_span_self_ensemble_no_weight_20/logs/8_entities_top_span_self_ensemble_no_weight_20_log_plot.jpg +2 -2
- 8_entities_top_span_self_ensemble_no_weight_20/r1s/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_r1_vs1.69613_ema.pth +3 -0
- 8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx +0 -0
8_entities_top_span_self_ensemble_no_weight_20/8_entities_top_span_self_ensemble_no_weight_20.py
CHANGED
|
@@ -467,227 +467,6 @@ class SpanErrorAnalyzer:
|
|
| 467 |
"details": detailed_errors
|
| 468 |
}
|
| 469 |
|
| 470 |
-
# %% [code]
|
| 471 |
-
class DataParallelProxy(nn.DataParallel):
|
| 472 |
-
|
| 473 |
-
def __getattr__(self, name):
|
| 474 |
-
try:
|
| 475 |
-
return super().__getattr__(name)
|
| 476 |
-
|
| 477 |
-
except AttributeError:
|
| 478 |
-
|
| 479 |
-
attr = getattr(self.module, name)
|
| 480 |
-
|
| 481 |
-
if callable(attr):
|
| 482 |
-
|
| 483 |
-
def wrapper(*args, **kwargs):
|
| 484 |
-
return self._parallel_apply_method(
|
| 485 |
-
name,
|
| 486 |
-
*args,
|
| 487 |
-
**kwargs
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
return wrapper
|
| 491 |
-
|
| 492 |
-
return attr
|
| 493 |
-
|
| 494 |
-
# =========================================================
|
| 495 |
-
# parallel custom method
|
| 496 |
-
# =========================================================
|
| 497 |
-
|
| 498 |
-
def _parallel_apply_method(self, method_name, *inputs, **kwargs):
|
| 499 |
-
|
| 500 |
-
if not self.device_ids:
|
| 501 |
-
return getattr(self.module, method_name)(*inputs, **kwargs)
|
| 502 |
-
|
| 503 |
-
inputs_scattered, kwargs_scattered = self.scatter(
|
| 504 |
-
inputs,
|
| 505 |
-
kwargs,
|
| 506 |
-
self.device_ids
|
| 507 |
-
)
|
| 508 |
-
|
| 509 |
-
replicas = self.replicate(
|
| 510 |
-
self.module,
|
| 511 |
-
self.device_ids[:len(inputs_scattered)]
|
| 512 |
-
)
|
| 513 |
-
|
| 514 |
-
outputs = self.parallel_apply(
|
| 515 |
-
[getattr(replica, method_name) for replica in replicas],
|
| 516 |
-
inputs_scattered,
|
| 517 |
-
kwargs_scattered
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
return self._custom_gather(
|
| 521 |
-
outputs,
|
| 522 |
-
self.output_device
|
| 523 |
-
)
|
| 524 |
-
|
| 525 |
-
# =========================================================
|
| 526 |
-
# OVERRIDE FORWARD GATHER
|
| 527 |
-
# =========================================================
|
| 528 |
-
|
| 529 |
-
def gather(self, outputs, output_device):
|
| 530 |
-
|
| 531 |
-
return self._custom_gather(
|
| 532 |
-
outputs,
|
| 533 |
-
output_device
|
| 534 |
-
)
|
| 535 |
-
|
| 536 |
-
# =========================================================
|
| 537 |
-
# recursive gather
|
| 538 |
-
# =========================================================
|
| 539 |
-
|
| 540 |
-
def _custom_gather(self, outputs, output_device):
|
| 541 |
-
|
| 542 |
-
first = outputs[0]
|
| 543 |
-
|
| 544 |
-
# =====================================================
|
| 545 |
-
# tensor
|
| 546 |
-
# =====================================================
|
| 547 |
-
|
| 548 |
-
if torch.is_tensor(first):
|
| 549 |
-
|
| 550 |
-
return self._gather_tensor(
|
| 551 |
-
outputs,
|
| 552 |
-
output_device
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
-
# =====================================================
|
| 556 |
-
# tuple
|
| 557 |
-
# =====================================================
|
| 558 |
-
|
| 559 |
-
if isinstance(first, tuple):
|
| 560 |
-
|
| 561 |
-
return tuple(
|
| 562 |
-
self._custom_gather(
|
| 563 |
-
list(items),
|
| 564 |
-
output_device
|
| 565 |
-
)
|
| 566 |
-
for items in zip(*outputs)
|
| 567 |
-
)
|
| 568 |
-
|
| 569 |
-
# =====================================================
|
| 570 |
-
# list
|
| 571 |
-
# =====================================================
|
| 572 |
-
|
| 573 |
-
if isinstance(first, list):
|
| 574 |
-
|
| 575 |
-
# list[tensor]
|
| 576 |
-
if len(first) > 0 and torch.is_tensor(first[0]):
|
| 577 |
-
|
| 578 |
-
return self._gather_tensor_list(
|
| 579 |
-
outputs,
|
| 580 |
-
output_device
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
merged = []
|
| 584 |
-
|
| 585 |
-
for out in outputs:
|
| 586 |
-
merged.extend(out)
|
| 587 |
-
|
| 588 |
-
return merged
|
| 589 |
-
|
| 590 |
-
# =====================================================
|
| 591 |
-
# dict
|
| 592 |
-
# =====================================================
|
| 593 |
-
|
| 594 |
-
if isinstance(first, dict):
|
| 595 |
-
|
| 596 |
-
return {
|
| 597 |
-
k: self._custom_gather(
|
| 598 |
-
[o[k] for o in outputs],
|
| 599 |
-
output_device
|
| 600 |
-
)
|
| 601 |
-
for k in first.keys()
|
| 602 |
-
}
|
| 603 |
-
|
| 604 |
-
# =====================================================
|
| 605 |
-
# fallback
|
| 606 |
-
# =====================================================
|
| 607 |
-
|
| 608 |
-
return outputs
|
| 609 |
-
|
| 610 |
-
# =========================================================
|
| 611 |
-
# gather tensor with auto pad
|
| 612 |
-
# =========================================================
|
| 613 |
-
|
| 614 |
-
def _gather_tensor(self, tensors, output_device):
|
| 615 |
-
|
| 616 |
-
# move same device first
|
| 617 |
-
tensors = [
|
| 618 |
-
t.to(output_device)
|
| 619 |
-
for t in tensors
|
| 620 |
-
]
|
| 621 |
-
|
| 622 |
-
# =====================================================
|
| 623 |
-
# fast path
|
| 624 |
-
# =====================================================
|
| 625 |
-
|
| 626 |
-
try:
|
| 627 |
-
return torch.cat(tensors, dim=0)
|
| 628 |
-
|
| 629 |
-
except RuntimeError:
|
| 630 |
-
pass
|
| 631 |
-
|
| 632 |
-
# =====================================================
|
| 633 |
-
# auto max shape
|
| 634 |
-
# =====================================================
|
| 635 |
-
|
| 636 |
-
max_shape = list(tensors[0].shape)
|
| 637 |
-
|
| 638 |
-
for t in tensors[1:]:
|
| 639 |
-
|
| 640 |
-
for d in range(len(max_shape)):
|
| 641 |
-
|
| 642 |
-
max_shape[d] = max(
|
| 643 |
-
max_shape[d],
|
| 644 |
-
t.shape[d]
|
| 645 |
-
)
|
| 646 |
-
|
| 647 |
-
# =====================================================
|
| 648 |
-
# pad tensors
|
| 649 |
-
# =====================================================
|
| 650 |
-
|
| 651 |
-
padded = []
|
| 652 |
-
|
| 653 |
-
for t in tensors:
|
| 654 |
-
|
| 655 |
-
pad = []
|
| 656 |
-
|
| 657 |
-
# reverse order for F.pad
|
| 658 |
-
for d in reversed(range(len(max_shape))):
|
| 659 |
-
|
| 660 |
-
# never pad batch dim
|
| 661 |
-
if d == 0:
|
| 662 |
-
pad.extend([0, 0])
|
| 663 |
-
continue
|
| 664 |
-
|
| 665 |
-
diff = max_shape[d] - t.shape[d]
|
| 666 |
-
|
| 667 |
-
pad.extend([0, diff])
|
| 668 |
-
|
| 669 |
-
t = F.pad(t, pad)
|
| 670 |
-
|
| 671 |
-
padded.append(t)
|
| 672 |
-
|
| 673 |
-
return torch.cat(padded, dim=0)
|
| 674 |
-
|
| 675 |
-
# =========================================================
|
| 676 |
-
# list[tensor]
|
| 677 |
-
# =========================================================
|
| 678 |
-
|
| 679 |
-
def _gather_tensor_list(self, outputs, output_device):
|
| 680 |
-
|
| 681 |
-
merged = []
|
| 682 |
-
|
| 683 |
-
for out in outputs:
|
| 684 |
-
merged.extend(out)
|
| 685 |
-
|
| 686 |
-
return self._gather_tensor(
|
| 687 |
-
merged,
|
| 688 |
-
output_device
|
| 689 |
-
)
|
| 690 |
-
|
| 691 |
# %% [code]
|
| 692 |
## Viết cấu trúc model vào đây
|
| 693 |
def get_span_reprs(hidden, spans):
|
|
@@ -885,7 +664,7 @@ class IEModel(nn.Module):
|
|
| 885 |
return start_logits, end_logits, logits, spans
|
| 886 |
|
| 887 |
def test_model():
|
| 888 |
-
model =
|
| 889 |
model.eval()
|
| 890 |
total_params = sum(p.numel() for p in model.parameters())
|
| 891 |
print(f"Total params: {total_params:,}")
|
|
@@ -984,18 +763,41 @@ def fmt(x):
|
|
| 984 |
return x
|
| 985 |
|
| 986 |
class ModelEmaV3Proxy(ModelEmaV3):
|
| 987 |
-
|
| 988 |
def __getattr__(self, name):
|
| 989 |
-
|
| 990 |
try:
|
| 991 |
return super().__getattr__(name)
|
| 992 |
-
|
| 993 |
except AttributeError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
|
| 995 |
-
|
| 996 |
-
|
|
|
|
|
|
|
| 997 |
|
| 998 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
def align(
|
| 1001 |
all_spans, # (B, N, 2)
|
|
@@ -1277,7 +1079,7 @@ class Trainer:
|
|
| 1277 |
self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
|
| 1278 |
|
| 1279 |
try:
|
| 1280 |
-
teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
|
| 1281 |
train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
|
| 1282 |
logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
|
| 1283 |
logging_dict.update(train_loss_epoch_dict)
|
|
@@ -1455,11 +1257,17 @@ class Trainer:
|
|
| 1455 |
start_labels = batch['start_labels'].to(self.device)
|
| 1456 |
end_labels = batch['end_labels'].to(self.device)
|
| 1457 |
|
|
|
|
|
|
|
|
|
|
| 1458 |
choice = random.random()
|
| 1459 |
if choice < teaching_rate:
|
| 1460 |
-
|
| 1461 |
else:
|
| 1462 |
-
start_logits, end_logits,
|
|
|
|
|
|
|
|
|
|
| 1463 |
|
| 1464 |
align_labels = align(all_spans, pred_spans, all_labels, -100)
|
| 1465 |
align_weights = align(all_spans, pred_spans, all_weights, 0)
|
|
@@ -1481,7 +1289,11 @@ class Trainer:
|
|
| 1481 |
|
| 1482 |
B, _, _ = input_ids.shape
|
| 1483 |
|
| 1484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1485 |
|
| 1486 |
gold_list, pred_list = extract_spans(all_spans, all_labels, pred_spans)
|
| 1487 |
gold_list = list_to_tuple(gold_list)
|
|
|
|
| 467 |
"details": detailed_errors
|
| 468 |
}
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
# %% [code]
|
| 471 |
## Viết cấu trúc model vào đây
|
| 472 |
def get_span_reprs(hidden, spans):
|
|
|
|
| 664 |
return start_logits, end_logits, logits, spans
|
| 665 |
|
| 666 |
def test_model():
|
| 667 |
+
model = nn.DataParallel(IEModel(backbone_model_name, 7, 10, 100, 0)).to(device)
|
| 668 |
model.eval()
|
| 669 |
total_params = sum(p.numel() for p in model.parameters())
|
| 670 |
print(f"Total params: {total_params:,}")
|
|
|
|
| 763 |
return x
|
| 764 |
|
| 765 |
class ModelEmaV3Proxy(ModelEmaV3):
|
|
|
|
| 766 |
def __getattr__(self, name):
|
|
|
|
| 767 |
try:
|
| 768 |
return super().__getattr__(name)
|
|
|
|
| 769 |
except AttributeError:
|
| 770 |
+
return getattr(self.module, name)
|
| 771 |
+
|
| 772 |
+
class DataParallelProxy(nn.DataParallel):
|
| 773 |
+
def __getattr__(self, name):
|
| 774 |
+
try:
|
| 775 |
+
return super().__getattr__(name)
|
| 776 |
+
except AttributeError:
|
| 777 |
+
attr = getattr(self.module, name)
|
| 778 |
|
| 779 |
+
if callable(attr):
|
| 780 |
+
def wrapper(*args, **kwargs):
|
| 781 |
+
return self._parallel_apply_method(name, *args, **kwargs)
|
| 782 |
+
return wrapper
|
| 783 |
|
| 784 |
+
return attr
|
| 785 |
+
|
| 786 |
+
def _parallel_apply_method(self, method_name, *inputs, **kwargs):
|
| 787 |
+
if not self.device_ids:
|
| 788 |
+
return getattr(self.module, method_name)(*inputs, **kwargs)
|
| 789 |
+
|
| 790 |
+
inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
|
| 791 |
+
|
| 792 |
+
replicas = self.replicate(self.module, self.device_ids)
|
| 793 |
+
|
| 794 |
+
outputs = self.parallel_apply(
|
| 795 |
+
[getattr(replica, method_name) for replica in replicas],
|
| 796 |
+
inputs_scattered,
|
| 797 |
+
kwargs_scattered
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
return self.gather(outputs, self.output_device)
|
| 801 |
|
| 802 |
def align(
|
| 803 |
all_spans, # (B, N, 2)
|
|
|
|
| 1079 |
self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
|
| 1080 |
|
| 1081 |
try:
|
| 1082 |
+
teaching_rate = math.cos(math.pi / 2 * (epoch - 2) / (epochs - 2)) if epoch - 2 > 0 else 1.0
|
| 1083 |
train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
|
| 1084 |
logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
|
| 1085 |
logging_dict.update(train_loss_epoch_dict)
|
|
|
|
| 1257 |
start_labels = batch['start_labels'].to(self.device)
|
| 1258 |
end_labels = batch['end_labels'].to(self.device)
|
| 1259 |
|
| 1260 |
+
hidden_states, attention_mask = network.encode(input_ids, attention_mask)
|
| 1261 |
+
start_logits, end_logits = network.get_token_logits(hidden_states)
|
| 1262 |
+
|
| 1263 |
choice = random.random()
|
| 1264 |
if choice < teaching_rate:
|
| 1265 |
+
pred_spans = all_spans
|
| 1266 |
else:
|
| 1267 |
+
pred_spans = filter_spans(start_logits, end_logits, attention_mask, network.max_span_len, network.topk_spans, network.keep_neighbor)
|
| 1268 |
+
|
| 1269 |
+
span_reprs = get_span_reprs(hidden_states, pred_spans)
|
| 1270 |
+
logits = network.get_logits(span_reprs)
|
| 1271 |
|
| 1272 |
align_labels = align(all_spans, pred_spans, all_labels, -100)
|
| 1273 |
align_weights = align(all_spans, pred_spans, all_weights, 0)
|
|
|
|
| 1289 |
|
| 1290 |
B, _, _ = input_ids.shape
|
| 1291 |
|
| 1292 |
+
hidden_states, attention_mask = network.encode(input_ids, attention_mask)
|
| 1293 |
+
start_logits, end_logits = network.get_token_logits(hidden_states)
|
| 1294 |
+
pred_spans = filter_spans(start_logits, end_logits, attention_mask, network.max_span_len, network.topk_spans, network.keep_neighbor)
|
| 1295 |
+
span_reprs = get_span_reprs(hidden_states, pred_spans)
|
| 1296 |
+
logits = network.get_logits(span_reprs)
|
| 1297 |
|
| 1298 |
gold_list, pred_list = extract_spans(all_spans, all_labels, pred_spans)
|
| 1299 |
gold_list = list_to_tuple(gold_list)
|
8_entities_top_span_self_ensemble_no_weight_20/lasts/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_last_ema.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 554305230
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d8c6d8cdb0343a8286e705230b1c6c65d137b9bdfb533daecdb805d22dcc40d
|
| 3 |
size 554305230
|
8_entities_top_span_self_ensemble_no_weight_20/logs/8_entities_top_span_self_ensemble_no_weight_20_log_plot.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
8_entities_top_span_self_ensemble_no_weight_20/r1s/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_r1_vs1.69613_ema.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a044d1c538e5ae2c18d2452a9f44819803179f0f5628a82370cb36341392abb3
|
| 3 |
+
size 554320278
|
8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx
CHANGED
|
Binary files a/8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx and b/8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx differ
|
|
|