| | import torch
|
| | from safetensors.torch import load_file
|
| |
|
| | def load_model(path='model.safetensors'):
|
| | return load_file(path)
|
| |
|
| | def xor2_from_weights(a, b, w, or_w, or_b, nand_w, nand_b, and_w, and_b):
|
| | """Compute XOR(a,b) using threshold gates"""
|
| | inp = torch.tensor([float(a), float(b)])
|
| | or_out = float((inp * or_w).sum() + or_b >= 0)
|
| | nand_out = float((inp * nand_w).sum() + nand_b >= 0)
|
| | l1 = torch.tensor([or_out, nand_out])
|
| | return int((l1 * and_w).sum() + and_b >= 0)
|
| |
|
| | def hamming74_encode(d1, d2, d3, d4, w):
|
| | """Hamming(7,4) encoder: 4 data bits -> 7 coded bits"""
|
| | inp = torch.tensor([float(d1), float(d2), float(d3), float(d4)])
|
| |
|
| |
|
| | or_out = float((inp * w['p1.xor12.layer1.or.weight']).sum() + w['p1.xor12.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp * w['p1.xor12.layer1.nand.weight']).sum() + w['p1.xor12.layer1.nand.bias'] >= 0)
|
| | xor12 = int((torch.tensor([or_out, nand_out]) * w['p1.xor12.layer2.weight']).sum() + w['p1.xor12.layer2.bias'] >= 0)
|
| |
|
| | inp2 = torch.tensor([float(xor12), float(d4)])
|
| | or_out = float((inp2 * w['p1.xor_final.layer1.or.weight']).sum() + w['p1.xor_final.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp2 * w['p1.xor_final.layer1.nand.weight']).sum() + w['p1.xor_final.layer1.nand.bias'] >= 0)
|
| | p1 = int((torch.tensor([or_out, nand_out]) * w['p1.xor_final.layer2.weight']).sum() + w['p1.xor_final.layer2.bias'] >= 0)
|
| |
|
| |
|
| | or_out = float((inp * w['p2.xor13.layer1.or.weight']).sum() + w['p2.xor13.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp * w['p2.xor13.layer1.nand.weight']).sum() + w['p2.xor13.layer1.nand.bias'] >= 0)
|
| | xor13 = int((torch.tensor([or_out, nand_out]) * w['p2.xor13.layer2.weight']).sum() + w['p2.xor13.layer2.bias'] >= 0)
|
| |
|
| | inp2 = torch.tensor([float(xor13), float(d4)])
|
| | or_out = float((inp2 * w['p2.xor_final.layer1.or.weight']).sum() + w['p2.xor_final.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp2 * w['p2.xor_final.layer1.nand.weight']).sum() + w['p2.xor_final.layer1.nand.bias'] >= 0)
|
| | p2 = int((torch.tensor([or_out, nand_out]) * w['p2.xor_final.layer2.weight']).sum() + w['p2.xor_final.layer2.bias'] >= 0)
|
| |
|
| |
|
| | or_out = float((inp * w['p3.xor23.layer1.or.weight']).sum() + w['p3.xor23.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp * w['p3.xor23.layer1.nand.weight']).sum() + w['p3.xor23.layer1.nand.bias'] >= 0)
|
| | xor23 = int((torch.tensor([or_out, nand_out]) * w['p3.xor23.layer2.weight']).sum() + w['p3.xor23.layer2.bias'] >= 0)
|
| |
|
| | inp2 = torch.tensor([float(xor23), float(d4)])
|
| | or_out = float((inp2 * w['p3.xor_final.layer1.or.weight']).sum() + w['p3.xor_final.layer1.or.bias'] >= 0)
|
| | nand_out = float((inp2 * w['p3.xor_final.layer1.nand.weight']).sum() + w['p3.xor_final.layer1.nand.bias'] >= 0)
|
| | p3 = int((torch.tensor([or_out, nand_out]) * w['p3.xor_final.layer2.weight']).sum() + w['p3.xor_final.layer2.bias'] >= 0)
|
| |
|
| |
|
| | c3 = int((inp * w['d1.weight']).sum() + w['d1.bias'] >= 0)
|
| | c5 = int((inp * w['d2.weight']).sum() + w['d2.bias'] >= 0)
|
| | c6 = int((inp * w['d3.weight']).sum() + w['d3.bias'] >= 0)
|
| | c7 = int((inp * w['d4.weight']).sum() + w['d4.bias'] >= 0)
|
| |
|
| |
|
| | return [p1, p2, c3, p3, c5, c6, c7]
|
| |
|
| | if __name__ == '__main__':
|
| | w = load_model()
|
| | print('Hamming(7,4) Encoder')
|
| | print('Input (d1d2d3d4) -> Output (c1c2c3c4c5c6c7)')
|
| |
|
| | def ref_encode(d1, d2, d3, d4):
|
| | p1 = d1 ^ d2 ^ d4
|
| | p2 = d1 ^ d3 ^ d4
|
| | p3 = d2 ^ d3 ^ d4
|
| | return [p1, p2, d1, p3, d2, d3, d4]
|
| |
|
| | errors = 0
|
| | for d in range(16):
|
| | d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
|
| | result = hamming74_encode(d1, d2, d3, d4, w)
|
| | expected = ref_encode(d1, d2, d3, d4)
|
| | status = 'OK' if result == expected else 'FAIL'
|
| | if result != expected:
|
| | errors += 1
|
| | r_str = ''.join(map(str, result))
|
| | print(f'{d1}{d2}{d3}{d4} -> {r_str} {status}')
|
| |
|
| | print(f'\n{16-errors}/16 correct')
|
| |
|