radinplaid commited on
Commit
b6b0c93
·
1 Parent(s): e732115

Initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ curl \
10
+ python3-pip \
11
+ python3 \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy local code to the container
15
+ COPY . /app
16
+
17
+ # Create a non-root user for security
18
+ RUN useradd -m -u 1001 user
19
+
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH
22
+
23
+ WORKDIR $HOME/app
24
+ COPY --chown=user . $HOME/app
25
+
26
+ # Install the package and dependencies
27
+ # This also installs the quickmt cli scripts
28
+ RUN pip install --break-system-packages --no-cache-dir /app/
29
+
30
+ # Expose the default FastAPI port
31
+ EXPOSE 7860
32
+
33
+ USER user
34
+
35
+ # Hf Spaces expect the app on port 7860 usually
36
+ # We override the port via env var or CLI arg
37
+ CMD ["uvicorn", "quickmt.rest_server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,65 @@
1
- ---
2
- title: Quickmt Gui
3
- emoji: 🐨
4
- colorFrom: green
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- short_description: 'QuickMT Web Application '
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # `quickmt` Neural Machine Translation Inference Library
2
+
3
+ ## REST Server Features
4
+
5
+ - **Dynamic Batching**: Multiple concurrent HTTP requests are pooled together to maximize GPU utilization.
6
+ - **Multi-Model Support**: Requests are routed to specific models based on `src_lang` and `tgt_lang`.
7
+ - **LRU Cache**: Automatically loads and unloads models based on usage to manage memory.
8
+
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ pip install -r requirements.txt
14
+ ```
15
+
16
+ ## Running the Web Application
17
+
18
+ ```bash
19
+ export MAX_LOADED_MODELS=3
20
+ export MAX_BATCH_SIZE=32
21
+ export DEVICE=cuda # or cpu
22
+ export COMPUTE_TYPE=int8 # default, auto, int8, float16, etc.
23
+ quickmt-gui
24
+ ```
25
+
26
+
27
+
28
+ ## Running the REST Server
29
+
30
+ ```bash
31
+ export MAX_LOADED_MODELS=3
32
+ export MAX_BATCH_SIZE=32
33
+ export DEVICE=cuda # or cpu
34
+ export COMPUTE_TYPE=int8 # default, auto, int8, float16, etc.
35
+ quickmt-api
36
+ ```
37
+
38
+
39
+ ## API Usage
40
+
41
+ ### Translate
42
+ ```bash
43
+ curl -X POST http://localhost:8000/translate \
44
+ -H "Content-Type: application/json" \
45
+ -d '{"src":"Hello world","src_lang":null,"tgt_lang":"fr","beam_size":2,"patience":1,"length_penalty":1,"coverage_penalty":0,"repetition_penalty":1}'
46
+ ```
47
+
48
+ Returns:
49
+ ```json
50
+ {
51
+ "translation":"Bonjour tout le monde !",
52
+ "src_lang":"en",
53
+ "src_lang_score":0.16532786190509796,
54
+ "tgt_lang":"fr",
55
+ "processing_time":2.2334513664245605,
56
+ "model_used":"quickmt/quickmt-en-fr"
57
+ }
58
+ ```
59
+
60
+ ## Load Testing with Locust
61
+ To simulate a multi-user load:
62
+ ```bash
63
+ locust -f locustfile.py --host http://localhost:8000
64
+ ```
65
+ Then open http://localhost:8089 in your browser.
locustfile.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from locust import FastHttpUser, task, between
3
+
4
+
5
+ class TranslationUser(FastHttpUser):
6
+ wait_time = between(0, 0)
7
+
8
+ # Sample sentences for translation and identification
9
+ sample_texts = [
10
+ "The quick brown fox jumps over the lazy dog.",
11
+ "Can we translate this correctly and quickly?",
12
+ "هذا نص تجريبي باللغة العربية.", # Arabic
13
+ "الذكاء الاصطناعي هو المستقبل.", # Arabic (AI is the future)
14
+ "أحب تعلم لغات جديدة.", # Arabic (I love learning new languages)
15
+ "这是一段中文测试文本。", # Chinese
16
+ "人工智能正在改变世界。", # Chinese (AI is changing the world)
17
+ "今天天气真好,去公园散步。", # Chinese (Weather is nice, let's walk)
18
+ "Bonjour, comment allez-vous ?", # French
19
+ "L'intelligence artificielle transforme notre vie quotidienne.", # French (AI transforms daily life)
20
+ "Ceci est un exemple de phrase en français.", # French
21
+ ]
22
+
23
+ def on_start(self):
24
+ """Discover available models on startup."""
25
+ try:
26
+ response = self.client.get("/models")
27
+ if response.status_code == 200:
28
+ self.available_models = response.json().get("models", [])
29
+ if not self.available_models:
30
+ print("No models found. Load test might fail.")
31
+ else:
32
+ self.available_models = []
33
+ except Exception as e:
34
+ print(f"Error discovering models: {e}")
35
+ self.available_models = []
36
+
37
+ def get_random_model(self):
38
+ """
39
+ Return a model, favoring the first 3 (hot set) 99% of the time,
40
+ and others (cold set) 1% of the time to trigger LRU eviction.
41
+ """
42
+ if not self.available_models:
43
+ return None
44
+
45
+ # If we have 4 or more models, we can simulate eviction cycles
46
+ if len(self.available_models) >= 4:
47
+ # 99.99% chance to pick from the first 3
48
+ if random.random() < 0.9999:
49
+ return random.choice(self.available_models[:3])
50
+ else:
51
+ # 0.01% chance to pick from the rest
52
+ return random.choice(self.available_models[3:])
53
+
54
+ return random.choice(self.available_models)
55
+
56
+ @task(1)
57
+ def translate_single(self):
58
+ model = self.get_random_model()
59
+ if not model:
60
+ return
61
+
62
+ self.client.post(
63
+ "/translate",
64
+ json={
65
+ "src": random.choice(self.sample_texts) + str(random.random()),
66
+ "src_lang": model["src_lang"],
67
+ "tgt_lang": model["tgt_lang"],
68
+ "beam_size": 2,
69
+ },
70
+ name="/translate [single, manual]",
71
+ )
72
+
73
+ @task(1)
74
+ def translate_auto_detect(self):
75
+ """Translate without specifying src_lang to trigger LangID."""
76
+ ret = self.client.post(
77
+ "/translate",
78
+ json={
79
+ "src": random.choice(self.sample_texts) + str(random.random()),
80
+ "tgt_lang": "en",
81
+ "beam_size": 2,
82
+ },
83
+ name="/translate [single, auto-detect]",
84
+ )
85
+ ret_json = ret.json()
86
+ assert "src_lang" in ret_json
87
+ assert "tgt_lang" in ret_json
88
+ assert "translation" in ret_json
89
+ assert "src_lang_score" in ret_json
90
+ assert "model_used" in ret_json
91
+ assert ret_json["tgt_lang"] == "en"
92
+
93
+ @task(1)
94
+ def translate_list(self):
95
+ model = self.get_random_model()
96
+ if not model:
97
+ return
98
+
99
+ num_sentences = random.randint(2, 5)
100
+ texts = random.sample(self.sample_texts, num_sentences)
101
+ texts = [i + str(random.random()) for i in texts]
102
+ ret = self.client.post(
103
+ "/translate",
104
+ json={
105
+ "src": texts,
106
+ "src_lang": model["src_lang"],
107
+ "tgt_lang": model["tgt_lang"],
108
+ "beam_size": 2,
109
+ },
110
+ name="/translate [list, manual]",
111
+ )
112
+ ret_json = ret.json()
113
+ for i in ret_json["src_lang"]:
114
+ assert i == model["src_lang"]
115
+ assert ret_json["tgt_lang"] == model["tgt_lang"]
116
+ assert len(ret_json["translation"]) == num_sentences
117
+
118
+ @task(1)
119
+ def identify_language(self):
120
+ """Directly benchmark the identification endpoint."""
121
+ num_sentences = random.randint(1, 4)
122
+ texts = random.sample(self.sample_texts, num_sentences)
123
+ src = texts[0] if num_sentences == 1 else texts
124
+
125
+ self.client.post(
126
+ "/identify-language", json={"src": src}, name="/identify-language"
127
+ )
128
+
129
+ @task(1)
130
+ def health_check(self):
131
+ self.client.get("/health", name="/health")
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "quickmt"
7
+ version = "0.1.0"
8
+ description = "A fast, multi-model translation API based on CTranslate2 and FastAPI"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "QuickMT Team", email = "hello@quickmt.ai"},
14
+ ]
15
+ dependencies = [
16
+ "blingfire",
17
+ "cachetools",
18
+ "fastapi",
19
+ "uvicorn[standard]",
20
+ "ctranslate2>=3.20.0",
21
+ "sentencepiece",
22
+ "huggingface-hub",
23
+ "fasttext-wheel",
24
+ "orjson",
25
+ "uvloop",
26
+ "httptools",
27
+ "pydantic",
28
+ "pydantic-settings"
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest",
34
+ "pytest-asyncio",
35
+ "httpx",
36
+ "sacrebleu",
37
+ "locust"
38
+ ]
39
+
40
+ [project.scripts]
41
+ quickmt-serve = "quickmt.rest_server:start"
42
+ quickmt-gui = "quickmt.rest_server:start_gui"
43
+
44
+ [tool.hatch.build.targets.wheel]
45
+ packages = ["quickmt"]
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
3
+ asyncio_default_fixture_loop_scope = function
quickmt/__init__.py ADDED
File without changes
quickmt/gui/static/app.js ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', () => {
2
+ // Elements
3
+ const srcText = document.getElementById('src-text');
4
+ const tgtText = document.getElementById('tgt-text');
5
+ const srcLangSelect = document.getElementById('src-lang-select');
6
+ const tgtLangSelect = document.getElementById('tgt-lang-select');
7
+ const charCount = document.getElementById('char-count');
8
+ const timingInfo = document.getElementById('timing-info');
9
+ const loader = document.getElementById('translation-loader');
10
+ const detectedBadge = document.getElementById('detected-badge');
11
+ const navLinks = document.querySelectorAll('.nav-links a');
12
+ const views = document.querySelectorAll('.view');
13
+ const healthIndicator = document.getElementById('health-indicator');
14
+ const modelsList = document.getElementById('models-list');
15
+ const copyBtn = document.getElementById('copy-btn');
16
+ const themeToggle = document.getElementById('theme-toggle');
17
+ const sidebarToggle = document.getElementById('sidebar-toggle');
18
+ const sidebar = document.querySelector('.sidebar');
19
+
20
+ let debounceTimer;
21
+ let languages = {};
22
+ let activeController = null;
23
+
24
+ let settings = {
25
+ beam_size: 2,
26
+ patience: 1,
27
+ length_penalty: 1.0,
28
+ coverage_penalty: 0.0,
29
+ repetition_penalty: 1.0
30
+ };
31
+
32
+ // 0. Theme Logic
33
+ function initTheme() {
34
+ const savedTheme = localStorage.getItem('theme') || 'dark';
35
+ if (savedTheme === 'light') {
36
+ document.body.classList.add('light-mode');
37
+ updateThemeUI(true);
38
+ }
39
+ }
40
+
41
+ function updateThemeUI(isLight) {
42
+ const text = themeToggle.querySelector('.mode-text');
43
+ text.textContent = isLight ? 'Light Mode' : 'Dark Mode';
44
+ }
45
+
46
+ themeToggle.addEventListener('click', () => {
47
+ const isLight = document.body.classList.toggle('light-mode');
48
+ localStorage.setItem('theme', isLight ? 'light' : 'dark');
49
+ updateThemeUI(isLight);
50
+ });
51
+
52
+ // 0.1 Sidebar Logic
53
+ function initSidebar() {
54
+ const isCollapsed = localStorage.getItem('sidebar-collapsed') === 'true';
55
+ if (isCollapsed) sidebar.classList.add('collapsed');
56
+ }
57
+
58
+ sidebarToggle.addEventListener('click', () => {
59
+ const isCollapsed = sidebar.classList.toggle('collapsed');
60
+ localStorage.setItem('sidebar-collapsed', isCollapsed);
61
+ });
62
+
63
+ // 0.2 Inference Settings Logic
64
+ function initSettings() {
65
+ const saved = localStorage.getItem('inference-settings');
66
+ if (saved) {
67
+ try {
68
+ const parsed = JSON.parse(saved);
69
+ settings = { ...settings, ...parsed };
70
+ } catch (e) { console.error("Failed to parse settings", e); }
71
+ }
72
+ updateSettingsUI();
73
+ }
74
+
75
+ function updateSettingsUI() {
76
+ // Sync values to inputs
77
+ Object.keys(settings).forEach(key => {
78
+ const input = document.getElementById(`setting-${key.replace('_', '-')}`);
79
+ if (input) {
80
+ input.value = settings[key];
81
+ const valDisplay = input.nextElementSibling;
82
+ if (valDisplay && valDisplay.classList.contains('setting-val')) {
83
+ valDisplay.textContent = settings[key];
84
+ }
85
+ }
86
+ });
87
+ }
88
+
89
+ function saveSettings() {
90
+ localStorage.setItem('inference-settings', JSON.stringify(settings));
91
+ }
92
+
93
+ // Add listeners to all settings inputs
94
+ const settingsInputs = [
95
+ 'setting-beam-size', 'setting-patience', 'setting-length-penalty',
96
+ 'setting-coverage-penalty', 'setting-repetition-penalty'
97
+ ];
98
+
99
+ settingsInputs.forEach(id => {
100
+ const input = document.getElementById(id);
101
+ const key = id.replace('setting-', '').replace(/-/g, '_');
102
+
103
+ input.addEventListener('input', () => {
104
+ let val = parseFloat(input.value);
105
+ if (id === 'setting-beam-size' || id === 'setting-patience') val = parseInt(input.value);
106
+
107
+ settings[key] = val;
108
+
109
+ // Enforcement: patience <= beam_size
110
+ if (id === 'setting-beam-size') {
111
+ if (settings.patience > settings.beam_size) {
112
+ settings.patience = settings.beam_size;
113
+ const patienceInput = document.getElementById('setting-patience');
114
+ patienceInput.value = settings.patience;
115
+ patienceInput.nextElementSibling.textContent = settings.patience;
116
+ }
117
+ // Update patience max slider to match beam_size for better UX?
118
+ // User said "maximum 10", so let's stick to that but cap the value.
119
+ } else if (id === 'setting-patience') {
120
+ if (val > settings.beam_size) {
121
+ val = settings.beam_size;
122
+ input.value = val;
123
+ settings.patience = val;
124
+ }
125
+ }
126
+
127
+ const valDisplay = input.nextElementSibling;
128
+ if (valDisplay && valDisplay.classList.contains('setting-val')) {
129
+ valDisplay.textContent = val;
130
+ }
131
+ saveSettings();
132
+ });
133
+ });
134
+
135
+ document.getElementById('reset-settings').addEventListener('click', () => {
136
+ settings = {
137
+ beam_size: 2,
138
+ patience: 1,
139
+ length_penalty: 1.0,
140
+ coverage_penalty: 0.0,
141
+ repetition_penalty: 1.0
142
+ };
143
+ updateSettingsUI();
144
+ saveSettings();
145
+ });
146
+
147
+ // 1. Fetch available languages and populate selects
148
+ async function initLanguages() {
149
+ try {
150
+ const res = await fetch('/api/languages');
151
+ if (res.ok) {
152
+ languages = await res.json();
153
+ populateSelects();
154
+ updateHealth(true);
155
+ }
156
+ } catch (e) {
157
+ console.error("Failed to load languages", e);
158
+ updateHealth(false);
159
+ }
160
+ }
161
+
162
+ function populateSelects() {
163
+ const currentSrc = srcLangSelect.value;
164
+ // Keep only the first "Auto-detect" option
165
+ srcLangSelect.innerHTML = '<option value="">Auto-detect</option>';
166
+
167
+ const sources = Object.keys(languages);
168
+
169
+ // Populate Source Languages
170
+ sources.forEach(lang => {
171
+ const opt = document.createElement('option');
172
+ opt.value = lang;
173
+ opt.textContent = lang.toUpperCase();
174
+ srcLangSelect.appendChild(opt);
175
+ });
176
+
177
+ // Restore selection if it still exists
178
+ if (currentSrc && languages[currentSrc]) {
179
+ srcLangSelect.value = currentSrc;
180
+ }
181
+
182
+ // Trigger target population for default selection
183
+ updateTargetOptions();
184
+ }
185
+
186
+ function updateTargetOptions() {
187
+ const src = srcLangSelect.value;
188
+ const currentTgt = tgtLangSelect.value;
189
+
190
+ // Clear targets
191
+ tgtLangSelect.innerHTML = '';
192
+
193
+ let availableTgts = [];
194
+ if (src) {
195
+ availableTgts = languages[src] || [];
196
+ } else {
197
+ // If auto-detect, union of all targets
198
+ const allTgts = new Set();
199
+ Object.values(languages).forEach(list => list.forEach(l => allTgts.add(l)));
200
+ availableTgts = Array.from(allTgts).sort();
201
+ }
202
+
203
+ availableTgts.forEach(lang => {
204
+ const opt = document.createElement('option');
205
+ opt.value = lang;
206
+ opt.textContent = lang.toUpperCase();
207
+ if (lang === currentTgt || (availableTgts.length === 1)) opt.selected = true;
208
+ tgtLangSelect.appendChild(opt);
209
+ });
210
+ }
211
+
212
+ // 2. Translation Logic
213
+ async function performTranslation() {
214
+ const fullText = srcText.value;
215
+ if (!fullText.trim()) {
216
+ tgtText.value = '';
217
+ timingInfo.textContent = 'Ready';
218
+ detectedBadge.classList.remove('visible');
219
+ return;
220
+ }
221
+
222
+ // Abort previous requests
223
+ if (activeController) activeController.abort();
224
+ activeController = new AbortController();
225
+ const { signal } = activeController;
226
+
227
+ const lines = fullText.split('\n');
228
+ const translatedLines = new Array(lines.length).fill('');
229
+ let srcLang = srcLangSelect.value || null;
230
+ const tgtLang = tgtLangSelect.value;
231
+
232
+ loader.classList.remove('hidden');
233
+ let completedLines = 0;
234
+ let totalToTranslate = lines.filter(l => l.trim()).length;
235
+
236
+ try {
237
+ // Step 1: If auto-detect mode, detect language for entire input first
238
+ if (!srcLang && fullText.trim()) {
239
+ const detectResponse = await fetch('/api/identify-language', {
240
+ method: 'POST',
241
+ headers: { 'Content-Type': 'application/json' },
242
+ body: JSON.stringify({
243
+ src: fullText,
244
+ k: 1,
245
+ threshold: 0.0
246
+ }),
247
+ signal
248
+ });
249
+
250
+ if (detectResponse.ok) {
251
+ const detectData = await detectResponse.json();
252
+ // Get the detected language from the response
253
+ if (detectData.results && detectData.results.length > 0) {
254
+ srcLang = detectData.results[0].lang;
255
+ detectedBadge.textContent = `Detected: ${srcLang.toUpperCase()}`;
256
+ detectedBadge.classList.add('visible');
257
+ }
258
+ }
259
+ }
260
+
261
+ // Step 2: Translate all lines with known source language
262
+ const updateTgtUI = () => {
263
+ tgtText.value = translatedLines.join('\n');
264
+ };
265
+
266
+ const translateParagraph = async (line, index) => {
267
+ if (!line.trim()) {
268
+ translatedLines[index] = line;
269
+ updateTgtUI();
270
+ return;
271
+ }
272
+
273
+ try {
274
+ const response = await fetch('/api/translate', {
275
+ method: 'POST',
276
+ headers: { 'Content-Type': 'application/json' },
277
+ body: JSON.stringify({
278
+ src: line,
279
+ src_lang: srcLang, // Now we always have a source language
280
+ tgt_lang: tgtLang,
281
+ ...settings
282
+ }),
283
+ signal
284
+ });
285
+
286
+ if (response.ok) {
287
+ const data = await response.json();
288
+ translatedLines[index] = data.translation;
289
+
290
+ completedLines++;
291
+ updateTgtUI();
292
+ timingInfo.textContent = `Translating: ${Math.round((completedLines / totalToTranslate) * 100)}%`;
293
+ }
294
+ } catch (e) {
295
+ if (e.name !== 'AbortError') {
296
+ console.error("Line translation error", e);
297
+ translatedLines[index] = `[[Error: ${line}]]`;
298
+ }
299
+ } finally {
300
+ if (completedLines === totalToTranslate) {
301
+ loader.classList.add('hidden');
302
+ timingInfo.textContent = 'Done';
303
+ }
304
+ }
305
+ };
306
+
307
+ // Fire all translation requests in parallel
308
+ lines.forEach((line, i) => translateParagraph(line, i));
309
+
310
+ } catch (e) {
311
+ if (e.name !== 'AbortError') {
312
+ console.error("Translation error", e);
313
+ loader.classList.add('hidden');
314
+ timingInfo.textContent = 'Error';
315
+ }
316
+ }
317
+ }
318
+
319
+ // 3. Models View
320
+ async function fetchModels() {
321
+ try {
322
+ const res = await fetch('/api/models');
323
+ const data = await res.json();
324
+
325
+ modelsList.innerHTML = '';
326
+
327
+ // Use DocumentFragment for better performance
328
+ const fragment = document.createDocumentFragment();
329
+
330
+ data.models.forEach(m => {
331
+ const card = document.createElement('div');
332
+ card.className = 'model-card';
333
+ card.innerHTML = `
334
+ <div class="model-lang-pair">
335
+ <span>${m.src_lang.toUpperCase()}</span>
336
+ <span>→</span>
337
+ <span>${m.tgt_lang.toUpperCase()}</span>
338
+ </div>
339
+ <div class="model-id">${m.model_id}</div>
340
+ ${m.loaded ? '<span class="loaded-badge">Currently Loaded</span>' : ''}
341
+ `;
342
+ fragment.appendChild(card);
343
+ });
344
+
345
+ // Single DOM update instead of multiple
346
+ modelsList.appendChild(fragment);
347
+ } catch (e) {
348
+ modelsList.innerHTML = '<p>Error loading models</p>';
349
+ }
350
+ }
351
+
352
+ // 4. UI Helpers
353
+ function updateHealth(isOnline) {
354
+ if (isOnline) {
355
+ healthIndicator.className = 'status-pill status-online';
356
+ healthIndicator.querySelector('.status-text').textContent = 'Online';
357
+ } else {
358
+ healthIndicator.className = 'status-pill status-loading';
359
+ healthIndicator.querySelector('.status-text').textContent = 'Offline';
360
+ }
361
+ }
362
+
363
+ // Event Listeners
364
+ srcText.addEventListener('input', () => {
365
+ charCount.textContent = `${srcText.value.length} characters`;
366
+ clearTimeout(debounceTimer);
367
+ debounceTimer = setTimeout(performTranslation, 250);
368
+ });
369
+
370
+ srcLangSelect.addEventListener('change', () => {
371
+ updateTargetOptions();
372
+ performTranslation();
373
+ });
374
+
375
+ tgtLangSelect.addEventListener('change', performTranslation);
376
+
377
+ copyBtn.addEventListener('click', () => {
378
+ navigator.clipboard.writeText(tgtText.value);
379
+ const originalText = copyBtn.textContent;
380
+ copyBtn.textContent = 'Copied!';
381
+ setTimeout(() => copyBtn.textContent = originalText, 2000);
382
+ });
383
+
384
+ // Navigation
385
+ navLinks.forEach(link => {
386
+ link.addEventListener('click', (e) => {
387
+ e.preventDefault();
388
+ const targetId = link.getAttribute('href').substring(1);
389
+
390
+ navLinks.forEach(l => l.parentElement.classList.remove('active'));
391
+ link.parentElement.classList.add('active');
392
+
393
+ views.forEach(v => {
394
+ v.classList.remove('active');
395
+ if (v.id === `${targetId}-view`) v.classList.add('active');
396
+ });
397
+
398
+ if (targetId === 'models') fetchModels();
399
+ });
400
+ });
401
+
402
+ // Start
403
+ initTheme();
404
+ initSidebar();
405
+ initSettings();
406
+ initLanguages();
407
+ setInterval(initLanguages, 10000); // Pulse health check
408
+ });
quickmt/gui/static/index.html ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>QuickMT Machine Translation</title>
8
+ <link rel="preconnect" href="https://fonts.googleapis.com">
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link
11
+ href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=Inter:wght@300;400;500;600&display=swap"
12
+ rel="stylesheet">
13
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.1/css/all.min.css">
14
+ <link rel="stylesheet" href="style.css">
15
+ </head>
16
+
17
+ <body>
18
+ <div class="bg-blur"></div>
19
+
20
+ <div class="top-nav-links">
21
+ <a href="https://huggingface.co/quickmt" target="_blank" title="Hugging Face Models" class="glass-btn">
22
+ <span class="btn-icon">🤗</span>
23
+ </a>
24
+ <a href="https://github.com/quickmt/quickmt" target="_blank" title="GitHub Repository"
25
+ class="glass-btn icon-only">
26
+ <span class="btn-icon"><i class="fa-brands fa-github"></i></span>
27
+ </a>
28
+ </div>
29
+
30
+ <main class="app-container">
31
+ <!-- Sidebar Navigation -->
32
+ <nav class="sidebar glass">
33
+ <div class="logo">
34
+ <div class="logo-icon">Q</div>
35
+ <span>QuickMT</span>
36
+ </div>
37
+ <ul class="nav-links">
38
+ <li class="active">
39
+ <a href="#translate">
40
+ <span class="nav-icon">🔁</span>
41
+ <span class="nav-text">Translate</span>
42
+ </a>
43
+ </li>
44
+ <li>
45
+ <a href="#models">
46
+ <span class="nav-icon">🧩</span>
47
+ <span class="nav-text">Models</span>
48
+ </a>
49
+ </li>
50
+ <li>
51
+ <a href="#settings">
52
+ <span class="nav-icon">⚙️</span>
53
+ <span class="nav-text">Settings</span>
54
+ </a>
55
+ </li>
56
+ </ul>
57
+ <div class="sidebar-footer">
58
+ <div id="health-indicator" class="status-pill status-loading">
59
+ <span class="dot"></span>
60
+ <span class="status-text">Connecting...</span>
61
+ </div>
62
+ <div class="theme-toggle-container">
63
+ <button id="theme-toggle" title="Toggle Light/Dark Mode">
64
+ <span class="mode-icon">◑</span>
65
+ <span class="mode-text">Dark Mode</span>
66
+ </button>
67
+ </div>
68
+ <button id="sidebar-toggle" class="icon-btn collapse-btn" title="Toggle Sidebar">«</button>
69
+ </div>
70
+ </nav>
71
+
72
+ <!-- Main Content -->
73
+ <section class="content">
74
+ <!-- Translate View -->
75
+ <div id="translate-view" class="view active">
76
+ <header class="view-header">
77
+ <h1>QuickMT Neural MachineTranslation</h1>
78
+ </header>
79
+
80
+ <div class="translation-grid">
81
+ <!-- Source Column -->
82
+ <div class="card glass translation-card">
83
+ <div class="card-header">
84
+ <div class="lang-group">
85
+ <span class="lang-label">From</span>
86
+ <select id="src-lang-select" class="lang-select">
87
+ <option value="">Auto-detect</option>
88
+ </select>
89
+ </div>
90
+ <div id="detected-badge" class="detected-badge"></div>
91
+ </div>
92
+ <div class="card-body">
93
+ <textarea id="src-text" placeholder="Enter text to translate..." autofocus></textarea>
94
+ </div>
95
+ <div class="card-footer">
96
+ <span id="char-count">0 characters</span>
97
+ </div>
98
+ </div>
99
+
100
+ <!-- Target Column -->
101
+ <div class="card glass translation-card target-card">
102
+ <div class="card-header">
103
+ <div class="lang-group">
104
+ <span class="lang-label">To</span>
105
+ <select id="tgt-lang-select" class="lang-select">
106
+ <option value="en">English</option>
107
+ </select>
108
+ </div>
109
+ </div>
110
+ <div class="card-body">
111
+ <textarea id="tgt-text" readonly placeholder="Translation will appear here..."></textarea>
112
+ <div id="translation-loader" class="loader-overlay hidden">
113
+ <div class="spinner"></div>
114
+ </div>
115
+ </div>
116
+ <div class="card-footer">
117
+ <span id="timing-info">Ready</span>
118
+ <button id="copy-btn" class="action-btn">Copy</button>
119
+ </div>
120
+ </div>
121
+ </div>
122
+ </div>
123
+
124
+ <!-- Models View -->
125
+ <div id="models-view" class="view">
126
+ <header class="view-header">
127
+ <h1>Available Models</h1>
128
+ <p>Browse models from the quickmt Hugging Face collection</p>
129
+ </header>
130
+ <div id="models-list" class="models-grid">
131
+ <!-- Model cards will be injected here -->
132
+ </div>
133
+ </div>
134
+ <!-- Settings View -->
135
+ <div id="settings-view" class="view">
136
+ <header class="view-header">
137
+ <h1>Inference Settings</h1>
138
+ <p>Fine-tune the translation engine for your needs</p>
139
+ </header>
140
+
141
+ <div class="settings-container glass">
142
+ <div class="settings-grid">
143
+ <!-- Beam Size -->
144
+ <div class="setting-item">
145
+ <div class="setting-info">
146
+ <label>Beam Size</label>
147
+ <span class="setting-desc">Number of hypotheses to explore (1-10)</span>
148
+ </div>
149
+ <div class="setting-control">
150
+ <input type="range" id="setting-beam-size" min="1" max="10" step="1" value="2">
151
+ <span class="setting-val">2</span>
152
+ </div>
153
+ </div>
154
+
155
+ <!-- Patience -->
156
+ <div class="setting-item">
157
+ <div class="setting-info">
158
+ <label>Patience</label>
159
+ <span class="setting-desc">Stopping criterion factor (1-10)</span>
160
+ </div>
161
+ <div class="setting-control">
162
+ <input type="range" id="setting-patience" min="1" max="10" step="1" value="1">
163
+ <span class="setting-val">1</span>
164
+ </div>
165
+ </div>
166
+
167
+ <!-- Length Penalty -->
168
+ <div class="setting-item">
169
+ <div class="setting-info">
170
+ <label>Length Penalty</label>
171
+ <span class="setting-desc">Favour longer or shorter sentences (default 1.0)</span>
172
+ </div>
173
+ <div class="setting-control">
174
+ <input type="number" id="setting-length-penalty" step="0.1" value="1.0">
175
+ </div>
176
+ </div>
177
+
178
+ <!-- Coverage Penalty -->
179
+ <div class="setting-item">
180
+ <div class="setting-info">
181
+ <label>Coverage Penalty</label>
182
+ <span class="setting-desc">Ensure all source words are translated (default 0.0)</span>
183
+ </div>
184
+ <div class="setting-control">
185
+ <input type="number" id="setting-coverage-penalty" step="0.1" value="0.0">
186
+ </div>
187
+ </div>
188
+
189
+ <!-- Repetition Penalty -->
190
+ <div class="setting-item">
191
+ <div class="setting-info">
192
+ <label>Repetition Penalty</label>
193
+ <span class="setting-desc">Prevent repeating words (default 1.0)</span>
194
+ </div>
195
+ <div class="setting-control">
196
+ <input type="number" id="setting-repetition-penalty" step="0.1" value="1.0">
197
+ </div>
198
+ </div>
199
+ </div>
200
+
201
+ <div class="settings-actions">
202
+ <button id="reset-settings" class="action-btn secondary">Reset to Defaults</button>
203
+ </div>
204
+ </div>
205
+ </div>
206
+ </section>
207
+ </main>
208
+
209
+ <script src="app.js"></script>
210
+ </body>
211
+
212
+ </html>
quickmt/gui/static/style.css ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary: #6366f1;
3
+ --primary-glow: rgba(99, 102, 241, 0.5);
4
+ --bg-gradient: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%);
5
+ --glass-bg: rgba(255, 255, 255, 0.03);
6
+ --glass-border: rgba(255, 255, 255, 0.1);
7
+ --text-main: #f8fafc;
8
+ --text-muted: #94a3b8;
9
+ --card-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.37);
10
+ --transition: none;
11
+ --sidebar-active-bg: rgba(255, 255, 255, 0.05);
12
+ --input-bg: rgba(255, 255, 255, 0.05);
13
+ --btn-hover-bg: rgba(255, 255, 255, 0.1);
14
+ }
15
+
16
+ body.light-mode {
17
+ --bg-gradient: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
18
+ --glass-bg: rgba(255, 255, 255, 0.7);
19
+ --glass-border: rgba(99, 102, 241, 0.1);
20
+ --text-main: #1e293b;
21
+ --text-muted: #64748b;
22
+ --card-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.07);
23
+ --sidebar-active-bg: rgba(99, 102, 241, 0.05);
24
+ --input-bg: rgba(0, 0, 0, 0.02);
25
+ --btn-hover-bg: rgba(99, 102, 241, 0.1);
26
+ }
27
+
28
+ * {
29
+ margin: 0;
30
+ padding: 0;
31
+ box-sizing: border-box;
32
+ font-family: 'Inter', sans-serif;
33
+ }
34
+
35
+ h1,
36
+ h2,
37
+ h3,
38
+ .logo {
39
+ font-family: 'Outfit', sans-serif;
40
+ }
41
+
42
+ body {
43
+ background: var(--bg-gradient);
44
+ color: var(--text-main);
45
+ min-height: 100vh;
46
+ overflow: hidden;
47
+ }
48
+
49
+ .bg-blur {
50
+ position: fixed;
51
+ top: 0;
52
+ left: 0;
53
+ width: 100%;
54
+ height: 100%;
55
+ z-index: -1;
56
+ background: radial-gradient(circle at 20% 30%, rgba(99, 102, 241, 0.15) 0%, transparent 40%),
57
+ radial-gradient(circle at 80% 70%, rgba(168, 85, 247, 0.15) 0%, transparent 40%);
58
+ }
59
+
60
+ .top-nav-links {
61
+ position: fixed;
62
+ top: 1.5rem;
63
+ right: 2.5rem;
64
+ display: flex;
65
+ gap: 0.75rem;
66
+ z-index: 100;
67
+ }
68
+
69
+ .glass-btn {
70
+ display: flex;
71
+ align-items: center;
72
+ justify-content: center;
73
+ padding: 0.6rem 1rem;
74
+ min-width: 44px;
75
+ height: 44px;
76
+ background: var(--glass-bg);
77
+ backdrop-filter: blur(12px);
78
+ -webkit-backdrop-filter: blur(12px);
79
+ border: 1px solid var(--glass-border);
80
+ border-radius: 0.75rem;
81
+ color: var(--text-main);
82
+ text-decoration: none;
83
+ font-size: 0.85rem;
84
+ font-weight: 600;
85
+ transition: none;
86
+ }
87
+
88
+ .glass-btn.icon-only {
89
+ padding: 0;
90
+ width: 44px;
91
+ }
92
+
93
+ .glass-btn .btn-icon i {
94
+ font-size: 1.25rem;
95
+ }
96
+
97
+ .glass-btn:hover {
98
+ background: var(--btn-hover-bg);
99
+ }
100
+
101
+ .glass-btn .btn-icon {
102
+ display: flex;
103
+ align-items: center;
104
+ gap: 0.5rem;
105
+ }
106
+
107
+ .app-container {
108
+ display: flex;
109
+ height: 100vh;
110
+ padding: 1.5rem;
111
+ gap: 1.5rem;
112
+ }
113
+
114
+ /* Glass Effect */
115
+ .glass {
116
+ background: var(--glass-bg);
117
+ backdrop-filter: blur(12px);
118
+ -webkit-backdrop-filter: blur(12px);
119
+ border: 1px solid var(--glass-border);
120
+ border-radius: 1.25rem;
121
+ }
122
+
123
+ /* Sidebar */
124
+ .sidebar {
125
+ width: 280px;
126
+ display: flex;
127
+ flex-direction: column;
128
+ padding: 2rem;
129
+ transition: none;
130
+ overflow: hidden;
131
+ }
132
+
133
+ .sidebar.collapsed {
134
+ width: 90px;
135
+ padding: 2rem 1.25rem;
136
+ }
137
+
138
+ .logo {
139
+ display: flex;
140
+ align-items: center;
141
+ gap: 1rem;
142
+ font-size: 1.5rem;
143
+ font-weight: 700;
144
+ margin-bottom: 3rem;
145
+ position: relative;
146
+ }
147
+
148
+ .logo span {
149
+ transition: none;
150
+ white-space: nowrap;
151
+ }
152
+
153
+ .sidebar.collapsed .logo span {
154
+ opacity: 0;
155
+ pointer-events: none;
156
+ }
157
+
158
+ .logo-icon {
159
+ width: 40px;
160
+ height: 40px;
161
+ min-width: 40px;
162
+ background: var(--primary);
163
+ border-radius: 10px;
164
+ display: flex;
165
+ align-items: center;
166
+ justify-content: center;
167
+ color: white;
168
+ box-shadow: 0 0 20px var(--primary-glow);
169
+ }
170
+
171
+ .nav-links {
172
+ list-style: none;
173
+ flex: 1;
174
+ }
175
+
176
+ .nav-links li {
177
+ margin-bottom: 0.5rem;
178
+ }
179
+
180
+ .nav-links a {
181
+ display: flex;
182
+ align-items: center;
183
+ padding: 0.75rem 1rem;
184
+ color: var(--text-muted);
185
+ text-decoration: none;
186
+ border-radius: 0.75rem;
187
+ transition: none;
188
+ }
189
+
190
+ .nav-links .nav-text {
191
+ transition: none;
192
+ }
193
+
194
+ .sidebar.collapsed .nav-links .nav-text {
195
+ opacity: 0;
196
+ pointer-events: none;
197
+ width: 0;
198
+ }
199
+
200
+ .nav-icon {
201
+ font-size: 1.25rem;
202
+ min-width: 24px;
203
+ display: flex;
204
+ align-items: center;
205
+ justify-content: center;
206
+ margin-right: 0.75rem;
207
+ transition: none;
208
+ }
209
+
210
+ .sidebar.collapsed .nav-icon {
211
+ margin-right: 0;
212
+ width: 100%;
213
+ }
214
+
215
+ .sidebar.collapsed .nav-links a {
216
+ justify-content: center;
217
+ padding: 0.75rem 0;
218
+ }
219
+
220
+ .nav-links li.active a,
221
+ .nav-links a:hover {
222
+ background: var(--sidebar-active-bg);
223
+ color: var(--text-main);
224
+ }
225
+
226
+ .nav-links li.active a {
227
+ border-left: 3px solid var(--primary);
228
+ }
229
+
230
+ .sidebar.collapsed .nav-links li.active a {
231
+ border-left: none;
232
+ background: var(--sidebar-active-bg);
233
+ box-shadow: inset 0 0 10px rgba(99, 102, 241, 0.2);
234
+ }
235
+
236
+ /* Content Area */
237
+ .content {
238
+ flex: 1;
239
+ overflow-y: auto;
240
+ padding-right: 0.5rem;
241
+ }
242
+
243
+ .view {
244
+ display: none;
245
+ }
246
+
247
+ .view.active {
248
+ display: flex;
249
+ flex-direction: column;
250
+ height: 100%;
251
+ }
252
+
253
+
254
+
255
+ .view-header {
256
+ margin-bottom: 2rem;
257
+ }
258
+
259
+ .view-header h1 {
260
+ font-size: 2.25rem;
261
+ margin-bottom: 0.5rem;
262
+ }
263
+
264
+ .view-header p {
265
+ color: var(--text-muted);
266
+ }
267
+
268
+ /* Prediction / Translation Grid */
269
+ .translation-grid {
270
+ display: grid;
271
+ grid-template-columns: 1fr 1fr;
272
+ gap: 1.5rem;
273
+ flex: 1;
274
+ min-height: 0;
275
+ }
276
+
277
+ .card {
278
+ display: flex;
279
+ flex-direction: column;
280
+ box-shadow: var(--card-shadow);
281
+ }
282
+
283
+ .card-header {
284
+ padding: 1rem 1.5rem;
285
+ border-bottom: 1px solid var(--glass-border);
286
+ display: flex;
287
+ align-items: center;
288
+ justify-content: space-between;
289
+ }
290
+
291
+ .card-body {
292
+ flex: 1;
293
+ position: relative;
294
+ }
295
+
296
+ .card-footer {
297
+ padding: 1rem 1.5rem;
298
+ border-top: 1px solid var(--glass-border);
299
+ font-size: 0.85rem;
300
+ color: var(--text-muted);
301
+ display: flex;
302
+ align-items: center;
303
+ justify-content: space-between;
304
+ }
305
+
306
+ textarea {
307
+ width: 100%;
308
+ height: 100%;
309
+ background: transparent;
310
+ border: none;
311
+ resize: none;
312
+ padding: 1.5rem;
313
+ color: var(--text-main);
314
+ font-size: 1.1rem;
315
+ line-height: 1.6;
316
+ outline: none;
317
+ }
318
+
319
+ .lang-group {
320
+ display: flex;
321
+ align-items: center;
322
+ gap: 0.75rem;
323
+ }
324
+
325
+ .lang-label {
326
+ font-size: 0.7rem;
327
+ font-weight: 700;
328
+ text-transform: uppercase;
329
+ letter-spacing: 0.05em;
330
+ color: var(--text-muted);
331
+ }
332
+
333
+ .lang-select {
334
+ background: var(--input-bg);
335
+ border: 1px solid var(--glass-border);
336
+ color: var(--text-main);
337
+ padding: 0.5rem 1rem;
338
+ border-radius: 0.5rem;
339
+ outline: none;
340
+ cursor: pointer;
341
+ }
342
+
343
+ .detected-badge {
344
+ font-size: 0.75rem;
345
+ background: var(--primary);
346
+ padding: 0.2rem 0.6rem;
347
+ border-radius: 1rem;
348
+ color: white;
349
+ opacity: 0;
350
+ transition: none;
351
+ }
352
+
353
+ .detected-badge.visible {
354
+ opacity: 1;
355
+ transition: none;
356
+ }
357
+
358
+ /* Loader */
359
+ .loader-overlay {
360
+ position: absolute;
361
+ top: 0;
362
+ left: 0;
363
+ width: 100%;
364
+ height: 100%;
365
+ background: rgba(15, 23, 42, 0.5);
366
+ display: flex;
367
+ align-items: center;
368
+ justify-content: center;
369
+ border-radius: 0 0 1.25rem 1.25rem;
370
+ }
371
+
372
+ .hidden {
373
+ display: none !important;
374
+ }
375
+
376
+ .spinner {
377
+ width: 30px;
378
+ height: 30px;
379
+ border: 3px solid rgba(255, 255, 255, 0.1);
380
+ border-top-color: var(--primary);
381
+ border-radius: 50%;
382
+ animation: none;
383
+ }
384
+
385
+ @keyframes spin {
386
+ to {
387
+ transform: rotate(360deg);
388
+ }
389
+ }
390
+
391
+ /* Success Pills */
392
+ .status-pill {
393
+ display: flex;
394
+ align-items: center;
395
+ gap: 0.5rem;
396
+ padding: 0.5rem 1rem;
397
+ border-radius: 2rem;
398
+ background: rgba(0, 0, 0, 0.2);
399
+ font-size: 0.85rem;
400
+ }
401
+
402
+ .dot {
403
+ width: 8px;
404
+ height: 8px;
405
+ border-radius: 50%;
406
+ }
407
+
408
+ .status-online .dot {
409
+ background: #10b981;
410
+ box-shadow: 0 0 10px #10b981;
411
+ }
412
+
413
+ .status-loading .dot {
414
+ background: #f59e0b;
415
+ animation: none;
416
+ }
417
+
418
+ @keyframes pulse {
419
+ 0% {
420
+ transform: scale(1);
421
+ opacity: 1;
422
+ }
423
+
424
+ 50% {
425
+ transform: scale(1.2);
426
+ opacity: 0.5;
427
+ }
428
+
429
+ 100% {
430
+ transform: scale(1);
431
+ opacity: 1;
432
+ }
433
+ }
434
+
435
+ /* Models Grid */
436
+ .models-grid {
437
+ display: grid;
438
+ grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
439
+ gap: 1.5rem;
440
+ }
441
+
442
+ .model-card {
443
+ padding: 1.5rem;
444
+ transition: none;
445
+ /* Performance optimizations */
446
+ contain: layout style paint;
447
+ will-change: background;
448
+ /* Remove expensive backdrop-filter for better scrolling */
449
+ background: var(--glass-bg);
450
+ border: 1px solid var(--glass-border);
451
+ border-radius: 1.25rem;
452
+ box-shadow: var(--card-shadow);
453
+ }
454
+
455
+ .model-card:hover {
456
+ background: var(--sidebar-active-bg);
457
+ }
458
+
459
+ .model-lang-pair {
460
+ display: flex;
461
+ align-items: center;
462
+ gap: 0.5rem;
463
+ font-weight: 600;
464
+ margin-bottom: 0.5rem;
465
+ }
466
+
467
+ .model-id {
468
+ font-size: 0.8rem;
469
+ color: var(--text-muted);
470
+ word-break: break-all;
471
+ }
472
+
473
+ .loaded-badge {
474
+ display: inline-block;
475
+ padding: 0.2rem 0.5rem;
476
+ background: rgba(16, 185, 129, 0.1);
477
+ color: #10b981;
478
+ border: 1px solid rgba(16, 185, 129, 0.2);
479
+ border-radius: 0.4rem;
480
+ font-size: 0.7rem;
481
+ margin-top: 1rem;
482
+ }
483
+
484
+ /* Buttons */
485
+ .action-btn {
486
+ background: var(--primary);
487
+ color: white;
488
+ border: none;
489
+ padding: 0.5rem 1rem;
490
+ border-radius: 0.5rem;
491
+ cursor: pointer;
492
+ font-weight: 500;
493
+ transition: none;
494
+ }
495
+
496
+ .action-btn:hover {
497
+ background: #4f46e5;
498
+ }
499
+
500
+ .icon-btn {
501
+ background: transparent;
502
+ border: none;
503
+ color: var(--text-muted);
504
+ font-size: 1.25rem;
505
+ cursor: pointer;
506
+ transition: none;
507
+ }
508
+
509
+ .icon-btn:hover {
510
+ color: var(--text-main);
511
+ }
512
+
513
+ /* Settings View */
514
+ .settings-container {
515
+ max-width: 800px;
516
+ padding: 2rem;
517
+ margin-top: 1rem;
518
+ }
519
+
520
+ .settings-grid {
521
+ display: flex;
522
+ flex-direction: column;
523
+ gap: 2rem;
524
+ }
525
+
526
+ .setting-item {
527
+ display: grid;
528
+ grid-template-columns: 1fr 200px;
529
+ align-items: center;
530
+ gap: 2rem;
531
+ padding-bottom: 2rem;
532
+ border-bottom: 1px solid var(--glass-border);
533
+ }
534
+
535
+ .setting-item:last-child {
536
+ border-bottom: none;
537
+ }
538
+
539
+ .setting-info {
540
+ display: flex;
541
+ flex-direction: column;
542
+ gap: 0.25rem;
543
+ }
544
+
545
+ .setting-info label {
546
+ font-weight: 600;
547
+ font-size: 1.1rem;
548
+ }
549
+
550
+ .setting-desc {
551
+ color: var(--text-muted);
552
+ font-size: 0.85rem;
553
+ }
554
+
555
+ .setting-control {
556
+ display: flex;
557
+ align-items: center;
558
+ gap: 1rem;
559
+ }
560
+
561
+ .setting-val {
562
+ min-width: 30px;
563
+ font-family: monospace;
564
+ font-weight: 600;
565
+ color: var(--primary);
566
+ }
567
+
568
+ input[type="range"] {
569
+ flex: 1;
570
+ cursor: pointer;
571
+ accent-color: var(--primary);
572
+ }
573
+
574
+ input[type="number"] {
575
+ width: 100%;
576
+ background: var(--input-bg);
577
+ border: 1px solid var(--glass-border);
578
+ color: var(--text-main);
579
+ padding: 0.5rem;
580
+ border-radius: 0.5rem;
581
+ outline: none;
582
+ font-family: inherit;
583
+ }
584
+
585
+ .settings-actions {
586
+ margin-top: 3rem;
587
+ display: flex;
588
+ justify-content: flex-end;
589
+ }
590
+
591
+ .action-btn.secondary {
592
+ background: rgba(255, 255, 255, 0.05);
593
+ border: 1px solid var(--glass-border);
594
+ }
595
+
596
+ .action-btn.secondary:hover {
597
+ background: rgba(255, 255, 255, 0.1);
598
+ }
599
+
600
+ /* Theme Toggle Button */
601
+ .theme-toggle-container {
602
+ margin-top: 1rem;
603
+ }
604
+
605
+ #theme-toggle {
606
+ width: 100%;
607
+ display: flex;
608
+ align-items: center;
609
+ gap: 0.75rem;
610
+ background: var(--input-bg);
611
+ border: 1px solid var(--glass-border);
612
+ color: var(--text-main);
613
+ padding: 0.75rem 1rem;
614
+ border-radius: 0.75rem;
615
+ cursor: pointer;
616
+ font-size: 0.9rem;
617
+ transition: none;
618
+ }
619
+
620
+ #theme-toggle:hover {
621
+ background: var(--sidebar-active-bg);
622
+ }
623
+
624
+ .mode-icon {
625
+ font-size: 1.1rem;
626
+ }
627
+
628
+ .sidebar.collapsed .mode-text {
629
+ display: none;
630
+ }
631
+
632
+ .sidebar.collapsed .theme-toggle-container {
633
+ display: flex;
634
+ justify-content: center;
635
+ }
636
+
637
+ #sidebar-toggle {
638
+ margin-top: 1rem;
639
+ width: 40px;
640
+ height: 40px;
641
+ display: flex;
642
+ align-items: center;
643
+ justify-content: center;
644
+ background: var(--input-bg);
645
+ border-radius: 50%;
646
+ margin-left: auto;
647
+ font-size: 1.25rem;
648
+ transition: none;
649
+ z-index: 10;
650
+ }
651
+
652
+ #sidebar-toggle:hover {
653
+ background: var(--sidebar-active-bg);
654
+ }
655
+
656
+ .sidebar.collapsed #sidebar-toggle {
657
+ transform: rotate(180deg);
658
+ margin-left: 0.5rem;
659
+ }
660
+
661
+ .sidebar.collapsed .sidebar-footer .status-text {
662
+ display: none;
663
+ }
664
+
665
+ .sidebar.collapsed .status-pill {
666
+ padding: 0.5rem;
667
+ justify-content: center;
668
+ }
quickmt/langid.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union, Optional
2
+ from pathlib import Path
3
+ import os
4
+ import urllib.request
5
+ import fasttext
6
+
7
+ # Suppress fasttext's warning about being loaded in a way that doesn't
8
+ # allow querying its version (common in some environments)
9
+ fasttext.FastText.eprint = lambda x: None
10
+
11
+
12
+ class LanguageIdentification:
13
+ """Detect language using a FastText langid model.
14
+
15
+ This class provides a wrapper around the FastText library for efficient
16
+ language identification, supporting both single-string and batch processing.
17
+ """
18
+
19
+ def __init__(self, model_path: Optional[Union[str, Path]] = None):
20
+ """Initialize the LanguageIdentification model.
21
+
22
+ Args:
23
+ model_path: Path to the pre-trained FastText model file.
24
+ If None, defaults to 'models/lid.176.bin' and downloads if missing.
25
+ """
26
+ if model_path is None:
27
+ cache_dir = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
28
+ model_dir = cache_dir / "fasttext_language_id"
29
+ model_path = model_dir / "lid.176.bin"
30
+
31
+ model_path = Path(model_path)
32
+
33
+ if not model_path.exists():
34
+ model_path.parent.mkdir(parents=True, exist_ok=True)
35
+ url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
36
+ print(f"Downloading FastText model from {url} to {model_path}...")
37
+ urllib.request.urlretrieve(url, str(model_path))
38
+ print("Download complete.")
39
+
40
+ self.ft = fasttext.load_model(str(model_path))
41
+
42
+ def predict(
43
+ self,
44
+ text: Union[str, List[str]],
45
+ k: int = 1,
46
+ threshold: float = 0.0
47
+ ) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
48
+ """Predict the language(s) for the given text or list of texts.
49
+
50
+ Args:
51
+ text: A single string or a list of strings to identify.
52
+ k: Number of most likely languages to return. Defaults to 1.
53
+ threshold: Minimum score for a language to be included in the results.
54
+ Defaults to 0.0 (return all k results regardless of score).
55
+
56
+ Returns:
57
+ If input is a string: A list of (lang, score) tuples.
58
+ If input is a list of strings: A list of lists of (lang, score) tuples,
59
+ maintaining the input order.
60
+ """
61
+ is_single = isinstance(text, str)
62
+ items = [text] if is_single else text
63
+
64
+ # Sanitize inputs: FastText errors on newlines
65
+ items = [t.replace("\n", " ") for t in items]
66
+
67
+ # FastText predict handles lists natively and is faster than looping
68
+ ft_output = self.ft.predict(items, k=k, threshold=threshold)
69
+
70
+ # FastText returns ([['__label__en', ...], ...], [[0.9, ...], ...])
71
+ labels, scores = ft_output
72
+
73
+ results = []
74
+ for item_labels, item_scores in zip(labels, scores):
75
+ item_results = [
76
+ (label.replace("__label__", ""), float(score))
77
+ for label, score in zip(item_labels, item_scores)
78
+ ]
79
+ results.append(item_results)
80
+
81
+ return results[0] if is_single else results
82
+
83
+ def predict_best(
84
+ self,
85
+ text: Union[str, List[str]],
86
+ threshold: float = 0.0
87
+ ) -> Union[Optional[str], List[Optional[str]]]:
88
+ """Predict the most likely language for the given text or list of texts.
89
+
90
+ This is a convenience wrapper around `predict` that returns only the
91
+ top-scoring language label (or None if no language exceeds the threshold).
92
+
93
+ Args:
94
+ text: A single string or a list of strings to identify.
95
+ threshold: Minimum score for a language to be selected.
96
+
97
+ Returns:
98
+ If input is a string: The language code (e.g., 'en') or None.
99
+ If input is a list: A list of language codes or None.
100
+ """
101
+ results = self.predict(text, k=1, threshold=threshold)
102
+
103
+ if isinstance(text, str):
104
+ # results is List[Tuple[str, float]]
105
+ return results[0][0] if results else None
106
+ else:
107
+ # results is List[List[Tuple[str, float]]]
108
+ return [r[0][0] if r else None for r in results]
109
+
110
+
111
+ def ensure_model_exists(model_path: Optional[Union[str, Path]] = None):
112
+ """Ensure the FastText model exists on disk, downloading if necessary.
113
+ This should be called from the main process before starting worker pools.
114
+ """
115
+ if model_path is None:
116
+ cache_dir = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
117
+ model_dir = cache_dir / "fasttext_language_id"
118
+ model_path = model_dir / "lid.176.bin"
119
+
120
+ model_path = Path(model_path)
121
+
122
+ if not model_path.exists():
123
+ model_path.parent.mkdir(parents=True, exist_ok=True)
124
+ url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
125
+ print(f"Downloading FastText model from {url} to {model_path}...")
126
+ urllib.request.urlretrieve(url, str(model_path))
127
+ print("Download complete.")
128
+
129
+
130
+ # Global detector instance for process pool workers
131
+ _detector: Optional[LanguageIdentification] = None
132
+
133
+
134
+ def init_worker(model_path: Optional[Union[str, Path]] = None):
135
+ """Initialize the global detector instance for a worker process."""
136
+ global _detector
137
+ # We assume ensure_model_exists was already called in the main process
138
+ _detector = LanguageIdentification(model_path)
139
+
140
+
141
+ def predict_worker(
142
+ text: Union[str, List[str]],
143
+ k: int = 1,
144
+ threshold: float = 0.0
145
+ ) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
146
+ """Prediction function to be run in a worker process."""
147
+ if _detector is None:
148
+ # Fallback if init_worker failed or wasn't called
149
+ init_worker()
150
+ return _detector.predict(text, k=k, threshold=threshold)
quickmt/manager.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional
6
+ from collections import OrderedDict
7
+ from functools import lru_cache
8
+
9
+ from fastapi import HTTPException
10
+ from huggingface_hub import HfApi, snapshot_download
11
+ from cachetools import TTLCache, cached, LRUCache
12
+
13
+ from quickmt.translator import Translator
14
+ from quickmt.settings import settings
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BatchTranslator:
20
+ def __init__(
21
+ self,
22
+ model_id: str,
23
+ model_path: str,
24
+ device: str = "cpu",
25
+ compute_type: str = "default",
26
+ inter_threads: int = 1,
27
+ intra_threads: int = 0,
28
+ ):
29
+ self.model_id = model_id
30
+ self.model_path = model_path
31
+ self.device = device
32
+ self.compute_type = compute_type
33
+ self.inter_threads = inter_threads
34
+ self.intra_threads = intra_threads
35
+ self.translator: Optional[Translator] = None
36
+ self.queue: asyncio.Queue = asyncio.Queue()
37
+ self.worker_task: Optional[asyncio.Task] = None
38
+ # LRU cache for translations
39
+ self.translation_cache: LRUCache = LRUCache(
40
+ maxsize=settings.translation_cache_size
41
+ )
42
+
43
+ async def start_worker(self):
44
+ if self.worker_task:
45
+ return
46
+
47
+ # Load model in main process (or worker thread if needed)
48
+ # For now, Translator handles its own loading
49
+ self.translator = Translator(
50
+ Path(self.model_path),
51
+ device=self.device,
52
+ compute_type=self.compute_type,
53
+ inter_threads=self.inter_threads,
54
+ intra_threads=self.intra_threads,
55
+ )
56
+ self.worker_task = asyncio.create_task(self._worker())
57
+ logger.info(f"Started translation worker for model: {self.model_id}")
58
+
59
+ async def stop_worker(self):
60
+ if not self.worker_task:
61
+ return
62
+
63
+ # Send sentinel to stop worker
64
+ await self.queue.put(None)
65
+ await self.worker_task
66
+ self.worker_task = None
67
+ if self.translator:
68
+ self.translator.unload()
69
+ self.translator = None
70
+ logger.info(f"Stopped translation worker for model: {self.model_id}")
71
+
72
+ async def _worker(self):
73
+ while True:
74
+ item = await self.queue.get()
75
+ if item is None:
76
+ self.queue.task_done()
77
+ break
78
+
79
+ src, src_lang, tgt_lang, kwargs, future = item
80
+ try:
81
+ # 1. Collect batch
82
+ batch_texts = [src]
83
+ futures = [future]
84
+
85
+ # Try to grab more items up to MAX_BATCH_SIZE or timeout
86
+ start_time = time.time()
87
+ while len(batch_texts) < settings.max_batch_size:
88
+ wait_time = (settings.batch_timeout_ms / 1000.0) - (
89
+ time.time() - start_time
90
+ )
91
+ if wait_time <= 0:
92
+ break
93
+ try:
94
+ next_item = await asyncio.wait_for(
95
+ self.queue.get(), timeout=wait_time
96
+ )
97
+ if next_item is None:
98
+ # Re-add sentinel to handle later
99
+ await self.queue.put(None)
100
+ break
101
+ n_src, n_sl, n_tl, n_kw, n_fut = next_item
102
+
103
+ # Only batch if parameters match exactly
104
+ if n_sl == src_lang and n_tl == tgt_lang and n_kw == kwargs:
105
+ batch_texts.append(n_src)
106
+ futures.append(n_fut)
107
+ else:
108
+ # Re-queue item for a later batch/worker cycle
109
+ await self.queue.put(next_item)
110
+ break
111
+ except asyncio.TimeoutError:
112
+ break
113
+
114
+ # 2. Process batch
115
+ # Run in executor to avoid blocking the asyncio loop during inference
116
+ loop = asyncio.get_running_loop()
117
+ results = await loop.run_in_executor(
118
+ None,
119
+ lambda: self.translator(
120
+ batch_texts, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs
121
+ ),
122
+ )
123
+
124
+ # result can be string or list
125
+ if isinstance(results, str):
126
+ results = [results]
127
+
128
+ # 3. Resolve futures
129
+ for res, fut in zip(results, futures):
130
+ if not fut.done():
131
+ fut.set_result(res)
132
+
133
+ # Mark done for all processed items
134
+ for _ in range(len(batch_texts)):
135
+ self.queue.task_done()
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error in translation worker for {self.model_id}: {e}")
139
+ if not future.done():
140
+ future.set_exception(e)
141
+ # TODO: handle others if batched
142
+
143
+ async def translate(
144
+ self, src: str, src_lang: str = None, tgt_lang: str = None, **kwargs
145
+ ) -> str:
146
+ if not self.worker_task:
147
+ await self.start_worker()
148
+
149
+ # Create cache key from input parameters
150
+ # Convert kwargs to a sorted tuple for hashability
151
+ kwargs_tuple = tuple(sorted(kwargs.items()))
152
+ cache_key = (src, src_lang, tgt_lang, kwargs_tuple)
153
+
154
+ # Check cache first
155
+ if cache_key in self.translation_cache:
156
+ return self.translation_cache[cache_key]
157
+
158
+ # Cache miss - perform translation
159
+ future = asyncio.get_running_loop().create_future()
160
+ await self.queue.put((src, src_lang, tgt_lang, kwargs, future))
161
+ result = await future
162
+
163
+ # Store in cache
164
+ self.translation_cache[cache_key] = result
165
+ return result
166
+
167
+
168
+ class ModelManager:
169
+ def __init__(
170
+ self,
171
+ max_loaded: int,
172
+ device: str,
173
+ compute_type: str = "default",
174
+ inter_threads: int = 1,
175
+ intra_threads: int = 0,
176
+ ):
177
+ self.max_loaded = max_loaded
178
+ self.device = device
179
+ self.compute_type = compute_type
180
+ self.inter_threads = inter_threads
181
+ self.intra_threads = intra_threads
182
+ # cache key: src-tgt string
183
+ self.models: OrderedDict[str, BatchTranslator] = OrderedDict()
184
+ self.pending_loads: Dict[str, asyncio.Event] = {}
185
+ self.lock = asyncio.Lock()
186
+ self.hf_collection_models: List[Dict] = []
187
+ self.api = HfApi()
188
+
189
+ @cached(cache=TTLCache(maxsize=1, ttl=3600))
190
+ async def fetch_hf_models(self):
191
+ """Fetch available models from the quickmt collection on Hugging Face."""
192
+ try:
193
+ loop = asyncio.get_running_loop()
194
+ collection = await loop.run_in_executor(
195
+ None, lambda: self.api.get_collection("quickmt/quickmt-models")
196
+ )
197
+
198
+ hf_models = []
199
+ for item in collection.items:
200
+ if item.item_type == "model":
201
+ model_id = item.item_id
202
+ # Expecting format: quickmt/quickmt-en-fr
203
+ parts = model_id.split("/")[-1].replace("quickmt-", "").split("-")
204
+ if len(parts) == 2:
205
+ src, tgt = parts
206
+ hf_models.append(
207
+ {"model_id": model_id, "src_lang": src, "tgt_lang": tgt}
208
+ )
209
+ self.hf_collection_models = hf_models
210
+ logger.info(
211
+ f"Discovered {len(hf_models)} models from Hugging Face collection"
212
+ )
213
+ except Exception as e:
214
+ logger.error(f"Failed to fetch models from Hugging Face: {e}")
215
+
216
+ async def get_model(self, src_lang: str, tgt_lang: str) -> BatchTranslator:
217
+ model_name = f"{src_lang}-{tgt_lang}"
218
+
219
+ async with self.lock:
220
+ # 1. Check if loaded
221
+ if model_name in self.models:
222
+ self.models.move_to_end(model_name)
223
+ return self.models[model_name]
224
+
225
+ # 2. Check if currently loading
226
+ if model_name in self.pending_loads:
227
+ event = self.pending_loads[model_name]
228
+ else:
229
+ # NEW: Pre-check existence before starting task to ensure clean 404
230
+ hf_model = next(
231
+ (
232
+ m
233
+ for m in self.hf_collection_models
234
+ if m["src_lang"] == src_lang and m["tgt_lang"] == tgt_lang
235
+ ),
236
+ None,
237
+ )
238
+ if not hf_model:
239
+ raise HTTPException(
240
+ status_code=404,
241
+ detail=f"Model for {src_lang}->{tgt_lang} not found in Hugging Face collection",
242
+ )
243
+
244
+ event = asyncio.Event()
245
+ self.pending_loads[model_name] = event
246
+ # This task will do the actual loading
247
+ asyncio.create_task(self._load_model_task(src_lang, tgt_lang, event))
248
+
249
+ # 3. Wait for load
250
+ await event.wait()
251
+
252
+ # 4. Return from cache
253
+ async with self.lock:
254
+ return self.models[model_name]
255
+
256
+ async def _load_model_task(
257
+ self, src_lang: str, tgt_lang: str, new_event: asyncio.Event
258
+ ):
259
+ model_name = f"{src_lang}-{tgt_lang}"
260
+ try:
261
+ try:
262
+ # Find matching model from HF collection (already checked in get_model)
263
+ hf_model = next(
264
+ m
265
+ for m in self.hf_collection_models
266
+ if m["src_lang"] == src_lang and m["tgt_lang"] == tgt_lang
267
+ )
268
+
269
+ logger.info(f"Accessing Hugging Face model: {hf_model['model_id']}")
270
+ loop = asyncio.get_running_loop()
271
+ # snapshot_download returns the local path in the HF cache.
272
+ # Try local only first to speed up loading
273
+ try:
274
+ cached_path = await loop.run_in_executor(
275
+ None,
276
+ lambda: snapshot_download(
277
+ repo_id=hf_model["model_id"],
278
+ ignore_patterns=["eole-model/*", "eole_model/*"],
279
+ local_files_only=True,
280
+ ),
281
+ )
282
+ except Exception:
283
+ # Fallback to checking online
284
+ logger.info(
285
+ f"Model {hf_model['model_id']} not fully cached, checking online..."
286
+ )
287
+ cached_path = await loop.run_in_executor(
288
+ None,
289
+ lambda: snapshot_download(
290
+ repo_id=hf_model["model_id"],
291
+ ignore_patterns=["eole-model/*", "eole_model/*"],
292
+ ),
293
+ )
294
+ model_path = Path(cached_path)
295
+
296
+ # Prepare for eviction
297
+ evicted_model = None
298
+ async with self.lock:
299
+ if len(self.models) >= self.max_loaded:
300
+ oldest_name, evicted_model = self.models.popitem(last=False)
301
+ logger.info(f"Evicting model: {oldest_name}")
302
+
303
+ if evicted_model:
304
+ await evicted_model.stop_worker()
305
+
306
+ # Load new model (SLOW, outside lock)
307
+ logger.info(
308
+ f"Loading model: {hf_model['model_id']} (device: {self.device}, compute: {self.compute_type})"
309
+ )
310
+ new_model = BatchTranslator(
311
+ model_id=hf_model["model_id"],
312
+ model_path=str(model_path),
313
+ device=self.device,
314
+ compute_type=self.compute_type,
315
+ inter_threads=self.inter_threads,
316
+ intra_threads=self.intra_threads,
317
+ )
318
+ await new_model.start_worker()
319
+
320
+ # Add to cache
321
+ async with self.lock:
322
+ self.models[model_name] = new_model
323
+
324
+ except Exception as e:
325
+ logger.error(f"Error loading model {model_name}: {e}")
326
+ # We still need to set the event to unblock waiters,
327
+ # but we should probably handle errors better in get_model
328
+ raise e
329
+ finally:
330
+ async with self.lock:
331
+ if model_name in self.pending_loads:
332
+ del self.pending_loads[model_name]
333
+ new_event.set()
334
+
335
+ def list_available_models(self) -> List[Dict]:
336
+ """List all models discovered from Hugging Face."""
337
+ available = []
338
+ for m in self.hf_collection_models:
339
+ lang_pair = f"{m['src_lang']}-{m['tgt_lang']}"
340
+ available.append(
341
+ {
342
+ "model_id": m["model_id"],
343
+ "src_lang": m["src_lang"],
344
+ "tgt_lang": m["tgt_lang"],
345
+ "loaded": lang_pair in self.models,
346
+ }
347
+ )
348
+ return available
349
+
350
+ @lru_cache(maxsize=1)
351
+ def get_language_pairs(self) -> Dict[str, List[str]]:
352
+ """Return a dictionary of source languages to list of supported target languages."""
353
+ pairs: Dict[str, set] = {}
354
+ for m in self.hf_collection_models:
355
+ src = m["src_lang"]
356
+ tgt = m["tgt_lang"]
357
+ if src not in pairs:
358
+ pairs[src] = set()
359
+ pairs[src].add(tgt)
360
+
361
+ # Convert sets to sorted lists
362
+ return {src: sorted(list(tgts)) for src, tgts in sorted(pairs.items())}
363
+
364
+ async def shutdown(self):
365
+ for name, model in self.models.items():
366
+ await model.stop_worker()
367
+ self.models.clear()
quickmt/rest_server.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import time
5
+ from contextlib import asynccontextmanager
6
+ from typing import List, Optional, Union, Dict
7
+ from concurrent.futures import ProcessPoolExecutor
8
+
9
+ from fastapi import FastAPI, HTTPException, APIRouter
10
+ from fastapi.responses import ORJSONResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from pydantic import BaseModel, model_validator
13
+
14
+ from quickmt.langid import init_worker, predict_worker, ensure_model_exists
15
+ from quickmt.manager import ModelManager
16
+ from quickmt.settings import settings
17
+
18
+ logger = logging.getLogger("uvicorn.error")
19
+
20
+
21
+ class TranslationRequest(BaseModel):
22
+ src: Union[str, List[str]]
23
+ src_lang: Optional[Union[str, List[str]]] = None
24
+ tgt_lang: str = "en"
25
+ beam_size: int = 5
26
+ patience: int = 1
27
+ length_penalty: float = 1.0
28
+ coverage_penalty: float = 0.0
29
+ repetition_penalty: float = 1.0
30
+ max_decoding_length: int = 256
31
+
32
+ @model_validator(mode="after")
33
+ def validate_patience(self):
34
+ if self.patience > self.beam_size:
35
+ raise ValueError("patience cannot be greater than beam_size")
36
+ return self
37
+
38
+
39
+ class TranslationResponse(BaseModel):
40
+ translation: Union[str, List[str]]
41
+ src_lang: Union[str, List[str]]
42
+ src_lang_score: Union[float, List[float]]
43
+ tgt_lang: str
44
+ processing_time: float
45
+ model_used: Union[str, List[str]]
46
+
47
+
48
+ class DetectionRequest(BaseModel):
49
+ src: Union[str, List[str]]
50
+ k: int = 1
51
+ threshold: float = 0.0
52
+
53
+
54
+ class DetectionResult(BaseModel):
55
+ lang: str
56
+ score: float
57
+
58
+
59
+ class DetectionResponse(BaseModel):
60
+ results: Union[List[DetectionResult], List[List[DetectionResult]]]
61
+ processing_time: float
62
+
63
+
64
+ class BatchItem:
65
+ def __init__(
66
+ self,
67
+ src: List[str],
68
+ src_lang: str,
69
+ tgt_lang: str,
70
+ beam_size: int,
71
+ max_decoding_length: int,
72
+ future: asyncio.Future,
73
+ ):
74
+ self.src = src
75
+ self.src_lang = src_lang
76
+ self.tgt_lang = tgt_lang
77
+ self.beam_size = beam_size
78
+ self.max_decoding_length = max_decoding_length
79
+ self.future = future
80
+
81
+
82
+ # Global instances initialized in lifespan
83
+ model_manager: Optional[ModelManager] = None
84
+ langid_executor: Optional[ProcessPoolExecutor] = None
85
+
86
+
87
+ @asynccontextmanager
88
+ async def lifespan(app: FastAPI):
89
+ global model_manager, langid_executor
90
+
91
+ model_manager = ModelManager(
92
+ max_loaded=settings.max_loaded_models,
93
+ device=settings.device,
94
+ compute_type=settings.compute_type,
95
+ inter_threads=settings.inter_threads,
96
+ intra_threads=settings.intra_threads,
97
+ )
98
+
99
+ # 1. Fetch available models from Hugging Face
100
+ await model_manager.fetch_hf_models()
101
+
102
+ # 2. Ensure langid model is downloaded in main process before starting workers
103
+ loop = asyncio.get_running_loop()
104
+ await loop.run_in_executor(None, ensure_model_exists, settings.langid_model_path)
105
+
106
+ # Initialize langid process pool
107
+ langid_executor = ProcessPoolExecutor(
108
+ max_workers=settings.langid_workers,
109
+ initializer=init_worker,
110
+ initargs=(settings.langid_model_path,),
111
+ )
112
+
113
+ yield
114
+
115
+ if langid_executor:
116
+ langid_executor.shutdown()
117
+ await model_manager.shutdown()
118
+
119
+
120
+ app = FastAPI(
121
+ title="quickmt Multi-Model API",
122
+ lifespan=lifespan,
123
+ default_response_class=ORJSONResponse,
124
+ )
125
+ api_router = APIRouter(prefix="/api")
126
+
127
+
128
+ @api_router.post("/translate", response_model=TranslationResponse)
129
+ async def translate_endpoint(request: TranslationRequest):
130
+ if not model_manager:
131
+ raise HTTPException(status_code=503, detail="Model manager not initialized")
132
+
133
+ start_time = time.time()
134
+ src_list = [request.src] if isinstance(request.src, str) else request.src
135
+ if not src_list:
136
+ return TranslationResponse(
137
+ translation="" if isinstance(request.src, str) else [],
138
+ src_lang="" if isinstance(request.src, str) else [],
139
+ src_lang_score=0.0 if isinstance(request.src, str) else [],
140
+ tgt_lang=request.tgt_lang,
141
+ processing_time=time.time() - start_time,
142
+ model_used="none",
143
+ )
144
+
145
+ try:
146
+ loop = asyncio.get_running_loop()
147
+
148
+ # 1. Determine source languages and confidence scores
149
+ if request.src_lang:
150
+ if isinstance(request.src_lang, list):
151
+ if not isinstance(src_list, list) or len(request.src_lang) != len(
152
+ src_list
153
+ ):
154
+ raise HTTPException(
155
+ status_code=422,
156
+ detail="src_lang list length must match src list length",
157
+ )
158
+ src_langs = request.src_lang
159
+ src_lang_scores = [1.0] * len(src_list)
160
+ else:
161
+ src_langs = [request.src_lang] * len(src_list)
162
+ src_lang_scores = [1.0] * len(src_list)
163
+ else:
164
+ if not langid_executor:
165
+ raise HTTPException(
166
+ status_code=503, detail="Language identification not initialized"
167
+ )
168
+ # Batch detect languages
169
+ raw_langid_results = await loop.run_in_executor(
170
+ langid_executor,
171
+ predict_worker,
172
+ src_list,
173
+ 1, # k=1 (best guess)
174
+ 0.0, # threshold
175
+ )
176
+ # results are List[List[Tuple[str, float]]], extract labels and scores
177
+ src_langs = [r[0][0] if r else "unknown" for r in raw_langid_results]
178
+ src_lang_scores = [float(r[0][1]) if r else 0.0 for r in raw_langid_results]
179
+
180
+ # 2. Group indices by source language
181
+ # groups: { "fr": [0, 2, ...], "es": [1, ...] }
182
+ groups: Dict[str, List[int]] = {}
183
+ for idx, lang in enumerate(src_langs):
184
+ if lang not in groups:
185
+ groups[lang] = []
186
+ groups[lang].append(idx)
187
+
188
+ # 3. Process each group
189
+ final_translations = [""] * len(src_list)
190
+ final_models = [""] * len(src_list)
191
+ tasks = []
192
+
193
+ # We need a way to track which lang pairs were actually used for the 'model_used' string
194
+ used_pairs = set()
195
+
196
+ for lang, indices in groups.items():
197
+ group_src = [src_list[i] for i in indices]
198
+
199
+ # Optimization: If src == tgt, skip translation
200
+ if lang == request.tgt_lang:
201
+ for src_idx, idx in enumerate(indices):
202
+ final_translations[idx] = group_src[src_idx]
203
+ final_models[idx] = "identity"
204
+ continue
205
+
206
+ # Load model and translate for this group
207
+ async def process_group_task(l=lang, i_list=indices, g_src=group_src):
208
+ try:
209
+ translator = await model_manager.get_model(l, request.tgt_lang)
210
+ used_pairs.add(translator.model_id)
211
+ # Call translate for each sentence; BatchTranslator will handle opportunistic batching
212
+ translation_tasks = [
213
+ translator.translate(
214
+ s,
215
+ src_lang=l,
216
+ tgt_lang=request.tgt_lang,
217
+ beam_size=request.beam_size,
218
+ patience=request.patience,
219
+ length_penalty=request.length_penalty,
220
+ coverage_penalty=request.coverage_penalty,
221
+ repetition_penalty=request.repetition_penalty,
222
+ max_decoding_length=request.max_decoding_length,
223
+ )
224
+ for s in g_src
225
+ ]
226
+ results = await asyncio.gather(*translation_tasks)
227
+ for result_idx, original_idx in enumerate(i_list):
228
+ final_translations[original_idx] = results[result_idx]
229
+ final_models[original_idx] = translator.model_id
230
+ except HTTPException as e:
231
+ # If a specific model is missing, we could either fail the whole batch
232
+ # or keep original text. Here we fail for consistency with previous behavior.
233
+ raise e
234
+ except Exception as e:
235
+ logger.error(f"Error translating {l} to {request.tgt_lang}: {e}")
236
+ raise e
237
+
238
+ tasks.append(process_group_task())
239
+
240
+ if tasks:
241
+ await asyncio.gather(*tasks)
242
+
243
+ # 4. Prepare response
244
+ if isinstance(request.src, str):
245
+ result = final_translations[0]
246
+ src_lang_res = src_langs[0]
247
+ src_lang_score_res = src_lang_scores[0]
248
+ model_used_res = final_models[0]
249
+ else:
250
+ result = final_translations
251
+ src_lang_res = src_langs
252
+ src_lang_score_res = src_lang_scores
253
+ model_used_res = final_models
254
+
255
+ return TranslationResponse(
256
+ translation=result,
257
+ src_lang=src_lang_res,
258
+ src_lang_score=src_lang_score_res,
259
+ tgt_lang=request.tgt_lang,
260
+ processing_time=time.time() - start_time,
261
+ model_used=model_used_res,
262
+ )
263
+
264
+ except HTTPException:
265
+ raise
266
+ except Exception as e:
267
+ logger.exception("Unexpected error in translate_endpoint")
268
+ raise HTTPException(status_code=500, detail=str(e))
269
+
270
+
271
+ @api_router.post("/identify-language", response_model=DetectionResponse)
272
+ async def identify_language_endpoint(request: DetectionRequest):
273
+ if not langid_executor:
274
+ raise HTTPException(
275
+ status_code=503, detail="Language identification not initialized"
276
+ )
277
+
278
+ start_time = time.time()
279
+ try:
280
+ loop = asyncio.get_running_loop()
281
+ # Offload detection to process pool to avoid GIL issues
282
+ raw_results = await loop.run_in_executor(
283
+ langid_executor, predict_worker, request.src, request.k, request.threshold
284
+ )
285
+
286
+ # Convert raw tuples to Pydantic models
287
+ if isinstance(request.src, str):
288
+ results = [
289
+ DetectionResult(lang=lang, score=score) for lang, score in raw_results
290
+ ]
291
+ else:
292
+ results = [
293
+ [
294
+ DetectionResult(lang=lang, score=score)
295
+ for lang, score in item_results
296
+ ]
297
+ for item_results in raw_results
298
+ ]
299
+
300
+ return DetectionResponse(
301
+ results=results, processing_time=time.time() - start_time
302
+ )
303
+ except Exception as e:
304
+ raise HTTPException(status_code=500, detail=str(e))
305
+
306
+
307
+ @api_router.get("/models")
308
+ async def get_models():
309
+ if not model_manager:
310
+ raise HTTPException(status_code=503, detail="Model manager not initialized")
311
+ return {"models": model_manager.list_available_models()}
312
+
313
+
314
+ @api_router.get("/languages")
315
+ async def get_languages():
316
+ if not model_manager:
317
+ raise HTTPException(status_code=503, detail="Model manager not initialized")
318
+ return model_manager.get_language_pairs()
319
+
320
+
321
+ @api_router.get("/health")
322
+ async def health_check():
323
+ loaded_models = list(model_manager.models.keys()) if model_manager else []
324
+ return {
325
+ "status": "ok",
326
+ "loaded_models": loaded_models,
327
+ "max_models": settings.max_loaded_models,
328
+ }
329
+
330
+
331
+ app.include_router(api_router)
332
+
333
+ # Serve static files for the GUI
334
+ static_dir = os.path.join(os.path.dirname(__file__), "gui", "static")
335
+ if os.path.exists(static_dir):
336
+ app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
337
+
338
+
339
+ def start():
340
+ """Entry point for the quickmt-serve CLI."""
341
+ import uvicorn
342
+
343
+ uvicorn.run("quickmt.rest_server:app", host="0.0.0.0", port=8000, reload=False)
344
+
345
+
346
+ def start_gui():
347
+ """Entry point for the quickmt-gui CLI."""
348
+ import uvicorn
349
+ import webbrowser
350
+ import threading
351
+ import time
352
+
353
+ def open_browser():
354
+ time.sleep(1.5)
355
+ webbrowser.open("http://127.0.0.1:8000")
356
+
357
+ threading.Thread(target=open_browser, daemon=True).start()
358
+ uvicorn.run("quickmt.rest_server:app", host="0.0.0.0", port=8000, reload=False)
quickmt/settings.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Centralized configuration management using pydantic-settings.
2
+
3
+ This module provides a type-safe, centralized way to manage all configuration
4
+ settings for the quickmt library. Settings can be configured via:
5
+ - Environment variables (e.g., MAX_LOADED_MODELS=10)
6
+ - .env file in the project root
7
+ - Runtime modification of the global settings object
8
+
9
+ All environment variables are case-insensitive.
10
+ """
11
+
12
+ from typing import Optional
13
+ from pydantic_settings import BaseSettings, SettingsConfigDict
14
+
15
+
16
+ class Settings(BaseSettings):
17
+ """Application settings with environment variable support.
18
+
19
+ All settings can be overridden via environment variables.
20
+ For example, to set max_loaded_models, use MAX_LOADED_MODELS=10
21
+ """
22
+
23
+ # Model Manager Settings
24
+ max_loaded_models: int = 5
25
+ """Maximum number of translation models to keep loaded in memory"""
26
+
27
+ device: str = "cpu"
28
+ """Device to use for inference: 'cpu', 'cuda', or 'auto'"""
29
+
30
+ compute_type: str = "default"
31
+ """CTranslate2 compute type: 'default', 'int8', 'int8_float16', 'int16', 'float16', 'float32'"""
32
+
33
+ inter_threads: int = 1
34
+ """Number of threads to use for inter-op parallelism (simultaneous translations)"""
35
+
36
+ intra_threads: int = 4
37
+ """Number of threads to use for intra-op parallelism (within each translation)"""
38
+
39
+ # Batch Processing Settings
40
+ max_batch_size: int = 32
41
+ """Maximum batch size for translation requests"""
42
+
43
+ batch_timeout_ms: int = 5
44
+ """Timeout in milliseconds to wait for batching additional requests"""
45
+
46
+ # Language Identification Settings
47
+ langid_model_path: Optional[str] = None
48
+ """Path to FastText language identification model. If None, uses default cache location"""
49
+
50
+ langid_workers: int = 2
51
+ """Number of worker processes for language identification"""
52
+
53
+ # Translation Cache Settings
54
+ translation_cache_size: int = 10000
55
+ """Maximum number of translations to cache (LRU eviction)"""
56
+
57
+ model_config = SettingsConfigDict(
58
+ env_prefix="",
59
+ case_sensitive=False,
60
+ env_file=".env",
61
+ env_file_encoding="utf-8",
62
+ extra="ignore",
63
+ )
64
+
65
+
66
+ # Global settings instance
67
+ # This can be imported and used throughout the application
68
+ # Settings can be modified at runtime: settings.max_loaded_models = 10
69
+ settings = Settings()
quickmt/translator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from time import time
4
+ from typing import List, Optional, Union
5
+
6
+ import ctranslate2
7
+ import sentencepiece
8
+ from blingfire import text_to_sentences
9
+ from pydantic import DirectoryPath, validate_call
10
+
11
+
12
+ class TranslatorABC(ABC):
13
+ def __init__(self, model_path: DirectoryPath, **kwargs):
14
+ """Create quickmt translation object
15
+
16
+ Args:
17
+ model_path (DirectoryPath): Path to quickmt model folder
18
+ **kwargs: CTranslate2 Translator arguments - see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html
19
+ """
20
+ self.model_path = Path(model_path)
21
+ self.translator = ctranslate2.Translator(str(model_path), **kwargs)
22
+
23
+ @staticmethod
24
+ @validate_call
25
+ def _sentence_split(src: List[str]):
26
+ """Split sentences with Blingfire
27
+
28
+ Args:
29
+ src (List[str]): Input list of strings to split by sentences
30
+
31
+ Returns:
32
+ List[int], List[int], List[str]: List of input ids, list of paragraph ids and sentences
33
+ """
34
+ input_ids = []
35
+ paragraph_ids = []
36
+ sentences = []
37
+ for idx, i in enumerate(src):
38
+ for paragraph, j in enumerate(i.splitlines(keepends=True)):
39
+ sents = text_to_sentences(j).splitlines()
40
+ for sent in sents:
41
+ stripped_sent = sent.strip()
42
+ if len(stripped_sent) > 0:
43
+ if (
44
+ len(stripped_sent) < 5
45
+ and len(paragraph_ids) > 0
46
+ and paragraph == paragraph_ids[-1]
47
+ and len(input_ids) > 0
48
+ and input_ids[-1] == idx
49
+ ):
50
+ sentences[-1] += " " + stripped_sent
51
+ else:
52
+ input_ids.append(idx)
53
+ paragraph_ids.append(paragraph)
54
+ sentences.append(stripped_sent)
55
+
56
+ return input_ids, paragraph_ids, sentences
57
+
58
+ @staticmethod
59
+ @validate_call
60
+ def _sentence_join(
61
+ input_ids: List[int],
62
+ paragraph_ids: List[int],
63
+ sentences: List[str],
64
+ paragraph_join_str: str = "\n",
65
+ sent_join_str: str = " ",
66
+ length: Optional[int] = None,
67
+ ):
68
+ """Sentence joiner
69
+
70
+ Args:
71
+ input_ids (List[int]): List of input IDs
72
+ paragraph_ids (List[int]): List of paragraph IDs
73
+ sentences (List[str]): List of sentences to join up by input and paragraph ids
74
+ paragraph_join_str (str, optional): str to use to join paragraphs. Defaults to "\n".
75
+ sent_join_str (str, optional): str to join up sentences. Defaults to " ".
76
+
77
+ Returns:
78
+ List[str]: Joined up sentences
79
+ """
80
+ if not input_ids:
81
+ return [""] * (length or 0)
82
+
83
+ target_len = length if length is not None else (max(input_ids) + 1)
84
+ ret = [""] * target_len
85
+ last_paragraph = 0
86
+ for idx, paragraph, text in zip(input_ids, paragraph_ids, sentences):
87
+ if len(ret[idx]) > 0:
88
+ if paragraph == last_paragraph:
89
+ ret[idx] += sent_join_str + text
90
+ else:
91
+ ret[idx] += paragraph_join_str + text
92
+ last_paragraph = paragraph
93
+ else:
94
+ ret[idx] = text
95
+ last_paragraph = paragraph
96
+ return ret
97
+
98
+ @abstractmethod
99
+ def tokenize(
100
+ self,
101
+ sentences: List[str],
102
+ src_lang: Optional[str] = None,
103
+ tgt_lang: Optional[str] = None,
104
+ ): ...
105
+
106
+ @abstractmethod
107
+ def detokenize(
108
+ self,
109
+ sentences: List[List[str]],
110
+ src_lang: Optional[str] = None,
111
+ tgt_lang: Optional[str] = None,
112
+ ): ...
113
+
114
+ @abstractmethod
115
+ def translate_batch(
116
+ self,
117
+ sentences: List[List[str]],
118
+ src_lang: Optional[str] = None,
119
+ tgt_lang: Optional[str] = None,
120
+ ): ...
121
+
122
+ @abstractmethod
123
+ def unload(self): ...
124
+
125
+ @validate_call
126
+ def __call__(
127
+ self,
128
+ src: Union[str, List[str]],
129
+ max_batch_size: int = 32,
130
+ max_decoding_length: int = 256,
131
+ beam_size: int = 2,
132
+ patience: int = 1,
133
+ length_penalty: float = 1.0,
134
+ coverage_penalty: float = 0.0,
135
+ repetition_penalty: float = 1.0,
136
+ verbose: bool = False,
137
+ src_lang: Union[None, str] = None,
138
+ tgt_lang: Union[None, str] = None,
139
+ **kwargs,
140
+ ) -> Union[str, List[str]]:
141
+ """Translate a list of strings with quickmt model
142
+
143
+ Args:
144
+ src (List[str]): Input list of strings to translate
145
+ max_batch_size (int, optional): Maximum batch size, to constrain RAM utilization. Defaults to 32.
146
+ beam_size (int, optional): CTranslate2 Beam size. Defaults to 5.
147
+ patience (int, optional): CTranslate2 Patience. Defaults to 1.
148
+ max_decoding_length (int, optional): Maximum length of translation
149
+ **args: Other CTranslate2 translate_batch args, see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch
150
+
151
+ Returns:
152
+ Union[str, List[str]]: Translation of the input
153
+ """
154
+ if isinstance(src, str):
155
+ return_string = True
156
+ src = [src]
157
+ else:
158
+ return_string = False
159
+
160
+ indices, paragraphs, sentences = self._sentence_split(src)
161
+
162
+ if not sentences:
163
+ return "" if return_string else [""] * len(src)
164
+
165
+ if verbose:
166
+ print(f"Split sentences: {sentences}")
167
+
168
+ input_text = self.tokenize(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
169
+ if verbose:
170
+ print(f"Tokenized input: {input_text}")
171
+
172
+ t1 = time()
173
+ results = self.translate_batch(
174
+ input_text,
175
+ beam_size=beam_size,
176
+ patience=patience,
177
+ length_penalty=length_penalty,
178
+ coverage_penalty=coverage_penalty,
179
+ repetition_penalty=repetition_penalty,
180
+ max_decoding_length=max_decoding_length,
181
+ max_batch_size=max_batch_size,
182
+ src_lang=src_lang,
183
+ tgt_lang=tgt_lang,
184
+ **kwargs,
185
+ )
186
+ t2 = time()
187
+ if verbose:
188
+ print(f"Translation time: {t2 - t1}")
189
+
190
+ output_tokens = [i.hypotheses[0] for i in results]
191
+
192
+ if verbose:
193
+ print(f"Tokenized output: {output_tokens}")
194
+
195
+ translated_sents = self.detokenize(
196
+ output_tokens, src_lang=src_lang, tgt_lang=tgt_lang
197
+ )
198
+
199
+ ret = self._sentence_join(
200
+ indices, paragraphs, translated_sents, length=len(src)
201
+ )
202
+
203
+ if return_string:
204
+ return ret[0]
205
+ else:
206
+ return ret
207
+
208
+ @validate_call
209
+ def translate_file(self, input_file: str, output_file: str, **kwargs) -> None:
210
+ """Translate a file with a quickmt model
211
+
212
+ Args:
213
+ file_path (str): Path to plain-text file to translate
214
+ """
215
+ with open(input_file, "rt") as myfile:
216
+ src = myfile.readlines()
217
+
218
+ # Remove newlines
219
+ src = [i.strip() for i in src]
220
+
221
+ # Translate
222
+ mt = self(src, **kwargs)
223
+
224
+ # Replace newlines to ensure output is the same number of lines
225
+ mt = [i.replace("\n", "\t") for i in mt]
226
+
227
+ with open(output_file, "wt") as myfile:
228
+ myfile.write("".join([i + "\n" for i in mt]))
229
+
230
+ @validate_call
231
+ def translate_stream(
232
+ self,
233
+ src: Union[str, List[str]],
234
+ max_batch_size: int = 32,
235
+ max_decoding_length: int = 256,
236
+ beam_size: int = 5,
237
+ patience: int = 1,
238
+ src_lang: Union[None, str] = None,
239
+ tgt_lang: Union[None, str] = None,
240
+ **kwargs,
241
+ ):
242
+ """Translate a list of strings with quickmt model
243
+
244
+ Args:
245
+ src (List[str]): Input list of strings to translate
246
+ max_batch_size (int, optional): Maximum batch size, to constrain RAM utilization. Defaults to 32.
247
+ beam_size (int, optional): CTranslate2 Beam size. Defaults to 5.
248
+ patience (int, optional): CTranslate2 Patience. Defaults to 1.
249
+ max_decoding_length (int, optional): Maximum length of translation
250
+ **args: Other CTranslate2 translate_batch args, see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch
251
+ """
252
+ if isinstance(src, str):
253
+ src = [src]
254
+
255
+ indices, paragraphs, sentences = self._sentence_split(src)
256
+
257
+ input_text = self.tokenize(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
258
+
259
+ translations_iterator = self.translator.translate_iterable(
260
+ input_text,
261
+ beam_size=beam_size,
262
+ patience=patience,
263
+ max_decoding_length=max_decoding_length,
264
+ max_batch_size=max_batch_size,
265
+ **kwargs,
266
+ )
267
+
268
+ for idx, para, sent, output in zip(
269
+ indices, paragraphs, sentences, translations_iterator
270
+ ):
271
+ yield {
272
+ "input_idx": idx,
273
+ "sentence_idx": para,
274
+ "input_text": sent,
275
+ "translation": self.detokenize([output.hypotheses[0]])[0],
276
+ }
277
+
278
+
279
+ class Translator(TranslatorABC):
280
+ def __init__(
281
+ self,
282
+ model_path: DirectoryPath,
283
+ inter_threads: int = 1,
284
+ intra_threads: int = 0,
285
+ **kwargs,
286
+ ):
287
+ """Create quickmt translation object
288
+
289
+ Args:
290
+ model_path (DirectoryPath): Path to quickmt model folder
291
+ inter_threads (int): Number of simultaneous translations
292
+ intra_threads (int): Number of threads for each translation
293
+ **kwargs: CTranslate2 Translator arguments - see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html
294
+ """
295
+ super().__init__(
296
+ model_path,
297
+ inter_threads=inter_threads,
298
+ intra_threads=intra_threads,
299
+ **kwargs,
300
+ )
301
+ joint_tokenizer_path = self.model_path / "joint.spm.model"
302
+ if joint_tokenizer_path.exists():
303
+ self.source_tokenizer = sentencepiece.SentencePieceProcessor(
304
+ model_file=str(self.model_path / "joint.spm.model")
305
+ )
306
+ self.target_tokenizer = sentencepiece.SentencePieceProcessor(
307
+ model_file=str(self.model_path / "joint.spm.model")
308
+ )
309
+ else:
310
+ self.source_tokenizer = sentencepiece.SentencePieceProcessor(
311
+ model_file=str(self.model_path / "src.spm.model")
312
+ )
313
+ self.target_tokenizer = sentencepiece.SentencePieceProcessor(
314
+ model_file=str(self.model_path / "tgt.spm.model")
315
+ )
316
+
317
+ def __del__(self):
318
+ self.unload()
319
+
320
+ def tokenize(
321
+ self,
322
+ sentences: List[str],
323
+ src_lang: Optional[str] = None,
324
+ tgt_lang: Optional[str] = None,
325
+ ):
326
+ # Default implementation ignores lang tags unless explicitly handled
327
+ return [
328
+ i + ["</s>"] for i in self.source_tokenizer.encode(sentences, out_type=str)
329
+ ]
330
+
331
+ def detokenize(
332
+ self,
333
+ sentences: List[List[str]],
334
+ src_lang: Optional[str] = None,
335
+ tgt_lang: Optional[str] = None,
336
+ ):
337
+ return self.target_tokenizer.decode(sentences)
338
+
339
+ def unload(self):
340
+ """Explicitly release CTranslate2 translator resources"""
341
+ if hasattr(self, "translator"):
342
+ del self.translator
343
+
344
+ def translate_batch(
345
+ self,
346
+ input_text: List[List[str]],
347
+ beam_size: int = 5,
348
+ patience: int = 1,
349
+ max_decoding_length: int = 256,
350
+ max_batch_size: int = 32,
351
+ disable_unk: bool = True,
352
+ replace_unknowns: bool = False,
353
+ length_penalty: float = 1.0,
354
+ coverage_penalty: float = 0.0,
355
+ repetition_penalty: float = 1.0,
356
+ src_lang: str = None,
357
+ tgt_lang: str = None,
358
+ **kwargs,
359
+ ):
360
+ """Translate a list of strings
361
+
362
+ Args:
363
+ input_text (List[List[str]]): Input text to be translated
364
+ beam_size (int, optional): Beam size for beam search. Defaults to 5.
365
+ patience (int, optional): Stop beam search when `patience` beams finish. Defaults to 1.
366
+ max_decoding_length (int, optional): Max decoding length for model. Defaults to 256.
367
+ max_batch_size (int, optional): Max batch size. Reduce to limit RAM usage. Increase for faster speed. Defaults to 32.
368
+ disable_unk (bool, optional): Disable generating unk token. Defaults to True.
369
+ replace_unknowns (bool, optional): Replace unk tokens with src token that has the highest attention value. Defaults to False.
370
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
371
+ coverage_penalty (float, optional): Coverage penalty. Defaults to 0.0.
372
+ src_lang (str, optional): Source language. Only needed for multilingual models. Defaults to None.
373
+ tgt_lang (str, optional): Target language. Only needed for multilingual models. Defaults to None.
374
+
375
+ Returns:
376
+ List[str]: Translated text
377
+ """
378
+ return self.translator.translate_batch(
379
+ input_text,
380
+ beam_size=beam_size,
381
+ patience=patience,
382
+ max_decoding_length=max_decoding_length,
383
+ max_batch_size=max_batch_size,
384
+ disable_unk=disable_unk,
385
+ replace_unknowns=replace_unknowns,
386
+ length_penalty=length_penalty,
387
+ coverage_penalty=coverage_penalty,
388
+ repetition_penalty=repetition_penalty,
389
+ **kwargs,
390
+ )
requirements-dev.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pytest
2
+ pytest-asyncio
3
+ httpx
4
+ locust
5
+ sacrebleu
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blingfire
2
+ cachetools
3
+ fastapi
4
+ uvicorn[standard]
5
+ ctranslate2
6
+ sentencepiece
7
+ huggingface_hub
8
+ fasttext-wheel
9
+ orjson
10
+ uvloop
11
+ httptools
12
+ pydantic
13
+ pydantic-settings
14
+
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ from typing import AsyncGenerator
4
+ from quickmt.rest_server import app
5
+ from httpx import AsyncClient
6
+
7
+
8
+ @pytest.fixture(scope="session")
9
+ def base_url() -> str:
10
+ return os.getenv("TEST_BASE_URL", "http://127.0.0.1:8000")
11
+
12
+
13
+ @pytest.fixture
14
+ async def client(base_url: str) -> AsyncGenerator[AsyncClient, None]:
15
+ async with AsyncClient(base_url=base_url, timeout=60.0) as client:
16
+ yield client
tests/test_api.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ from httpx import AsyncClient
4
+
5
+
6
+ @pytest.mark.asyncio
7
+ async def test_health_check(client: AsyncClient):
8
+ response = await client.get("/api/health")
9
+ assert response.status_code == 200
10
+ data = response.json()
11
+ assert data["status"] == "ok"
12
+ assert "loaded_models" in data
13
+
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_get_models(client: AsyncClient):
17
+ response = await client.get("/api/models")
18
+ assert response.status_code == 200
19
+ data = response.json()
20
+ assert "models" in data
21
+ assert isinstance(data["models"], list)
22
+
23
+
24
+ @pytest.mark.asyncio
25
+ async def test_get_languages(client: AsyncClient):
26
+ response = await client.get("/api/languages")
27
+ assert response.status_code == 200
28
+ data = response.json()
29
+ assert isinstance(data, dict)
30
+ # Check structure if models exist
31
+ if data:
32
+ src = list(data.keys())[0]
33
+ assert isinstance(data[src], list)
34
+
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_translate_single(client: AsyncClient):
38
+ # First, find an available model
39
+ models_res = await client.get("/api/models")
40
+ models = models_res.json()["models"]
41
+ if not models:
42
+ pytest.skip("No models available in MODELS_DIR")
43
+
44
+ model = models[0]
45
+ payload = {
46
+ "src": "Hello world",
47
+ "src_lang": model["src_lang"],
48
+ "tgt_lang": model["tgt_lang"],
49
+ }
50
+
51
+ response = await client.post("/api/translate", json=payload)
52
+ assert response.status_code == 200
53
+ data = response.json()
54
+ assert "translation" in data
55
+ assert "processing_time" in data
56
+ assert data["src_lang"] == model["src_lang"]
57
+ assert data["src_lang_score"] == 1.0
58
+ assert data["tgt_lang"] == model["tgt_lang"]
59
+ assert data["model_used"] == model["model_id"]
60
+
61
+
62
+ @pytest.mark.asyncio
63
+ async def test_translate_list(client: AsyncClient):
64
+ models_res = await client.get("/api/models")
65
+ models = models_res.json()["models"]
66
+ if not models:
67
+ pytest.skip("No models available")
68
+
69
+ model = models[0]
70
+ payload = {
71
+ "src": ["Hello", "World"],
72
+ "src_lang": model["src_lang"],
73
+ "tgt_lang": model["tgt_lang"],
74
+ }
75
+
76
+ response = await client.post("/api/translate", json=payload)
77
+ assert response.status_code == 200
78
+ data = response.json()
79
+ assert isinstance(data["translation"], list)
80
+ assert len(data["translation"]) == 2
81
+ assert data["src_lang"] == [model["src_lang"], model["src_lang"]]
82
+ assert data["src_lang_score"] == [1.0, 1.0]
83
+ assert data["tgt_lang"] == model["tgt_lang"]
84
+ assert data["model_used"] == [model["model_id"], model["model_id"]]
85
+
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_dynamic_batching(client: AsyncClient):
89
+ """Verify that multiple concurrent requests work correctly (triggering batching logic)."""
90
+ models_res = await client.get("/api/models")
91
+ models = models_res.json()["models"]
92
+ if not models:
93
+ pytest.skip("No models available")
94
+
95
+ model = models[0]
96
+ src, tgt = model["src_lang"], model["tgt_lang"]
97
+
98
+ texts = [f"Sentence number {i}" for i in range(5)]
99
+ tasks = []
100
+
101
+ for text in texts:
102
+ payload = {"src": text, "src_lang": src, "tgt_lang": tgt}
103
+ tasks.append(client.post("/api/translate", json=payload))
104
+
105
+ responses = await asyncio.gather(*tasks)
106
+
107
+ for response in responses:
108
+ assert response.status_code == 200
109
+ assert "translation" in response.json()
tests/test_auto_translate.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+
4
+
5
+ @pytest.mark.asyncio
6
+ async def test_auto_detect_src_lang(client: AsyncClient):
7
+ """Verify that src_lang is auto-detected if missing."""
8
+ # Ensure some models are available
9
+ models_res = await client.get("/api/models")
10
+ available_models = models_res.json()["models"]
11
+ if not any(
12
+ m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
13
+ ):
14
+ pytest.skip("fr-en model needed for this test")
15
+
16
+ payload = {"src": "Bonjour tout le monde", "tgt_lang": "en"}
17
+ response = await client.post("/api/translate", json=payload)
18
+ assert response.status_code == 200
19
+ data = response.json()
20
+ assert "translation" in data
21
+ assert data["src_lang"] == "fr"
22
+ assert 0.0 < data["src_lang_score"] <= 1.0
23
+ assert data["tgt_lang"] == "en"
24
+ assert "quickmt/quickmt-fr-en" in data["model_used"]
25
+
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_default_tgt_lang(client: AsyncClient):
29
+ """Verify that tgt_lang defaults to 'en'."""
30
+ models_res = await client.get("/api/models")
31
+ available_models = models_res.json()["models"]
32
+ if not any(
33
+ m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
34
+ ):
35
+ pytest.skip("fr-en model needed for this test")
36
+
37
+ payload = {"src": "Bonjour", "src_lang": "fr"}
38
+ response = await client.post("/api/translate", json=payload)
39
+ assert response.status_code == 200
40
+ data = response.json()
41
+ assert data["src_lang"] == "fr"
42
+ assert data["src_lang_score"] == 1.0
43
+ assert data["tgt_lang"] == "en"
44
+ assert "quickmt/quickmt-fr-en" in data["model_used"]
45
+
46
+
47
+ @pytest.mark.asyncio
48
+ async def test_mixed_language_batch(client: AsyncClient):
49
+ """Verify that a batch with mixed languages is handled correctly."""
50
+ models_res = await client.get("/api/models")
51
+ available_models = models_res.json()["models"]
52
+
53
+ needed = [("fr", "en"), ("es", "en")]
54
+ for src, tgt in needed:
55
+ if not any(
56
+ m["src_lang"] == src and m["tgt_lang"] == tgt for m in available_models
57
+ ):
58
+ pytest.skip(f"Mixed batch test needs both fr-en and es-en models")
59
+
60
+ payload = {"src": ["Bonjour tout le monde", "Hola amigos"], "tgt_lang": "en"}
61
+ response = await client.post("/api/translate", json=payload)
62
+ assert response.status_code == 200
63
+ data = response.json()
64
+ assert isinstance(data["translation"], list)
65
+ assert len(data["translation"]) == 2
66
+ assert data["src_lang"] == ["fr", "es"]
67
+ assert len(data["src_lang_score"]) == 2
68
+ assert all(0.0 < s <= 1.0 for s in data["src_lang_score"])
69
+ assert data["tgt_lang"] == "en"
70
+ assert "quickmt/quickmt-fr-en" in data["model_used"]
71
+ assert "quickmt/quickmt-es-en" in data["model_used"]
72
+
73
+
74
+ @pytest.mark.asyncio
75
+ async def test_identity_translation(client: AsyncClient):
76
+ """Verify that translation is skipped if src_lang == tgt_lang."""
77
+ payload = {"src": "This is already English", "src_lang": "en", "tgt_lang": "en"}
78
+ response = await client.post("/api/translate", json=payload)
79
+ assert response.status_code == 200
80
+ data = response.json()
81
+ assert data["translation"] == "This is already English"
82
+ assert data["src_lang"] == "en"
83
+ assert data["src_lang_score"] == 1.0
84
+ assert data["tgt_lang"] == "en"
85
+ assert data["model_used"] == "identity"
86
+
87
+
88
+ @pytest.mark.asyncio
89
+ async def test_auto_detect_mixed_identity(client: AsyncClient):
90
+ """Verify mixed batch with some items needing translation and some remaining as-is."""
91
+ models_res = await client.get("/api/models")
92
+ available_models = models_res.json()["models"]
93
+ if not any(
94
+ m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
95
+ ):
96
+ pytest.skip("fr-en model needed for this test")
97
+
98
+ payload = {"src": ["Bonjour", "Hello world"], "tgt_lang": "en"}
99
+ response = await client.post("/api/translate", json=payload)
100
+ assert response.status_code == 200
101
+ data = response.json()
102
+ assert len(data["translation"]) == 2
103
+ assert data["src_lang"] == ["fr", "en"]
104
+ assert len(data["src_lang_score"]) == 2
105
+ # First should be auto-detected, second should be auto-detected (and high confidence)
106
+ assert all(0.0 < s <= 1.0 for s in data["src_lang_score"])
107
+ assert data["tgt_lang"] == "en"
108
+ assert data["model_used"] == ["quickmt/quickmt-fr-en", "identity"]
tests/test_cache.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple test to demonstrate the translation cache functionality.
3
+ Run this to verify cache hits provide instant responses.
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ from quickmt.manager import BatchTranslator
9
+ from quickmt.settings import settings
10
+
11
+
12
+ async def test_translation_cache():
13
+ print("=== Translation Cache Test ===\n")
14
+
15
+ # Create a mock BatchTranslator (would normally be created by ModelManager)
16
+ # For this test, we'll just verify the cache mechanism
17
+ print(f"Cache size configured: {settings.translation_cache_size}")
18
+
19
+ # Simulate cache behavior
20
+ from cachetools import LRUCache
21
+
22
+ cache = LRUCache(maxsize=settings.translation_cache_size)
23
+
24
+ # Test data
25
+ test_text = "Hello, world!"
26
+ src_lang = "en"
27
+ tgt_lang = "fr"
28
+ kwargs_tuple = tuple(sorted({"beam_size": 5, "patience": 1}.items()))
29
+
30
+ cache_key = (test_text, src_lang, tgt_lang, kwargs_tuple)
31
+
32
+ # First request - cache miss
33
+ print("\n1. First translation (cache miss):")
34
+ print(f" Key: {cache_key}")
35
+ if cache_key in cache:
36
+ print(" ✓ Cache HIT")
37
+ else:
38
+ print(" ✗ Cache MISS (expected)")
39
+ # Simulate translation and caching
40
+ cache[cache_key] = "Bonjour, monde!"
41
+ print(" → Cached result")
42
+
43
+ # Second request - cache hit
44
+ print("\n2. Repeated translation (cache hit):")
45
+ print(f" Key: {cache_key}")
46
+ if cache_key in cache:
47
+ print(" ✓ Cache HIT (instant!)")
48
+ print(f" → Result: {cache[cache_key]}")
49
+ else:
50
+ print(" ✗ Cache MISS (unexpected)")
51
+
52
+ # Different parameters - cache miss
53
+ different_kwargs = tuple(sorted({"beam_size": 10, "patience": 2}.items()))
54
+ different_key = (test_text, src_lang, tgt_lang, different_kwargs)
55
+
56
+ print("\n3. Same text, different parameters (cache miss):")
57
+ print(f" Key: {different_key}")
58
+ if different_key in cache:
59
+ print(" ✓ Cache HIT")
60
+ else:
61
+ print(" ✗ Cache MISS (expected - different params)")
62
+
63
+ print("\n✅ Cache test complete!")
64
+ print(f"Cache size: {len(cache)}/{settings.translation_cache_size}")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ asyncio.run(test_translation_cache())
tests/test_identify_language.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+
4
+
5
+ @pytest.mark.asyncio
6
+ async def test_identify_language_single(client: AsyncClient):
7
+ """Verify single string language identification."""
8
+ payload = {"src": "Hello, how are you?", "k": 1}
9
+ response = await client.post("/api/identify-language", json=payload)
10
+ assert response.status_code == 200
11
+ data = response.json()
12
+ assert "results" in data
13
+ assert "processing_time" in data
14
+ assert isinstance(data["results"], list)
15
+ # FastText should identify this as English
16
+ assert data["results"][0]["lang"] == "en"
17
+
18
+
19
+ @pytest.mark.asyncio
20
+ async def test_identify_language_batch(client: AsyncClient):
21
+ """Verify batch language identification."""
22
+ payload = {"src": ["Bonjour tout le monde", "Hola amigos"], "k": 1}
23
+ response = await client.post("/api/identify-language", json=payload)
24
+ assert response.status_code == 200
25
+ data = response.json()
26
+ assert len(data["results"]) == 2
27
+ assert data["results"][0][0]["lang"] == "fr"
28
+ assert data["results"][1][0]["lang"] == "es"
29
+
30
+
31
+ @pytest.mark.asyncio
32
+ async def test_identify_language_threshold(client: AsyncClient):
33
+ """Verify threshold filtering in the endpoint."""
34
+ payload = {"src": "This is definitely English", "k": 5, "threshold": 0.9}
35
+ response = await client.post("/api/identify-language", json=payload)
36
+ assert response.status_code == 200
37
+ data = response.json()
38
+ # Only 'en' should probably be above 0.9
39
+ assert len(data["results"]) == 1
40
+ assert data["results"][0]["lang"] == "en"
tests/test_langid.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock, patch
3
+ from pathlib import Path
4
+ from quickmt.langid import LanguageIdentification
5
+
6
+ @pytest.fixture
7
+ def mock_fasttext():
8
+ with patch("fasttext.load_model") as mock_load:
9
+ mock_model = MagicMock()
10
+ mock_load.return_value = mock_model
11
+
12
+ # Configure default behavior for predict
13
+ def mock_predict(items, k=1, threshold=0.0):
14
+ # Return ([['__label__en', ...]], [[0.9, ...]])
15
+ labels = [["__label__en"] * k for _ in items]
16
+ scores = [[0.9] * k for _ in items]
17
+ return labels, scores
18
+
19
+ mock_model.predict.side_effect = mock_predict
20
+ yield mock_model
21
+
22
+ @pytest.fixture
23
+ def langid_model(mock_fasttext, tmp_path):
24
+ # Create a dummy model file so the existence check passes
25
+ model_path = tmp_path / "model.bin"
26
+ model_path.write_text("dummy content")
27
+ return LanguageIdentification(model_path)
28
+
29
+
30
+ def test_predict_single(langid_model, mock_fasttext):
31
+ result = langid_model.predict("Hello world")
32
+
33
+ assert isinstance(result, list)
34
+ assert len(result) == 1
35
+ assert result[0] == ("en", 0.9)
36
+ mock_fasttext.predict.assert_called_once_with(["Hello world"], k=1, threshold=0.0)
37
+
38
+ def test_predict_batch(langid_model, mock_fasttext):
39
+ texts = ["Hello", "Bonjour"]
40
+ results = langid_model.predict(texts, k=2)
41
+
42
+ assert isinstance(results, list)
43
+ assert len(results) == 2
44
+ for r in results:
45
+ assert len(r) == 2
46
+ assert r[0] == ("en", 0.9)
47
+
48
+ mock_fasttext.predict.assert_called_once_with(texts, k=2, threshold=0.0)
49
+
50
+ def test_predict_best_single(langid_model):
51
+ result = langid_model.predict_best("Hello")
52
+ assert result == "en"
53
+
54
+ def test_predict_best_batch(langid_model):
55
+ results = langid_model.predict_best(["Hello", "World"])
56
+ assert results == ["en", "en"]
57
+
58
+ def test_predict_threshold(langid_model, mock_fasttext):
59
+ # Configure mock to return nothing if threshold is high (simulated)
60
+ def mock_predict_low_score(items, k=1, threshold=0.0):
61
+ if threshold > 0.9:
62
+ return [[] for _ in items], [[] for _ in items]
63
+ return [["__label__en"] for _ in items], [[0.9] for _ in items]
64
+
65
+ mock_fasttext.predict.side_effect = mock_predict_low_score
66
+
67
+ result = langid_model.predict_best("Hello", threshold=0.95)
68
+ assert result is None
69
+
70
+ result = langid_model.predict_best("Hello", threshold=0.5)
71
+ assert result == "en"
tests/test_langid_batch.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+
4
+
5
+ @pytest.mark.asyncio
6
+ async def test_langid_batch(client: AsyncClient):
7
+ """Verify that language identification works for a list of strings."""
8
+ payload = {"src": ["This is English text.", "Ceci est un texte français."]}
9
+ response = await client.post("/api/identify-language", json=payload)
10
+ assert response.status_code == 200
11
+ data = response.json()
12
+
13
+ # Expect a list of lists of DetectionResult
14
+ results = data["results"]
15
+ assert isinstance(results, list)
16
+ assert len(results) == 2
17
+
18
+ # First item: English
19
+ assert len(results[0]) >= 1
20
+ assert results[0][0]["lang"] == "en"
21
+
22
+ # Second item: French
23
+ assert len(results[1]) >= 1
24
+ assert results[1][0]["lang"] == "fr"
25
+
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_langid_newline_handling(client: AsyncClient):
29
+ """Verify that inputs with newlines are handled gracefully (no 500 error)."""
30
+ # Single string with newline
31
+ payload_single = {"src": "This text\nhas a newline."}
32
+ response = await client.post("/api/identify-language", json=payload_single)
33
+ assert response.status_code == 200
34
+ data = response.json()
35
+ assert data["results"][0]["lang"] == "en"
36
+
37
+ # Batch with newlines
38
+ payload_batch = {"src": ["Line 1\nLine 2", "Another\nline"]}
39
+ response = await client.post("/api/identify-language", json=payload_batch)
40
+ assert response.status_code == 200
41
+ data = response.json()
42
+ assert len(data["results"]) == 2
tests/test_langid_path.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from unittest.mock import patch
4
+ from quickmt.langid import ensure_model_exists, LanguageIdentification
5
+
6
+ def test_langid_default_path():
7
+ """Verify that LanguageIdentification uses the XDG cache path by default."""
8
+ # Mock os.getenv to ensure we test the default behavior, but respect XDG_CACHE_HOME if we want to mock it.
9
+ # Here we simulate no explicit model path provided.
10
+
11
+ with patch("quickmt.langid.fasttext.load_model") as mock_load, \
12
+ patch("quickmt.langid.urllib.request.urlretrieve") as mock_retrieve, \
13
+ patch("pathlib.Path.exists") as mock_exists, \
14
+ patch("pathlib.Path.mkdir") as mock_mkdir:
15
+
16
+ # Simulate model cached and exists
17
+ mock_exists.return_value = True
18
+
19
+ lid = LanguageIdentification(model_path=None)
20
+
21
+ # Verify load_model was called with a path in the cache
22
+ args, _ = mock_load.call_args
23
+ loaded_path = str(args[0])
24
+
25
+ expected_part = os.path.join(".cache", "fasttext_language_id", "lid.176.bin")
26
+ assert expected_part in loaded_path
27
+
28
+ # Old path should not be used
29
+ assert "models/lid.176.ftz" not in loaded_path
30
+
31
+ def test_ensure_model_exists_path():
32
+ """Verify ensure_model_exists resolves to cache path."""
33
+ with patch("quickmt.langid.urllib.request.urlretrieve") as mock_retrieve, \
34
+ patch("pathlib.Path.exists") as mock_exists, \
35
+ patch("pathlib.Path.mkdir") as mock_mkdir:
36
+
37
+ # Simulate model missing to trigger download logic path check
38
+ mock_exists.return_value = False
39
+
40
+ ensure_model_exists(None)
41
+
42
+ # Check download target
43
+ args, _ = mock_retrieve.call_args
44
+ download_target = str(args[1])
45
+
46
+ expected_part = os.path.join(".cache", "fasttext_language_id", "lid.176.bin")
47
+ assert expected_part in download_target
tests/test_lru.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+
4
+
5
+ @pytest.mark.asyncio
6
+ async def test_lru_eviction(client: AsyncClient):
7
+ """
8
+ Test that the server correctly unloads the least recently used model
9
+ when MAX_LOADED_MODELS is exceeded.
10
+ """
11
+ # 1. Get available models
12
+ models_res = await client.get("/api/models")
13
+ available_models = models_res.json()["models"]
14
+
15
+ # 2. Get MAX_LOADED_MODELS from health
16
+ health_res = await client.get("/api/health")
17
+ max_models = health_res.json()["max_models"]
18
+
19
+ if len(available_models) <= max_models:
20
+ pytest.skip(
21
+ f"Not enough models in MODELS_DIR to test eviction (need > {max_models})"
22
+ )
23
+
24
+ # 3. Load max_models + 1 models sequentially
25
+ loaded_in_order = []
26
+ for i in range(max_models + 1):
27
+ model = available_models[i]
28
+ payload = {
29
+ "src": "test",
30
+ "src_lang": model["src_lang"],
31
+ "tgt_lang": model["tgt_lang"],
32
+ }
33
+ await client.post("/api/translate", json=payload)
34
+ loaded_in_order.append(f"{model['src_lang']}-{model['tgt_lang']}")
35
+
36
+ # 4. Check currently loaded models
37
+ health_after = await client.get("/api/health")
38
+ currently_loaded = health_after.json()["loaded_models"]
39
+
40
+ # The first model should have been evicted
41
+ first_model = loaded_in_order[0]
42
+ assert first_model not in currently_loaded
43
+ assert len(currently_loaded) == max_models
44
+
45
+ # The most recently requested model should be there
46
+ last_model = loaded_in_order[-1]
47
+ assert last_model in currently_loaded
tests/test_manager.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ from pathlib import Path
4
+ from unittest.mock import MagicMock, patch, AsyncMock
5
+ from quickmt.manager import ModelManager, BatchTranslator
6
+ from quickmt.translator import Translator
7
+
8
+ @pytest.fixture
9
+ def mock_translator():
10
+ with patch("quickmt.manager.Translator") as mock:
11
+ instance = MagicMock()
12
+ mock.return_value = instance
13
+ yield instance
14
+
15
+ @pytest.fixture
16
+ def mock_hf():
17
+ with patch("quickmt.manager.snapshot_download") as mock_dl, \
18
+ patch("quickmt.manager.HfApi") as mock_api:
19
+
20
+ # Mock collection fetch
21
+ coll = MagicMock()
22
+ coll.items = [
23
+ MagicMock(item_id="quickmt/quickmt-en-fr", item_type="model"),
24
+ MagicMock(item_id="quickmt/quickmt-fr-en", item_type="model")
25
+ ]
26
+ mock_api.return_value.get_collection.return_value = coll
27
+ mock_dl.return_value = "/tmp/mock-model-path"
28
+
29
+ yield mock_api, mock_dl
30
+
31
+ class TestBatchTranslator:
32
+ @pytest.mark.asyncio
33
+ async def test_translate_single(self, mock_translator):
34
+ bt = BatchTranslator("test-id", "/tmp/path")
35
+
36
+ # Mock translator call
37
+ mock_translator.return_value = "Hola"
38
+
39
+ result = await bt.translate("Hello", src_lang="en", tgt_lang="es")
40
+ assert result == "Hola"
41
+ assert bt.worker_task is not None
42
+
43
+ await bt.stop_worker()
44
+ assert bt.worker_task is None
45
+
46
+ class TestModelManager:
47
+ @pytest.mark.asyncio
48
+ async def test_fetch_hf_models(self, mock_hf):
49
+ mm = ModelManager(max_loaded=2, device="cpu")
50
+ await mm.fetch_hf_models()
51
+
52
+ assert len(mm.hf_collection_models) == 2
53
+ assert mm.hf_collection_models[0]["src_lang"] == "en"
54
+ assert mm.hf_collection_models[0]["tgt_lang"] == "fr"
55
+
56
+ @pytest.mark.asyncio
57
+ async def test_get_model_lazy_load(self, mock_hf, mock_translator):
58
+ mm = ModelManager(max_loaded=2, device="cpu")
59
+ await mm.fetch_hf_models()
60
+
61
+ # This should trigger download and start worker
62
+ bt = await mm.get_model("en", "fr")
63
+ assert isinstance(bt, BatchTranslator)
64
+ assert "en-fr" in mm.models
65
+ assert bt.model_id == "quickmt/quickmt-en-fr"
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_lru_eviction(self, mock_hf, mock_translator):
69
+ # Set max_loaded to 1 to trigger eviction immediately
70
+ mm = ModelManager(max_loaded=1, device="cpu")
71
+ await mm.fetch_hf_models()
72
+
73
+ # Load first
74
+ bt1 = await mm.get_model("en", "fr")
75
+ assert len(mm.models) == 1
76
+
77
+ # Load second (should evict first)
78
+ bt2 = await mm.get_model("fr", "en")
79
+ assert len(mm.models) == 1
80
+ assert "fr-en" in mm.models
81
+ assert "en-fr" not in mm.models
82
+ @pytest.mark.asyncio
83
+ async def test_get_model_cache_first(self, mock_hf, mock_translator):
84
+ mock_api, mock_dl = mock_hf
85
+ mm = ModelManager(max_loaded=2, device="cpu")
86
+ await mm.fetch_hf_models()
87
+
88
+ # Scenario 1: Local cache hit
89
+ # Reset mock to track new calls
90
+ mock_dl.reset_mock()
91
+ mock_dl.return_value = "/tmp/mock-model-path"
92
+
93
+ await mm.get_model("en", "fr")
94
+
95
+ # Verify it tried local_files_only=True first
96
+ assert mock_dl.call_count == 1
97
+ args, kwargs = mock_dl.call_args
98
+ assert kwargs.get("local_files_only") is True
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_get_model_fallback(self, mock_hf, mock_translator):
102
+ mock_api, mock_dl = mock_hf
103
+ mm = ModelManager(max_loaded=2, device="cpu")
104
+ await mm.fetch_hf_models()
105
+
106
+ # Scenario 2: Local cache miss, fallback to online
107
+ # First call fails, second succeeds
108
+ mock_dl.side_effect = [Exception("Not found locally"), "/tmp/mock-model-path"]
109
+
110
+ await mm.get_model("fr", "en")
111
+
112
+ assert mock_dl.call_count == 2
113
+ # First call was local only
114
+ args1, kwargs1 = mock_dl.call_args_list[0]
115
+ assert kwargs1.get("local_files_only") is True
116
+ # Second call was online (no local_files_only or False)
117
+ args2, kwargs2 = mock_dl.call_args_list[1]
118
+ assert not kwargs2.get("local_files_only")
tests/test_mixed_src.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+
4
+
5
+ @pytest.mark.asyncio
6
+ async def test_explicit_mixed_languages(client: AsyncClient):
7
+ """Verify explicit src_lang list for a mixed batch."""
8
+ # Ensure needed models are available
9
+ models_res = await client.get("/api/models")
10
+ available_models = models_res.json()["models"]
11
+
12
+ needed = [("fr", "en"), ("es", "en")]
13
+ for src, tgt in needed:
14
+ if not any(
15
+ m["src_lang"] == src and m["tgt_lang"] == tgt for m in available_models
16
+ ):
17
+ pytest.skip(f"Mixed batch test needs both fr-en and es-en models")
18
+
19
+ # Explicitly specify languages for each input
20
+ payload = {"src": ["Bonjour", "Hola"], "src_lang": ["fr", "es"], "tgt_lang": "en"}
21
+
22
+ response = await client.post("/api/translate", json=payload)
23
+ assert response.status_code == 200
24
+ data = response.json()
25
+
26
+ assert data["src_lang"] == ["fr", "es"]
27
+ assert data["model_used"] == ["quickmt/quickmt-fr-en", "quickmt/quickmt-es-en"]
28
+ assert len(data["translation"]) == 2
29
+
30
+
31
+ @pytest.mark.asyncio
32
+ async def test_src_lang_length_mismatch(client: AsyncClient):
33
+ """Verify 422 error when src and src_lang lengths differ."""
34
+ payload = {
35
+ "src": ["Hello", "World"],
36
+ "src_lang": ["en"], # Only 1 language for 2 inputs
37
+ "tgt_lang": "es",
38
+ }
39
+
40
+ response = await client.post("/api/translate", json=payload)
41
+ assert response.status_code == 422
42
+ assert (
43
+ "src_lang list length must match src list length" in response.json()["detail"]
44
+ )
45
+
46
+
47
+ @pytest.mark.asyncio
48
+ async def test_src_lang_list_with_single_src(client: AsyncClient):
49
+ """Verify single src string with single-item src_lang list is not allowed or handled gracefully."""
50
+ # The Pydantic model allows this, but our logic checks lengths.
51
+ # If src is str, src_list has len 1. If src_lang is list, it must have len 1.
52
+
53
+ # Needs a model
54
+ models_res = await client.get("/api/models")
55
+ models = models_res.json()["models"]
56
+ if not models:
57
+ pytest.skip("No models available")
58
+ model = models[0]
59
+
60
+ payload = {
61
+ "src": "Hello",
62
+ "src_lang": [model["src_lang"]],
63
+ "tgt_lang": model["tgt_lang"],
64
+ }
65
+
66
+ response = await client.post("/api/translate", json=payload)
67
+ assert response.status_code == 200
68
+ data = response.json()
69
+ assert data["src_lang"] == model["src_lang"]
tests/test_robustness.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ import time
4
+ from httpx import AsyncClient
5
+
6
+
7
+ @pytest.mark.asyncio
8
+ async def test_model_not_found(client: AsyncClient):
9
+ """Verify that requesting a non-existent model returns 404."""
10
+ payload = {
11
+ "src": "Hello",
12
+ "src_lang": "en",
13
+ "tgt_lang": "zz", # Non-existent
14
+ }
15
+ response = await client.post("/api/translate", json=payload)
16
+ assert response.status_code == 404
17
+ assert "not found" in response.json()["detail"]
18
+
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_empty_input_string(client: AsyncClient):
22
+ """Verify handling of empty string input."""
23
+ models_res = await client.get("/api/models")
24
+ model = models_res.json()["models"][0]
25
+
26
+ payload = {"src": "", "src_lang": model["src_lang"], "tgt_lang": model["tgt_lang"]}
27
+ response = await client.post("/api/translate", json=payload)
28
+ assert response.status_code == 200
29
+ assert response.json()["translation"] == ""
30
+
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_empty_input_list(client: AsyncClient):
34
+ """Verify handling of empty list input."""
35
+ models_res = await client.get("/api/models")
36
+ model = models_res.json()["models"][0]
37
+
38
+ payload = {"src": [], "src_lang": model["src_lang"], "tgt_lang": model["tgt_lang"]}
39
+ response = await client.post("/api/translate", json=payload)
40
+ assert response.status_code == 200
41
+ assert response.json()["translation"] == []
42
+
43
+
44
+ @pytest.mark.asyncio
45
+ async def test_invalid_input_type(client: AsyncClient):
46
+ """Verify that invalid input types are rejected by Pydantic."""
47
+ payload = {
48
+ "src": 123, # Should be string or list of strings
49
+ "src_lang": "en",
50
+ "tgt_lang": "fr",
51
+ }
52
+ response = await client.post("/api/translate", json=payload)
53
+ assert response.status_code == 422 # Unprocessable Entity (Validation Error)
54
+
55
+
56
+ @pytest.mark.asyncio
57
+ async def test_concurrent_model_load(client: AsyncClient):
58
+ """
59
+ Test that concurrent requests for a new model are handled correctly
60
+ (only one load should happen, others wait on the event).
61
+ """
62
+ # Find a model that is definitely NOT loaded
63
+ health_res = await client.get("/api/health")
64
+ loaded = health_res.json()["loaded_models"]
65
+
66
+ models_res = await client.get("/api/models")
67
+ available = models_res.json()["models"]
68
+
69
+ target_model = None
70
+ for m in available:
71
+ lang_pair = f"{m['src_lang']}-{m['tgt_lang']}"
72
+ if lang_pair not in loaded:
73
+ target_model = m
74
+ break
75
+
76
+ if not target_model:
77
+ pytest.skip("No unloaded models available to test concurrent loading")
78
+
79
+ # Send multiple concurrent requests for the same new model
80
+ payload = {
81
+ "src": "Concurrent test",
82
+ "src_lang": target_model["src_lang"],
83
+ "tgt_lang": target_model["tgt_lang"],
84
+ }
85
+
86
+ tasks = [client.post("/api/translate", json=payload) for _ in range(3)]
87
+ responses = await asyncio.gather(*tasks)
88
+
89
+ for resp in responses:
90
+ assert resp.status_code == 200
91
+ assert "translation" in resp.json()
92
+
93
+
94
+ @pytest.mark.asyncio
95
+ async def test_parameter_overrides(client: AsyncClient):
96
+ """Verify that request-level parameters are respected."""
97
+ models_res = await client.get("/api/models")
98
+ model = models_res.json()["models"][0]
99
+
100
+ payload = {
101
+ "src": "This is a test of parameter overrides.",
102
+ "src_lang": model["src_lang"],
103
+ "tgt_lang": model["tgt_lang"],
104
+ "beam_size": 1,
105
+ "max_decoding_length": 5,
106
+ }
107
+
108
+ response = await client.post("/api/translate", json=payload)
109
+ assert response.status_code == 200
110
+ # With max_decoding_length=5, the translation should be very short
111
+ # Note: tokens != words, but usually it translates to 1-3 words
112
+ trans = response.json()["translation"]
113
+ # We can't strictly assert word count but we can check it's non-empty
114
+ assert len(trans) > 0
115
+
116
+
117
+ @pytest.mark.asyncio
118
+ async def test_large_batch_processing(client: AsyncClient):
119
+ """Verify processing of a batch larger than MAX_BATCH_SIZE."""
120
+ models_res = await client.get("/api/models")
121
+ models = models_res.json()["models"]
122
+ if not models:
123
+ pytest.skip("No translation models available")
124
+ model = models[0]
125
+
126
+ # Send 50 sentences (default MAX_BATCH_SIZE is 32)
127
+ sentences = [f"This is sentence {i}" for i in range(50)]
128
+ payload = {
129
+ "src": sentences,
130
+ "src_lang": model["src_lang"],
131
+ "tgt_lang": model["tgt_lang"],
132
+ }
133
+
134
+ response = await client.post("/api/translate", json=payload)
135
+ assert response.status_code == 200
136
+ data = response.json()
137
+ assert len(data["translation"]) == 50
tests/test_threading_config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import patch
3
+ from quickmt.manager import ModelManager
4
+
5
+
6
+ @pytest.mark.asyncio
7
+ async def test_threading_config_propagation():
8
+ """Verify that inter_threads and intra_threads are passed to CTranslate2."""
9
+
10
+ # Mocking components to prevent actual model loading
11
+ with patch("quickmt.manager.Translator") as mock_translator_cls:
12
+ # Configuration
13
+ inter = 2
14
+ intra = 4
15
+
16
+ manager = ModelManager(
17
+ max_loaded=1,
18
+ device="cpu",
19
+ compute_type="int8",
20
+ inter_threads=inter,
21
+ intra_threads=intra,
22
+ )
23
+
24
+ # Inject a dummy model to collection
25
+ manager.hf_collection_models = [
26
+ {"model_id": "test/model", "src_lang": "en", "tgt_lang": "fr"}
27
+ ]
28
+
29
+ # Mock snapshot_download
30
+ with patch("quickmt.manager.snapshot_download", return_value="/tmp/model"):
31
+ # Trigger model load
32
+ await manager.get_model("en", "fr")
33
+
34
+ # Verify Translator was instantiated with correct parameters
35
+ args, kwargs = mock_translator_cls.call_args
36
+ assert kwargs["inter_threads"] == inter
37
+ assert kwargs["intra_threads"] == intra
38
+ assert kwargs["device"] == "cpu"
39
+ assert kwargs["compute_type"] == "int8"
tests/test_translation_quality.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+ import sacrebleu
4
+
5
+ # 10 Diverse English -> French pairs
6
+ EN_FR_PAIRS = [
7
+ ("Hello world", "Bonjour le monde"),
8
+ ("The cat sits on the mat.", "Le chat est assis sur le tapis."),
9
+ ("I would like a coffee, please.", "Je voudrais un café, s'il vous plaît."),
10
+ ("Where is the nearest train station?", "Où est la gare la plus proche ?"),
11
+ (
12
+ "Artificial intelligence is fascinating.",
13
+ "L'intelligence artificielle est fascinante.",
14
+ ),
15
+ ("Can you help me translate this?", "Pouvez-vous m'aider à traduire ceci ?"),
16
+ ("It is raining today.", "Il pleut aujourd'hui."),
17
+ ("Programming is fun.", "La programmation est amusante."),
18
+ ("I am learning French.", "J'apprends le français."),
19
+ ("Have a nice day.", "Bonne journée."),
20
+ ]
21
+
22
+ # 10 Diverse French -> English pairs
23
+ FR_EN_PAIRS = [
24
+ ("Bonjour tout le monde", "Hello everyone"),
25
+ ("La vie est belle", "Life is beautiful"),
26
+ ("Je suis fatigué", "I am tired"),
27
+ ("Quelle heure est-il ?", "What time is it?"),
28
+ ("J'aime manger des croissants", "I like eating croissants"),
29
+ ("Merci beaucoup", "Thank you very much"),
30
+ ("À bientôt", "See you soon"),
31
+ ("Le livre est sur la table", "The book is on the table"),
32
+ ("Je ne comprends pas", "I do not understand"),
33
+ ("C'est magnifique", "It is magnificent"),
34
+ ]
35
+
36
+
37
+ async def translate_batch(client, texts, src, tgt):
38
+ payload = {"src": texts, "src_lang": src, "tgt_lang": tgt}
39
+ response = await client.post("/api/translate", json=payload)
40
+ if response.status_code != 200:
41
+ return []
42
+ return response.json()["translation"]
43
+
44
+
45
+ @pytest.mark.asyncio
46
+ async def test_quality_en_fr(client: AsyncClient):
47
+ """Assess translation quality for English to French."""
48
+ # Check model availability first
49
+ models_res = await client.get("/api/models")
50
+ available_models = models_res.json()["models"]
51
+ if not any(
52
+ m["src_lang"] == "en" and m["tgt_lang"] == "fr" for m in available_models
53
+ ):
54
+ pytest.skip("en-fr model needed for this test")
55
+
56
+ sources = [p[0] for p in EN_FR_PAIRS]
57
+ refs = [
58
+ [p[1] for p in EN_FR_PAIRS]
59
+ ] # sacrebleu expects list of lists of references
60
+
61
+ hyps = await translate_batch(client, sources, "en", "fr")
62
+ assert len(hyps) == len(sources)
63
+
64
+ # Calculate BLEU
65
+ bleu = sacrebleu.corpus_bleu(hyps, refs)
66
+ chrf = sacrebleu.corpus_chrf(hyps, refs)
67
+
68
+ print(f"\nEN->FR Quality: BLEU={bleu.score:.2f}, CHRF={chrf.score:.2f}")
69
+
70
+ # Assert minimum quality (adjust baselines based on model capability)
71
+ # Generic models should at least get > 10 BLEU on simple sentences
72
+ assert bleu.score > 40.0
73
+ assert chrf.score > 70.0
74
+
75
+
76
+ @pytest.mark.asyncio
77
+ async def test_quality_fr_en(client: AsyncClient):
78
+ """Assess translation quality for French to English."""
79
+ # Check model availability first
80
+ models_res = await client.get("/api/models")
81
+ available_models = models_res.json()["models"]
82
+ if not any(
83
+ m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
84
+ ):
85
+ pytest.skip("fr-en model needed for this test")
86
+
87
+ sources = [p[0] for p in FR_EN_PAIRS]
88
+ refs = [[p[1] for p in FR_EN_PAIRS]]
89
+
90
+ hyps = await translate_batch(client, sources, "fr", "en")
91
+ assert len(hyps) == len(sources)
92
+
93
+ # Calculate BLEU
94
+ bleu = sacrebleu.corpus_bleu(hyps, refs)
95
+ chrf = sacrebleu.corpus_chrf(hyps, refs)
96
+
97
+ print(f"\nFR->EN Quality: BLEU={bleu.score:.2f}, CHRF={chrf.score:.2f}")
98
+
99
+ # Assert minimum quality
100
+ assert bleu.score > 40.0
101
+ assert chrf.score > 70.0
tests/test_translator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from pathlib import Path
3
+ from unittest.mock import MagicMock, patch
4
+ from quickmt.translator import Translator, TranslatorABC
5
+
6
+
7
+ # Mock objects
8
+ @pytest.fixture
9
+ def mock_ctranslate2():
10
+ with patch("ctranslate2.Translator") as mock:
11
+ yield mock
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_sentencepiece():
16
+ with patch("sentencepiece.SentencePieceProcessor") as mock:
17
+ yield mock
18
+
19
+
20
+ @pytest.fixture
21
+ def temp_model_dir(tmp_path):
22
+ """Create a dummy model directory with required files."""
23
+ model_dir = tmp_path / "dummy-model"
24
+ model_dir.mkdir()
25
+ (model_dir / "src.spm.model").write_text("dummy")
26
+ (model_dir / "tgt.spm.model").write_text("dummy")
27
+ return model_dir
28
+
29
+
30
+ @pytest.fixture
31
+ def translator_instance(temp_model_dir, mock_ctranslate2, mock_sentencepiece):
32
+ return Translator(temp_model_dir)
33
+
34
+
35
+ class TestTranslatorABC:
36
+ def test_sentence_split(self):
37
+ src = ["Hello world. This is a test.", "Another paragraph."]
38
+ input_ids, paragraph_ids, sentences = TranslatorABC._sentence_split(src)
39
+
40
+ assert len(sentences) == 3
41
+ assert input_ids == [0, 0, 1]
42
+ assert paragraph_ids == [0, 0, 0]
43
+ assert sentences[0] == "Hello world."
44
+ assert sentences[1] == "This is a test."
45
+ assert sentences[2] == "Another paragraph."
46
+
47
+ def test_sentence_join(self):
48
+ input_ids = [0, 0, 1]
49
+ paragraph_ids = [0, 0, 0]
50
+ sentences = ["Hello world.", "This is a test.", "Another paragraph."]
51
+
52
+ joined = TranslatorABC._sentence_join(input_ids, paragraph_ids, sentences)
53
+ assert len(joined) == 2
54
+ assert joined[0] == "Hello world. This is a test."
55
+ assert joined[1] == "Another paragraph."
56
+
57
+ def test_sentence_join_empty(self):
58
+ assert TranslatorABC._sentence_join([], [], [], length=5) == [""] * 5
59
+
60
+
61
+ class TestTranslator:
62
+ def test_init_joint_tokens(self, tmp_path, mock_ctranslate2, mock_sentencepiece):
63
+ model_dir = tmp_path / "joint-model"
64
+ model_dir.mkdir()
65
+ (model_dir / "joint.spm.model").write_text("dummy")
66
+
67
+ translator = Translator(model_dir)
68
+ assert mock_sentencepiece.call_count == 2
69
+ # Verify it used the joint model for both
70
+ args, kwargs = mock_sentencepiece.call_args_list[0]
71
+ assert "joint.spm.model" in kwargs["model_file"]
72
+
73
+ def test_tokenize(self, translator_instance):
74
+ translator_instance.source_tokenizer.encode.return_value = [
75
+ ["token1", "token2"]
76
+ ]
77
+ result = translator_instance.tokenize(["Hello"])
78
+ assert result == [["token1", "token2", "</s>"]]
79
+ translator_instance.source_tokenizer.encode.assert_called_with(
80
+ ["Hello"], out_type=str
81
+ )
82
+
83
+ def test_detokenize(self, translator_instance):
84
+ translator_instance.target_tokenizer.decode.return_value = ["Hello"]
85
+ result = translator_instance.detokenize([["token1", "token2"]])
86
+ assert result == ["Hello"]
87
+ translator_instance.target_tokenizer.decode.assert_called_with(
88
+ [["token1", "token2"]]
89
+ )
90
+
91
+ def test_unload(self, translator_instance):
92
+ del translator_instance.translator
93
+ # Should not raise
94
+ translator_instance.unload()
95
+
96
+ def test_call_full_pipeline(self, translator_instance):
97
+ # Mock the steps
98
+ with (
99
+ patch.object(Translator, "tokenize") as mock_tok,
100
+ patch.object(Translator, "translate_batch") as mock_trans,
101
+ patch.object(Translator, "detokenize") as mock_detok,
102
+ ):
103
+ mock_tok.return_value = [["tok"]]
104
+ mock_res = MagicMock()
105
+ mock_res.hypotheses = [["hypo"]]
106
+ mock_trans.return_value = [mock_res]
107
+ mock_detok.return_value = ["Translated sentence."]
108
+
109
+ result = translator_instance("Source text.")
110
+ assert result == "Translated sentence."
111
+
112
+ mock_tok.assert_called_once()
113
+ mock_trans.assert_called_once()
114
+ mock_detok.assert_called_once()
115
+
116
+ def test_translate_stream(self, translator_instance):
117
+ translator_instance.translator.translate_iterable = MagicMock(
118
+ return_value=[
119
+ MagicMock(hypotheses=[["hypo1"]]),
120
+ MagicMock(hypotheses=[["hypo2"]]),
121
+ ]
122
+ )
123
+
124
+ with (
125
+ patch.object(Translator, "tokenize") as mock_tok,
126
+ patch.object(Translator, "detokenize") as mock_detok,
127
+ ):
128
+ mock_tok.return_value = [["tok1"], ["tok2"]]
129
+ mock_detok.side_effect = lambda x: [f"Detok {x[0][0]}"]
130
+
131
+ results = list(translator_instance.translate_stream(["Sent 1.", "Sent 2."]))
132
+ assert len(results) == 2
133
+ assert results[0]["translation"] == "Detok hypo1"
134
+ assert results[1]["translation"] == "Detok hypo2"
135
+
136
+ def test_translate_file(self, translator_instance, tmp_path):
137
+ input_file = tmp_path / "input.txt"
138
+ output_file = tmp_path / "output.txt"
139
+ input_file.write_text("Line 1\nLine 2")
140
+
141
+ with patch.object(Translator, "__call__") as mock_call:
142
+ mock_call.return_value = ["Trans 1", "Trans 2"]
143
+ translator_instance.translate_file(str(input_file), str(output_file))
144
+
145
+ content = output_file.read_text()
146
+ assert content == "Trans 1\nTrans 2\n"
147
+
148
+ def test_translate_batch(self, translator_instance):
149
+ translator_instance.translate_batch(
150
+ [["tok"]],
151
+ beam_size=10,
152
+ patience=2,
153
+ max_batch_size=16,
154
+ num_hypotheses=5, # kwargs
155
+ )
156
+ translator_instance.translator.translate_batch.assert_called_once()
157
+ args, kwargs = translator_instance.translator.translate_batch.call_args
158
+ assert kwargs["beam_size"] == 10
159
+ assert kwargs["patience"] == 2
160
+ assert kwargs["max_batch_size"] == 16
161
+ assert kwargs["num_hypotheses"] == 5