| """Tests for the core FmriEncoder model.""" |
|
|
| from unittest.mock import MagicMock |
|
|
| import pytest |
| import torch |
|
|
|
|
| def _make_model(hidden=256, n_outputs=100, n_timesteps=10, modalities=None): |
| """Build a small FmriEncoderModel for testing.""" |
| from neuraltrain.models.transformer import TransformerEncoder |
|
|
| from cortexlab.core.model import FmriEncoder |
|
|
| if modalities is None: |
| modalities = {"text": (2, 32), "audio": (2, 32), "video": (2, 32)} |
|
|
| config = FmriEncoder( |
| hidden=hidden, |
| max_seq_len=128, |
| dropout=0.0, |
| modality_dropout=0.0, |
| temporal_dropout=0.0, |
| linear_baseline=False, |
| encoder=TransformerEncoder(depth=2, heads=4), |
| ) |
| model = config.build( |
| feature_dims=modalities, |
| n_outputs=n_outputs, |
| n_output_timesteps=n_timesteps, |
| ) |
| return model |
|
|
|
|
| def _make_segments(n): |
| """Create dummy segments for SegmentData.""" |
| import neuralset.segments as seg |
| return [seg.Segment(start=float(i), duration=1.0, timeline="test") for i in range(n)] |
|
|
|
|
| def _make_batch(modalities, batch_size=2, seq_len=20): |
| """Create a synthetic SegmentData-like batch.""" |
| from neuralset.dataloader import SegmentData |
|
|
| data = {} |
| for name, (n_layers, feat_dim) in modalities.items(): |
| data[name] = torch.randn(batch_size, n_layers, feat_dim, seq_len) |
| data["subject_id"] = torch.zeros(batch_size, dtype=torch.long) |
| return SegmentData(data=data, segments=_make_segments(batch_size)) |
|
|
|
|
| class TestFmriEncoderModel: |
| def test_forward_shape(self): |
| modalities = {"text": (2, 32), "audio": (2, 32)} |
| model = _make_model(modalities=modalities) |
| batch = _make_batch(modalities) |
| out = model(batch) |
| assert out.shape == (2, 100, 10), f"Expected (2, 100, 10), got {out.shape}" |
|
|
| def test_forward_no_pool(self): |
| modalities = {"text": (2, 32)} |
| model = _make_model(modalities=modalities) |
| batch = _make_batch(modalities) |
| out = model(batch, pool_outputs=False) |
| assert out.shape[0] == 2 |
| assert out.shape[1] == 100 |
|
|
| def test_return_attn(self): |
| modalities = {"text": (2, 32)} |
| model = _make_model(modalities=modalities) |
| batch = _make_batch(modalities) |
| result = model(batch, return_attn=True) |
| assert isinstance(result, tuple) |
| out, attn_maps = result |
| assert out.shape == (2, 100, 10) |
| |
| assert isinstance(attn_maps, list) |
|
|
| def test_missing_modality_zeros(self): |
| modalities = {"text": (2, 32), "audio": (2, 32)} |
| model = _make_model(modalities=modalities) |
| |
| from neuralset.dataloader import SegmentData |
| data = {"text": torch.randn(2, 2, 32, 20), "subject_id": torch.zeros(2, dtype=torch.long)} |
| batch = SegmentData(data=data, segments=_make_segments(2)) |
| out = model(batch) |
| assert out.shape == (2, 100, 10) |
|
|
| def test_modality_dropout_training(self): |
| modalities = {"text": (2, 32), "audio": (2, 32)} |
| from neuraltrain.models.transformer import TransformerEncoder |
|
|
| from cortexlab.core.model import FmriEncoder |
| config = FmriEncoder( |
| hidden=256, max_seq_len=128, modality_dropout=0.5, |
| encoder=TransformerEncoder(depth=2, heads=4), |
| ) |
| model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10) |
| model.train() |
| batch = _make_batch(modalities) |
| out = model(batch) |
| assert out.shape == (2, 100, 10) |
|
|
| def test_linear_baseline(self): |
| modalities = {"text": (2, 32)} |
| from cortexlab.core.model import FmriEncoder |
| config = FmriEncoder(hidden=256, linear_baseline=True) |
| model = config.build(feature_dims=modalities, n_outputs=100, n_output_timesteps=10) |
| batch = _make_batch(modalities) |
| out = model(batch) |
| assert out.shape == (2, 100, 10) |
|
|