File size: 397 Bytes
df6cf36 | 1 2 3 4 5 6 7 8 9 10 11 12 | import torch
from core.fluid_transformer import FluidTransformer
class SyncManager:
def __init__(self, vocab_size=256):
self.heads = [FluidTransformer(vocab_size) for _ in range(3)]
def forward(self, input_ids):
# Consensus: Average logits across all three heads
logits_list = [h(input_ids) for h in self.heads]
return sum(logits_list) / len(logits_list)
|