SS3M commited on
Commit
e2a67cd
·
verified ·
1 Parent(s): b934157

Upload 8_entities_top_span_self_ensemble_no_weight_20's state dict

Browse files
8_entities_top_span_self_ensemble_no_weight_20/8_entities_top_span_self_ensemble_no_weight_20.py CHANGED
@@ -467,227 +467,6 @@ class SpanErrorAnalyzer:
467
  "details": detailed_errors
468
  }
469
 
470
- # %% [code]
471
- class DataParallelProxy(nn.DataParallel):
472
-
473
- def __getattr__(self, name):
474
- try:
475
- return super().__getattr__(name)
476
-
477
- except AttributeError:
478
-
479
- attr = getattr(self.module, name)
480
-
481
- if callable(attr):
482
-
483
- def wrapper(*args, **kwargs):
484
- return self._parallel_apply_method(
485
- name,
486
- *args,
487
- **kwargs
488
- )
489
-
490
- return wrapper
491
-
492
- return attr
493
-
494
- # =========================================================
495
- # parallel custom method
496
- # =========================================================
497
-
498
- def _parallel_apply_method(self, method_name, *inputs, **kwargs):
499
-
500
- if not self.device_ids:
501
- return getattr(self.module, method_name)(*inputs, **kwargs)
502
-
503
- inputs_scattered, kwargs_scattered = self.scatter(
504
- inputs,
505
- kwargs,
506
- self.device_ids
507
- )
508
-
509
- replicas = self.replicate(
510
- self.module,
511
- self.device_ids[:len(inputs_scattered)]
512
- )
513
-
514
- outputs = self.parallel_apply(
515
- [getattr(replica, method_name) for replica in replicas],
516
- inputs_scattered,
517
- kwargs_scattered
518
- )
519
-
520
- return self._custom_gather(
521
- outputs,
522
- self.output_device
523
- )
524
-
525
- # =========================================================
526
- # OVERRIDE FORWARD GATHER
527
- # =========================================================
528
-
529
- def gather(self, outputs, output_device):
530
-
531
- return self._custom_gather(
532
- outputs,
533
- output_device
534
- )
535
-
536
- # =========================================================
537
- # recursive gather
538
- # =========================================================
539
-
540
- def _custom_gather(self, outputs, output_device):
541
-
542
- first = outputs[0]
543
-
544
- # =====================================================
545
- # tensor
546
- # =====================================================
547
-
548
- if torch.is_tensor(first):
549
-
550
- return self._gather_tensor(
551
- outputs,
552
- output_device
553
- )
554
-
555
- # =====================================================
556
- # tuple
557
- # =====================================================
558
-
559
- if isinstance(first, tuple):
560
-
561
- return tuple(
562
- self._custom_gather(
563
- list(items),
564
- output_device
565
- )
566
- for items in zip(*outputs)
567
- )
568
-
569
- # =====================================================
570
- # list
571
- # =====================================================
572
-
573
- if isinstance(first, list):
574
-
575
- # list[tensor]
576
- if len(first) > 0 and torch.is_tensor(first[0]):
577
-
578
- return self._gather_tensor_list(
579
- outputs,
580
- output_device
581
- )
582
-
583
- merged = []
584
-
585
- for out in outputs:
586
- merged.extend(out)
587
-
588
- return merged
589
-
590
- # =====================================================
591
- # dict
592
- # =====================================================
593
-
594
- if isinstance(first, dict):
595
-
596
- return {
597
- k: self._custom_gather(
598
- [o[k] for o in outputs],
599
- output_device
600
- )
601
- for k in first.keys()
602
- }
603
-
604
- # =====================================================
605
- # fallback
606
- # =====================================================
607
-
608
- return outputs
609
-
610
- # =========================================================
611
- # gather tensor with auto pad
612
- # =========================================================
613
-
614
- def _gather_tensor(self, tensors, output_device):
615
-
616
- # move same device first
617
- tensors = [
618
- t.to(output_device)
619
- for t in tensors
620
- ]
621
-
622
- # =====================================================
623
- # fast path
624
- # =====================================================
625
-
626
- try:
627
- return torch.cat(tensors, dim=0)
628
-
629
- except RuntimeError:
630
- pass
631
-
632
- # =====================================================
633
- # auto max shape
634
- # =====================================================
635
-
636
- max_shape = list(tensors[0].shape)
637
-
638
- for t in tensors[1:]:
639
-
640
- for d in range(len(max_shape)):
641
-
642
- max_shape[d] = max(
643
- max_shape[d],
644
- t.shape[d]
645
- )
646
-
647
- # =====================================================
648
- # pad tensors
649
- # =====================================================
650
-
651
- padded = []
652
-
653
- for t in tensors:
654
-
655
- pad = []
656
-
657
- # reverse order for F.pad
658
- for d in reversed(range(len(max_shape))):
659
-
660
- # never pad batch dim
661
- if d == 0:
662
- pad.extend([0, 0])
663
- continue
664
-
665
- diff = max_shape[d] - t.shape[d]
666
-
667
- pad.extend([0, diff])
668
-
669
- t = F.pad(t, pad)
670
-
671
- padded.append(t)
672
-
673
- return torch.cat(padded, dim=0)
674
-
675
- # =========================================================
676
- # list[tensor]
677
- # =========================================================
678
-
679
- def _gather_tensor_list(self, outputs, output_device):
680
-
681
- merged = []
682
-
683
- for out in outputs:
684
- merged.extend(out)
685
-
686
- return self._gather_tensor(
687
- merged,
688
- output_device
689
- )
690
-
691
  # %% [code]
692
  ## Viết cấu trúc model vào đây
693
  def get_span_reprs(hidden, spans):
@@ -885,7 +664,7 @@ class IEModel(nn.Module):
885
  return start_logits, end_logits, logits, spans
886
 
887
  def test_model():
888
- model = DataParallelProxy(IEModel(backbone_model_name, 7, 10, 100, 2)).to(device)
889
  model.eval()
890
  total_params = sum(p.numel() for p in model.parameters())
891
  print(f"Total params: {total_params:,}")
@@ -984,18 +763,41 @@ def fmt(x):
984
  return x
985
 
986
  class ModelEmaV3Proxy(ModelEmaV3):
987
-
988
  def __getattr__(self, name):
989
-
990
  try:
991
  return super().__getattr__(name)
992
-
993
  except AttributeError:
 
 
 
 
 
 
 
 
994
 
995
- # tránh recursion
996
- module = object.__getattribute__(self, "module")
 
 
997
 
998
- return getattr(module, name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999
 
1000
  def align(
1001
  all_spans, # (B, N, 2)
@@ -1277,7 +1079,7 @@ class Trainer:
1277
  self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1278
 
1279
  try:
1280
- teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1281
  train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1282
  logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1283
  logging_dict.update(train_loss_epoch_dict)
@@ -1455,11 +1257,17 @@ class Trainer:
1455
  start_labels = batch['start_labels'].to(self.device)
1456
  end_labels = batch['end_labels'].to(self.device)
1457
 
 
 
 
1458
  choice = random.random()
1459
  if choice < teaching_rate:
1460
- start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask, all_spans)
1461
  else:
1462
- start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask)
 
 
 
1463
 
1464
  align_labels = align(all_spans, pred_spans, all_labels, -100)
1465
  align_weights = align(all_spans, pred_spans, all_weights, 0)
@@ -1481,7 +1289,11 @@ class Trainer:
1481
 
1482
  B, _, _ = input_ids.shape
1483
 
1484
- start_logits, end_logits, logits, pred_spans = network(input_ids, attention_mask)
 
 
 
 
1485
 
1486
  gold_list, pred_list = extract_spans(all_spans, all_labels, pred_spans)
1487
  gold_list = list_to_tuple(gold_list)
 
467
  "details": detailed_errors
468
  }
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  # %% [code]
471
  ## Viết cấu trúc model vào đây
472
  def get_span_reprs(hidden, spans):
 
664
  return start_logits, end_logits, logits, spans
665
 
666
  def test_model():
667
+ model = nn.DataParallel(IEModel(backbone_model_name, 7, 10, 100, 0)).to(device)
668
  model.eval()
669
  total_params = sum(p.numel() for p in model.parameters())
670
  print(f"Total params: {total_params:,}")
 
763
  return x
764
 
765
  class ModelEmaV3Proxy(ModelEmaV3):
 
766
  def __getattr__(self, name):
 
767
  try:
768
  return super().__getattr__(name)
 
769
  except AttributeError:
770
+ return getattr(self.module, name)
771
+
772
+ class DataParallelProxy(nn.DataParallel):
773
+ def __getattr__(self, name):
774
+ try:
775
+ return super().__getattr__(name)
776
+ except AttributeError:
777
+ attr = getattr(self.module, name)
778
 
779
+ if callable(attr):
780
+ def wrapper(*args, **kwargs):
781
+ return self._parallel_apply_method(name, *args, **kwargs)
782
+ return wrapper
783
 
784
+ return attr
785
+
786
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
787
+ if not self.device_ids:
788
+ return getattr(self.module, method_name)(*inputs, **kwargs)
789
+
790
+ inputs_scattered, kwargs_scattered = self.scatter(inputs, kwargs, self.device_ids)
791
+
792
+ replicas = self.replicate(self.module, self.device_ids)
793
+
794
+ outputs = self.parallel_apply(
795
+ [getattr(replica, method_name) for replica in replicas],
796
+ inputs_scattered,
797
+ kwargs_scattered
798
+ )
799
+
800
+ return self.gather(outputs, self.output_device)
801
 
802
  def align(
803
  all_spans, # (B, N, 2)
 
1079
  self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1080
 
1081
  try:
1082
+ teaching_rate = math.cos(math.pi / 2 * (epoch - 2) / (epochs - 2)) if epoch - 2 > 0 else 1.0
1083
  train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1084
  logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1085
  logging_dict.update(train_loss_epoch_dict)
 
1257
  start_labels = batch['start_labels'].to(self.device)
1258
  end_labels = batch['end_labels'].to(self.device)
1259
 
1260
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
1261
+ start_logits, end_logits = network.get_token_logits(hidden_states)
1262
+
1263
  choice = random.random()
1264
  if choice < teaching_rate:
1265
+ pred_spans = all_spans
1266
  else:
1267
+ pred_spans = filter_spans(start_logits, end_logits, attention_mask, network.max_span_len, network.topk_spans, network.keep_neighbor)
1268
+
1269
+ span_reprs = get_span_reprs(hidden_states, pred_spans)
1270
+ logits = network.get_logits(span_reprs)
1271
 
1272
  align_labels = align(all_spans, pred_spans, all_labels, -100)
1273
  align_weights = align(all_spans, pred_spans, all_weights, 0)
 
1289
 
1290
  B, _, _ = input_ids.shape
1291
 
1292
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
1293
+ start_logits, end_logits = network.get_token_logits(hidden_states)
1294
+ pred_spans = filter_spans(start_logits, end_logits, attention_mask, network.max_span_len, network.topk_spans, network.keep_neighbor)
1295
+ span_reprs = get_span_reprs(hidden_states, pred_spans)
1296
+ logits = network.get_logits(span_reprs)
1297
 
1298
  gold_list, pred_list = extract_spans(all_spans, all_labels, pred_spans)
1299
  gold_list = list_to_tuple(gold_list)
8_entities_top_span_self_ensemble_no_weight_20/lasts/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_last_ema.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c52d0aa8377ce14c4dc4daf9db1c971b886711672876e1cc2f88bd4d80d18b8
3
  size 554305230
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d8c6d8cdb0343a8286e705230b1c6c65d137b9bdfb533daecdb805d22dcc40d
3
  size 554305230
8_entities_top_span_self_ensemble_no_weight_20/logs/8_entities_top_span_self_ensemble_no_weight_20_log_plot.jpg CHANGED

Git LFS Details

  • SHA256: fe67ca8c0085699d76124508ff20813b3839fc3b5683fa38f329f994a40e4a9d
  • Pointer size: 131 Bytes
  • Size of remote file: 555 kB

Git LFS Details

  • SHA256: f8035be3c272db30e6225ee3b7eabc177febbdcd216cefab24d2ae4e2f94b1e2
  • Pointer size: 131 Bytes
  • Size of remote file: 547 kB
8_entities_top_span_self_ensemble_no_weight_20/r1s/8_entities_top_span_self_ensemble_no_weight_20_s26092004_f0_r1_vs1.69613_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a044d1c538e5ae2c18d2452a9f44819803179f0f5628a82370cb36341392abb3
3
+ size 554320278
8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx CHANGED
Binary files a/8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx and b/8_entities_top_span_self_ensemble_no_weight_20/results/8_entities_top_span_self_ensemble_no_weight_20_test_df_.xlsx differ