File size: 6,494 Bytes
92bfe31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from __future__ import annotations

import os
import sys
from unittest.mock import patch

import pytest

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from services import inference_client as inf_client
from services.inference_client import (
    _MODEL_PROFILES,
    get_current_runtime_config,
    get_model_for_task,
    is_sequential_model,
    model_supports_thinking,
    reset_runtime_overrides,
    set_runtime_model_override,
    set_runtime_model_profile,
)


REQUIRED_PROFILE_KEYS = {
    "INFERENCE_MODEL_ID", "INFERENCE_CHAT_MODEL_ID",
    "HF_QUIZ_MODEL_ID", "HF_RAG_MODEL_ID", "INFERENCE_LOCK_MODEL_ID",
}


class TestModelProfiles:
    def test_profiles_have_all_keys(self):
        for name, profile in _MODEL_PROFILES.items():
            assert REQUIRED_PROFILE_KEYS == set(profile.keys()), \
                f"Profile '{name}' missing or extra keys"

    def test_dev_uses_chat_model(self):
        dev = _MODEL_PROFILES["dev"]
        for key, value in dev.items():
            assert "deepseek-chat" in value, f"dev/{key} = {value}, expected deepseek-chat"

    def test_prod_chat_is_chat_model(self):
        assert "deepseek-chat" in _MODEL_PROFILES["prod"]["INFERENCE_CHAT_MODEL_ID"]

    def test_prod_rag_is_reasoner(self):
        assert "deepseek-reasoner" in _MODEL_PROFILES["prod"]["HF_RAG_MODEL_ID"]

    def test_budget_uses_chat_model_everywhere(self):
        budget = _MODEL_PROFILES["budget"]
        for key, value in budget.items():
            assert "deepseek-chat" in value, f"budget/{key} = {value}"


class TestRuntimeOverrides:

    def setup_method(self):
        reset_runtime_overrides()

    def teardown_method(self):
        reset_runtime_overrides()

    def test_set_profile_populates_overrides(self):
        set_runtime_model_profile("dev")
        assert inf_client._RUNTIME_PROFILE == "dev"
        assert inf_client._RUNTIME_OVERRIDES["INFERENCE_MODEL_ID"] == "deepseek-chat"
        assert inf_client._RUNTIME_OVERRIDES["INFERENCE_CHAT_MODEL_ID"] == "deepseek-chat"

    def test_set_profile_replaces_all_overrides(self):
        set_runtime_model_profile("dev")
        set_runtime_model_profile("prod")
        assert inf_client._RUNTIME_OVERRIDES["INFERENCE_CHAT_MODEL_ID"] == "deepseek-chat"
        assert inf_client._RUNTIME_OVERRIDES["INFERENCE_LOCK_MODEL_ID"] == "deepseek-chat"

    def test_set_profile_unknown_raises(self):
        with pytest.raises(ValueError, match="Unknown profile"):
            set_runtime_model_profile("nonexistent")

    def test_single_override_sets_key(self):
        set_runtime_model_override("HF_RAG_MODEL_ID", "custom/model")
        assert inf_client._RUNTIME_OVERRIDES["HF_RAG_MODEL_ID"] == "custom/model"

    def test_reset_clears_overrides(self):
        set_runtime_model_profile("dev")
        reset_runtime_overrides()
        assert inf_client._RUNTIME_PROFILE == ""
        assert inf_client._RUNTIME_OVERRIDES == {}

    def test_override_layers_on_profile(self):
        set_runtime_model_profile("dev")
        set_runtime_model_override("HF_RAG_MODEL_ID", "custom/model")
        assert inf_client._RUNTIME_OVERRIDES["HF_RAG_MODEL_ID"] == "custom/model"
        assert inf_client._RUNTIME_OVERRIDES["INFERENCE_MODEL_ID"] == "deepseek-chat"


class TestGetCurrentRuntimeConfig:

    def setup_method(self):
        reset_runtime_overrides()

    def teardown_method(self):
        reset_runtime_overrides()

    def test_returns_resolved_dict_with_all_keys(self):
        set_runtime_model_profile("dev")
        config = get_current_runtime_config()
        assert config["profile"] == "dev"
        for key in REQUIRED_PROFILE_KEYS:
            assert key in config["resolved"], f"Missing {key}"

    def test_override_takes_priority_over_profile(self):
        set_runtime_model_profile("dev")
        set_runtime_model_override("INFERENCE_CHAT_MODEL_ID", "custom/chat")
        config = get_current_runtime_config()
        assert config["resolved"]["INFERENCE_CHAT_MODEL_ID"] == "custom/chat"


class TestGetModelForTask:

    def setup_method(self):
        reset_runtime_overrides()

    def teardown_method(self):
        reset_runtime_overrides()

    @patch.dict(os.environ, {"INFERENCE_ENFORCE_LOCK_MODEL": "false"})
    def test_returns_profile_default_for_rag(self):
        set_runtime_model_profile("prod")
        model = get_model_for_task("rag_lesson")
        assert "deepseek-reasoner" in model

    @patch.dict(os.environ, {"INFERENCE_ENFORCE_LOCK_MODEL": "false"})
    def test_returns_profile_default_for_chat(self):
        set_runtime_model_profile("prod")
        model = get_model_for_task("chat")
        assert "deepseek-chat" in model

    @patch.dict(os.environ, {"INFERENCE_ENFORCE_LOCK_MODEL": "false"})
    def test_returns_runtime_override_for_chat(self):
        set_runtime_model_override("INFERENCE_CHAT_MODEL_ID", "custom/chat")
        model = get_model_for_task("chat")
        assert model == "custom/chat"

    @patch.dict(os.environ, {"INFERENCE_ENFORCE_LOCK_MODEL": "true"})
    def test_enforce_qwen_overrides_task(self):
        set_runtime_model_profile("prod")
        model = get_model_for_task("rag_lesson")
        assert "deepseek-chat" in model


class TestIsSequentialModel:

    def setup_method(self):
        reset_runtime_overrides()

    def teardown_method(self):
        reset_runtime_overrides()

    def test_reasoner_is_sequential(self):
        assert is_sequential_model("deepseek-reasoner") is True

    def test_chat_is_not_sequential(self):
        assert is_sequential_model("deepseek-chat") is False

    def test_empty_string_checks_env(self):
        result = is_sequential_model("")
        assert result is True or result is False

    @patch.dict(os.environ, {"INFERENCE_MODEL_ID": "deepseek-reasoner"})
    def test_env_model_reasoner_is_sequential(self):
        assert is_sequential_model("") is True

    @patch.dict(os.environ, {"INFERENCE_MODEL_ID": "deepseek-chat"})
    def test_env_model_chat_is_not_sequential(self):
        assert is_sequential_model("") is False


class TestModelSupportsThinking:

    def test_reasoner_supports_thinking(self):
        assert model_supports_thinking("deepseek-reasoner") is True

    def test_chat_does_not_support_thinking(self):
        assert model_supports_thinking("deepseek-chat") is False

    def test_unknown_does_not_support_thinking(self):
        assert model_supports_thinking("meta-llama/Llama-3.1-8B-Instruct") is False