Spaces:
Sleeping
Sleeping
radinplaid commited on
Commit ·
b6b0c93
1
Parent(s): e732115
Initial commit
Browse files- Dockerfile +37 -0
- README.md +65 -11
- locustfile.py +131 -0
- pyproject.toml +45 -0
- pytest.ini +3 -0
- quickmt/__init__.py +0 -0
- quickmt/gui/static/app.js +408 -0
- quickmt/gui/static/index.html +212 -0
- quickmt/gui/static/style.css +668 -0
- quickmt/langid.py +150 -0
- quickmt/manager.py +367 -0
- quickmt/rest_server.py +358 -0
- quickmt/settings.py +69 -0
- quickmt/translator.py +390 -0
- requirements-dev.txt +5 -0
- requirements.txt +14 -0
- tests/__init__.py +0 -0
- tests/conftest.py +16 -0
- tests/test_api.py +109 -0
- tests/test_auto_translate.py +108 -0
- tests/test_cache.py +68 -0
- tests/test_identify_language.py +40 -0
- tests/test_langid.py +71 -0
- tests/test_langid_batch.py +42 -0
- tests/test_langid_path.py +47 -0
- tests/test_lru.py +47 -0
- tests/test_manager.py +118 -0
- tests/test_mixed_src.py +69 -0
- tests/test_robustness.py +137 -0
- tests/test_threading_config.py +39 -0
- tests/test_translation_quality.py +101 -0
- tests/test_translator.py +161 -0
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
build-essential \
|
| 9 |
+
curl \
|
| 10 |
+
python3-pip \
|
| 11 |
+
python3 \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy local code to the container
|
| 15 |
+
COPY . /app
|
| 16 |
+
|
| 17 |
+
# Create a non-root user for security
|
| 18 |
+
RUN useradd -m -u 1001 user
|
| 19 |
+
|
| 20 |
+
ENV HOME=/home/user \
|
| 21 |
+
PATH=/home/user/.local/bin:$PATH
|
| 22 |
+
|
| 23 |
+
WORKDIR $HOME/app
|
| 24 |
+
COPY --chown=user . $HOME/app
|
| 25 |
+
|
| 26 |
+
# Install the package and dependencies
|
| 27 |
+
# This also installs the quickmt cli scripts
|
| 28 |
+
RUN pip install --break-system-packages --no-cache-dir /app/
|
| 29 |
+
|
| 30 |
+
# Expose the default FastAPI port
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
|
| 33 |
+
USER user
|
| 34 |
+
|
| 35 |
+
# Hf Spaces expect the app on port 7860 usually
|
| 36 |
+
# We override the port via env var or CLI arg
|
| 37 |
+
CMD ["uvicorn", "quickmt.rest_server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,65 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# `quickmt` Neural Machine Translation Inference Library
|
| 2 |
+
|
| 3 |
+
## REST Server Features
|
| 4 |
+
|
| 5 |
+
- **Dynamic Batching**: Multiple concurrent HTTP requests are pooled together to maximize GPU utilization.
|
| 6 |
+
- **Multi-Model Support**: Requests are routed to specific models based on `src_lang` and `tgt_lang`.
|
| 7 |
+
- **LRU Cache**: Automatically loads and unloads models based on usage to manage memory.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Installation
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
pip install -r requirements.txt
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Running the Web Application
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
export MAX_LOADED_MODELS=3
|
| 20 |
+
export MAX_BATCH_SIZE=32
|
| 21 |
+
export DEVICE=cuda # or cpu
|
| 22 |
+
export COMPUTE_TYPE=int8 # default, auto, int8, float16, etc.
|
| 23 |
+
quickmt-gui
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Running the REST Server
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
export MAX_LOADED_MODELS=3
|
| 32 |
+
export MAX_BATCH_SIZE=32
|
| 33 |
+
export DEVICE=cuda # or cpu
|
| 34 |
+
export COMPUTE_TYPE=int8 # default, auto, int8, float16, etc.
|
| 35 |
+
quickmt-api
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## API Usage
|
| 40 |
+
|
| 41 |
+
### Translate
|
| 42 |
+
```bash
|
| 43 |
+
curl -X POST http://localhost:8000/translate \
|
| 44 |
+
-H "Content-Type: application/json" \
|
| 45 |
+
-d '{"src":"Hello world","src_lang":null,"tgt_lang":"fr","beam_size":2,"patience":1,"length_penalty":1,"coverage_penalty":0,"repetition_penalty":1}'
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
```json
|
| 50 |
+
{
|
| 51 |
+
"translation":"Bonjour tout le monde !",
|
| 52 |
+
"src_lang":"en",
|
| 53 |
+
"src_lang_score":0.16532786190509796,
|
| 54 |
+
"tgt_lang":"fr",
|
| 55 |
+
"processing_time":2.2334513664245605,
|
| 56 |
+
"model_used":"quickmt/quickmt-en-fr"
|
| 57 |
+
}
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Load Testing with Locust
|
| 61 |
+
To simulate a multi-user load:
|
| 62 |
+
```bash
|
| 63 |
+
locust -f locustfile.py --host http://localhost:8000
|
| 64 |
+
```
|
| 65 |
+
Then open http://localhost:8089 in your browser.
|
locustfile.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from locust import FastHttpUser, task, between
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TranslationUser(FastHttpUser):
|
| 6 |
+
wait_time = between(0, 0)
|
| 7 |
+
|
| 8 |
+
# Sample sentences for translation and identification
|
| 9 |
+
sample_texts = [
|
| 10 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 11 |
+
"Can we translate this correctly and quickly?",
|
| 12 |
+
"هذا نص تجريبي باللغة العربية.", # Arabic
|
| 13 |
+
"الذكاء الاصطناعي هو المستقبل.", # Arabic (AI is the future)
|
| 14 |
+
"أحب تعلم لغات جديدة.", # Arabic (I love learning new languages)
|
| 15 |
+
"这是一段中文测试文本。", # Chinese
|
| 16 |
+
"人工智能正在改变世界。", # Chinese (AI is changing the world)
|
| 17 |
+
"今天天气真好,去公园散步。", # Chinese (Weather is nice, let's walk)
|
| 18 |
+
"Bonjour, comment allez-vous ?", # French
|
| 19 |
+
"L'intelligence artificielle transforme notre vie quotidienne.", # French (AI transforms daily life)
|
| 20 |
+
"Ceci est un exemple de phrase en français.", # French
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
def on_start(self):
|
| 24 |
+
"""Discover available models on startup."""
|
| 25 |
+
try:
|
| 26 |
+
response = self.client.get("/models")
|
| 27 |
+
if response.status_code == 200:
|
| 28 |
+
self.available_models = response.json().get("models", [])
|
| 29 |
+
if not self.available_models:
|
| 30 |
+
print("No models found. Load test might fail.")
|
| 31 |
+
else:
|
| 32 |
+
self.available_models = []
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error discovering models: {e}")
|
| 35 |
+
self.available_models = []
|
| 36 |
+
|
| 37 |
+
def get_random_model(self):
|
| 38 |
+
"""
|
| 39 |
+
Return a model, favoring the first 3 (hot set) 99% of the time,
|
| 40 |
+
and others (cold set) 1% of the time to trigger LRU eviction.
|
| 41 |
+
"""
|
| 42 |
+
if not self.available_models:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
# If we have 4 or more models, we can simulate eviction cycles
|
| 46 |
+
if len(self.available_models) >= 4:
|
| 47 |
+
# 99.99% chance to pick from the first 3
|
| 48 |
+
if random.random() < 0.9999:
|
| 49 |
+
return random.choice(self.available_models[:3])
|
| 50 |
+
else:
|
| 51 |
+
# 0.01% chance to pick from the rest
|
| 52 |
+
return random.choice(self.available_models[3:])
|
| 53 |
+
|
| 54 |
+
return random.choice(self.available_models)
|
| 55 |
+
|
| 56 |
+
@task(1)
|
| 57 |
+
def translate_single(self):
|
| 58 |
+
model = self.get_random_model()
|
| 59 |
+
if not model:
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
self.client.post(
|
| 63 |
+
"/translate",
|
| 64 |
+
json={
|
| 65 |
+
"src": random.choice(self.sample_texts) + str(random.random()),
|
| 66 |
+
"src_lang": model["src_lang"],
|
| 67 |
+
"tgt_lang": model["tgt_lang"],
|
| 68 |
+
"beam_size": 2,
|
| 69 |
+
},
|
| 70 |
+
name="/translate [single, manual]",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@task(1)
|
| 74 |
+
def translate_auto_detect(self):
|
| 75 |
+
"""Translate without specifying src_lang to trigger LangID."""
|
| 76 |
+
ret = self.client.post(
|
| 77 |
+
"/translate",
|
| 78 |
+
json={
|
| 79 |
+
"src": random.choice(self.sample_texts) + str(random.random()),
|
| 80 |
+
"tgt_lang": "en",
|
| 81 |
+
"beam_size": 2,
|
| 82 |
+
},
|
| 83 |
+
name="/translate [single, auto-detect]",
|
| 84 |
+
)
|
| 85 |
+
ret_json = ret.json()
|
| 86 |
+
assert "src_lang" in ret_json
|
| 87 |
+
assert "tgt_lang" in ret_json
|
| 88 |
+
assert "translation" in ret_json
|
| 89 |
+
assert "src_lang_score" in ret_json
|
| 90 |
+
assert "model_used" in ret_json
|
| 91 |
+
assert ret_json["tgt_lang"] == "en"
|
| 92 |
+
|
| 93 |
+
@task(1)
|
| 94 |
+
def translate_list(self):
|
| 95 |
+
model = self.get_random_model()
|
| 96 |
+
if not model:
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
num_sentences = random.randint(2, 5)
|
| 100 |
+
texts = random.sample(self.sample_texts, num_sentences)
|
| 101 |
+
texts = [i + str(random.random()) for i in texts]
|
| 102 |
+
ret = self.client.post(
|
| 103 |
+
"/translate",
|
| 104 |
+
json={
|
| 105 |
+
"src": texts,
|
| 106 |
+
"src_lang": model["src_lang"],
|
| 107 |
+
"tgt_lang": model["tgt_lang"],
|
| 108 |
+
"beam_size": 2,
|
| 109 |
+
},
|
| 110 |
+
name="/translate [list, manual]",
|
| 111 |
+
)
|
| 112 |
+
ret_json = ret.json()
|
| 113 |
+
for i in ret_json["src_lang"]:
|
| 114 |
+
assert i == model["src_lang"]
|
| 115 |
+
assert ret_json["tgt_lang"] == model["tgt_lang"]
|
| 116 |
+
assert len(ret_json["translation"]) == num_sentences
|
| 117 |
+
|
| 118 |
+
@task(1)
|
| 119 |
+
def identify_language(self):
|
| 120 |
+
"""Directly benchmark the identification endpoint."""
|
| 121 |
+
num_sentences = random.randint(1, 4)
|
| 122 |
+
texts = random.sample(self.sample_texts, num_sentences)
|
| 123 |
+
src = texts[0] if num_sentences == 1 else texts
|
| 124 |
+
|
| 125 |
+
self.client.post(
|
| 126 |
+
"/identify-language", json={"src": src}, name="/identify-language"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@task(1)
|
| 130 |
+
def health_check(self):
|
| 131 |
+
self.client.get("/health", name="/health")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "quickmt"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "A fast, multi-model translation API based on CTranslate2 and FastAPI"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "QuickMT Team", email = "hello@quickmt.ai"},
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"blingfire",
|
| 17 |
+
"cachetools",
|
| 18 |
+
"fastapi",
|
| 19 |
+
"uvicorn[standard]",
|
| 20 |
+
"ctranslate2>=3.20.0",
|
| 21 |
+
"sentencepiece",
|
| 22 |
+
"huggingface-hub",
|
| 23 |
+
"fasttext-wheel",
|
| 24 |
+
"orjson",
|
| 25 |
+
"uvloop",
|
| 26 |
+
"httptools",
|
| 27 |
+
"pydantic",
|
| 28 |
+
"pydantic-settings"
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest",
|
| 34 |
+
"pytest-asyncio",
|
| 35 |
+
"httpx",
|
| 36 |
+
"sacrebleu",
|
| 37 |
+
"locust"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[project.scripts]
|
| 41 |
+
quickmt-serve = "quickmt.rest_server:start"
|
| 42 |
+
quickmt-gui = "quickmt.rest_server:start_gui"
|
| 43 |
+
|
| 44 |
+
[tool.hatch.build.targets.wheel]
|
| 45 |
+
packages = ["quickmt"]
|
pytest.ini
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
asyncio_default_fixture_loop_scope = function
|
quickmt/__init__.py
ADDED
|
File without changes
|
quickmt/gui/static/app.js
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 2 |
+
// Elements
|
| 3 |
+
const srcText = document.getElementById('src-text');
|
| 4 |
+
const tgtText = document.getElementById('tgt-text');
|
| 5 |
+
const srcLangSelect = document.getElementById('src-lang-select');
|
| 6 |
+
const tgtLangSelect = document.getElementById('tgt-lang-select');
|
| 7 |
+
const charCount = document.getElementById('char-count');
|
| 8 |
+
const timingInfo = document.getElementById('timing-info');
|
| 9 |
+
const loader = document.getElementById('translation-loader');
|
| 10 |
+
const detectedBadge = document.getElementById('detected-badge');
|
| 11 |
+
const navLinks = document.querySelectorAll('.nav-links a');
|
| 12 |
+
const views = document.querySelectorAll('.view');
|
| 13 |
+
const healthIndicator = document.getElementById('health-indicator');
|
| 14 |
+
const modelsList = document.getElementById('models-list');
|
| 15 |
+
const copyBtn = document.getElementById('copy-btn');
|
| 16 |
+
const themeToggle = document.getElementById('theme-toggle');
|
| 17 |
+
const sidebarToggle = document.getElementById('sidebar-toggle');
|
| 18 |
+
const sidebar = document.querySelector('.sidebar');
|
| 19 |
+
|
| 20 |
+
let debounceTimer;
|
| 21 |
+
let languages = {};
|
| 22 |
+
let activeController = null;
|
| 23 |
+
|
| 24 |
+
let settings = {
|
| 25 |
+
beam_size: 2,
|
| 26 |
+
patience: 1,
|
| 27 |
+
length_penalty: 1.0,
|
| 28 |
+
coverage_penalty: 0.0,
|
| 29 |
+
repetition_penalty: 1.0
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
// 0. Theme Logic
|
| 33 |
+
function initTheme() {
|
| 34 |
+
const savedTheme = localStorage.getItem('theme') || 'dark';
|
| 35 |
+
if (savedTheme === 'light') {
|
| 36 |
+
document.body.classList.add('light-mode');
|
| 37 |
+
updateThemeUI(true);
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
function updateThemeUI(isLight) {
|
| 42 |
+
const text = themeToggle.querySelector('.mode-text');
|
| 43 |
+
text.textContent = isLight ? 'Light Mode' : 'Dark Mode';
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
themeToggle.addEventListener('click', () => {
|
| 47 |
+
const isLight = document.body.classList.toggle('light-mode');
|
| 48 |
+
localStorage.setItem('theme', isLight ? 'light' : 'dark');
|
| 49 |
+
updateThemeUI(isLight);
|
| 50 |
+
});
|
| 51 |
+
|
| 52 |
+
// 0.1 Sidebar Logic
|
| 53 |
+
function initSidebar() {
|
| 54 |
+
const isCollapsed = localStorage.getItem('sidebar-collapsed') === 'true';
|
| 55 |
+
if (isCollapsed) sidebar.classList.add('collapsed');
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
sidebarToggle.addEventListener('click', () => {
|
| 59 |
+
const isCollapsed = sidebar.classList.toggle('collapsed');
|
| 60 |
+
localStorage.setItem('sidebar-collapsed', isCollapsed);
|
| 61 |
+
});
|
| 62 |
+
|
| 63 |
+
// 0.2 Inference Settings Logic
|
| 64 |
+
function initSettings() {
|
| 65 |
+
const saved = localStorage.getItem('inference-settings');
|
| 66 |
+
if (saved) {
|
| 67 |
+
try {
|
| 68 |
+
const parsed = JSON.parse(saved);
|
| 69 |
+
settings = { ...settings, ...parsed };
|
| 70 |
+
} catch (e) { console.error("Failed to parse settings", e); }
|
| 71 |
+
}
|
| 72 |
+
updateSettingsUI();
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
function updateSettingsUI() {
|
| 76 |
+
// Sync values to inputs
|
| 77 |
+
Object.keys(settings).forEach(key => {
|
| 78 |
+
const input = document.getElementById(`setting-${key.replace('_', '-')}`);
|
| 79 |
+
if (input) {
|
| 80 |
+
input.value = settings[key];
|
| 81 |
+
const valDisplay = input.nextElementSibling;
|
| 82 |
+
if (valDisplay && valDisplay.classList.contains('setting-val')) {
|
| 83 |
+
valDisplay.textContent = settings[key];
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
});
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
function saveSettings() {
|
| 90 |
+
localStorage.setItem('inference-settings', JSON.stringify(settings));
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Add listeners to all settings inputs
|
| 94 |
+
const settingsInputs = [
|
| 95 |
+
'setting-beam-size', 'setting-patience', 'setting-length-penalty',
|
| 96 |
+
'setting-coverage-penalty', 'setting-repetition-penalty'
|
| 97 |
+
];
|
| 98 |
+
|
| 99 |
+
settingsInputs.forEach(id => {
|
| 100 |
+
const input = document.getElementById(id);
|
| 101 |
+
const key = id.replace('setting-', '').replace(/-/g, '_');
|
| 102 |
+
|
| 103 |
+
input.addEventListener('input', () => {
|
| 104 |
+
let val = parseFloat(input.value);
|
| 105 |
+
if (id === 'setting-beam-size' || id === 'setting-patience') val = parseInt(input.value);
|
| 106 |
+
|
| 107 |
+
settings[key] = val;
|
| 108 |
+
|
| 109 |
+
// Enforcement: patience <= beam_size
|
| 110 |
+
if (id === 'setting-beam-size') {
|
| 111 |
+
if (settings.patience > settings.beam_size) {
|
| 112 |
+
settings.patience = settings.beam_size;
|
| 113 |
+
const patienceInput = document.getElementById('setting-patience');
|
| 114 |
+
patienceInput.value = settings.patience;
|
| 115 |
+
patienceInput.nextElementSibling.textContent = settings.patience;
|
| 116 |
+
}
|
| 117 |
+
// Update patience max slider to match beam_size for better UX?
|
| 118 |
+
// User said "maximum 10", so let's stick to that but cap the value.
|
| 119 |
+
} else if (id === 'setting-patience') {
|
| 120 |
+
if (val > settings.beam_size) {
|
| 121 |
+
val = settings.beam_size;
|
| 122 |
+
input.value = val;
|
| 123 |
+
settings.patience = val;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
const valDisplay = input.nextElementSibling;
|
| 128 |
+
if (valDisplay && valDisplay.classList.contains('setting-val')) {
|
| 129 |
+
valDisplay.textContent = val;
|
| 130 |
+
}
|
| 131 |
+
saveSettings();
|
| 132 |
+
});
|
| 133 |
+
});
|
| 134 |
+
|
| 135 |
+
document.getElementById('reset-settings').addEventListener('click', () => {
|
| 136 |
+
settings = {
|
| 137 |
+
beam_size: 2,
|
| 138 |
+
patience: 1,
|
| 139 |
+
length_penalty: 1.0,
|
| 140 |
+
coverage_penalty: 0.0,
|
| 141 |
+
repetition_penalty: 1.0
|
| 142 |
+
};
|
| 143 |
+
updateSettingsUI();
|
| 144 |
+
saveSettings();
|
| 145 |
+
});
|
| 146 |
+
|
| 147 |
+
// 1. Fetch available languages and populate selects
|
| 148 |
+
async function initLanguages() {
|
| 149 |
+
try {
|
| 150 |
+
const res = await fetch('/api/languages');
|
| 151 |
+
if (res.ok) {
|
| 152 |
+
languages = await res.json();
|
| 153 |
+
populateSelects();
|
| 154 |
+
updateHealth(true);
|
| 155 |
+
}
|
| 156 |
+
} catch (e) {
|
| 157 |
+
console.error("Failed to load languages", e);
|
| 158 |
+
updateHealth(false);
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
function populateSelects() {
|
| 163 |
+
const currentSrc = srcLangSelect.value;
|
| 164 |
+
// Keep only the first "Auto-detect" option
|
| 165 |
+
srcLangSelect.innerHTML = '<option value="">Auto-detect</option>';
|
| 166 |
+
|
| 167 |
+
const sources = Object.keys(languages);
|
| 168 |
+
|
| 169 |
+
// Populate Source Languages
|
| 170 |
+
sources.forEach(lang => {
|
| 171 |
+
const opt = document.createElement('option');
|
| 172 |
+
opt.value = lang;
|
| 173 |
+
opt.textContent = lang.toUpperCase();
|
| 174 |
+
srcLangSelect.appendChild(opt);
|
| 175 |
+
});
|
| 176 |
+
|
| 177 |
+
// Restore selection if it still exists
|
| 178 |
+
if (currentSrc && languages[currentSrc]) {
|
| 179 |
+
srcLangSelect.value = currentSrc;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Trigger target population for default selection
|
| 183 |
+
updateTargetOptions();
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
function updateTargetOptions() {
|
| 187 |
+
const src = srcLangSelect.value;
|
| 188 |
+
const currentTgt = tgtLangSelect.value;
|
| 189 |
+
|
| 190 |
+
// Clear targets
|
| 191 |
+
tgtLangSelect.innerHTML = '';
|
| 192 |
+
|
| 193 |
+
let availableTgts = [];
|
| 194 |
+
if (src) {
|
| 195 |
+
availableTgts = languages[src] || [];
|
| 196 |
+
} else {
|
| 197 |
+
// If auto-detect, union of all targets
|
| 198 |
+
const allTgts = new Set();
|
| 199 |
+
Object.values(languages).forEach(list => list.forEach(l => allTgts.add(l)));
|
| 200 |
+
availableTgts = Array.from(allTgts).sort();
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
availableTgts.forEach(lang => {
|
| 204 |
+
const opt = document.createElement('option');
|
| 205 |
+
opt.value = lang;
|
| 206 |
+
opt.textContent = lang.toUpperCase();
|
| 207 |
+
if (lang === currentTgt || (availableTgts.length === 1)) opt.selected = true;
|
| 208 |
+
tgtLangSelect.appendChild(opt);
|
| 209 |
+
});
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// 2. Translation Logic
|
| 213 |
+
async function performTranslation() {
|
| 214 |
+
const fullText = srcText.value;
|
| 215 |
+
if (!fullText.trim()) {
|
| 216 |
+
tgtText.value = '';
|
| 217 |
+
timingInfo.textContent = 'Ready';
|
| 218 |
+
detectedBadge.classList.remove('visible');
|
| 219 |
+
return;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
// Abort previous requests
|
| 223 |
+
if (activeController) activeController.abort();
|
| 224 |
+
activeController = new AbortController();
|
| 225 |
+
const { signal } = activeController;
|
| 226 |
+
|
| 227 |
+
const lines = fullText.split('\n');
|
| 228 |
+
const translatedLines = new Array(lines.length).fill('');
|
| 229 |
+
let srcLang = srcLangSelect.value || null;
|
| 230 |
+
const tgtLang = tgtLangSelect.value;
|
| 231 |
+
|
| 232 |
+
loader.classList.remove('hidden');
|
| 233 |
+
let completedLines = 0;
|
| 234 |
+
let totalToTranslate = lines.filter(l => l.trim()).length;
|
| 235 |
+
|
| 236 |
+
try {
|
| 237 |
+
// Step 1: If auto-detect mode, detect language for entire input first
|
| 238 |
+
if (!srcLang && fullText.trim()) {
|
| 239 |
+
const detectResponse = await fetch('/api/identify-language', {
|
| 240 |
+
method: 'POST',
|
| 241 |
+
headers: { 'Content-Type': 'application/json' },
|
| 242 |
+
body: JSON.stringify({
|
| 243 |
+
src: fullText,
|
| 244 |
+
k: 1,
|
| 245 |
+
threshold: 0.0
|
| 246 |
+
}),
|
| 247 |
+
signal
|
| 248 |
+
});
|
| 249 |
+
|
| 250 |
+
if (detectResponse.ok) {
|
| 251 |
+
const detectData = await detectResponse.json();
|
| 252 |
+
// Get the detected language from the response
|
| 253 |
+
if (detectData.results && detectData.results.length > 0) {
|
| 254 |
+
srcLang = detectData.results[0].lang;
|
| 255 |
+
detectedBadge.textContent = `Detected: ${srcLang.toUpperCase()}`;
|
| 256 |
+
detectedBadge.classList.add('visible');
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
// Step 2: Translate all lines with known source language
|
| 262 |
+
const updateTgtUI = () => {
|
| 263 |
+
tgtText.value = translatedLines.join('\n');
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
const translateParagraph = async (line, index) => {
|
| 267 |
+
if (!line.trim()) {
|
| 268 |
+
translatedLines[index] = line;
|
| 269 |
+
updateTgtUI();
|
| 270 |
+
return;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
try {
|
| 274 |
+
const response = await fetch('/api/translate', {
|
| 275 |
+
method: 'POST',
|
| 276 |
+
headers: { 'Content-Type': 'application/json' },
|
| 277 |
+
body: JSON.stringify({
|
| 278 |
+
src: line,
|
| 279 |
+
src_lang: srcLang, // Now we always have a source language
|
| 280 |
+
tgt_lang: tgtLang,
|
| 281 |
+
...settings
|
| 282 |
+
}),
|
| 283 |
+
signal
|
| 284 |
+
});
|
| 285 |
+
|
| 286 |
+
if (response.ok) {
|
| 287 |
+
const data = await response.json();
|
| 288 |
+
translatedLines[index] = data.translation;
|
| 289 |
+
|
| 290 |
+
completedLines++;
|
| 291 |
+
updateTgtUI();
|
| 292 |
+
timingInfo.textContent = `Translating: ${Math.round((completedLines / totalToTranslate) * 100)}%`;
|
| 293 |
+
}
|
| 294 |
+
} catch (e) {
|
| 295 |
+
if (e.name !== 'AbortError') {
|
| 296 |
+
console.error("Line translation error", e);
|
| 297 |
+
translatedLines[index] = `[[Error: ${line}]]`;
|
| 298 |
+
}
|
| 299 |
+
} finally {
|
| 300 |
+
if (completedLines === totalToTranslate) {
|
| 301 |
+
loader.classList.add('hidden');
|
| 302 |
+
timingInfo.textContent = 'Done';
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
// Fire all translation requests in parallel
|
| 308 |
+
lines.forEach((line, i) => translateParagraph(line, i));
|
| 309 |
+
|
| 310 |
+
} catch (e) {
|
| 311 |
+
if (e.name !== 'AbortError') {
|
| 312 |
+
console.error("Translation error", e);
|
| 313 |
+
loader.classList.add('hidden');
|
| 314 |
+
timingInfo.textContent = 'Error';
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
// 3. Models View
|
| 320 |
+
async function fetchModels() {
|
| 321 |
+
try {
|
| 322 |
+
const res = await fetch('/api/models');
|
| 323 |
+
const data = await res.json();
|
| 324 |
+
|
| 325 |
+
modelsList.innerHTML = '';
|
| 326 |
+
|
| 327 |
+
// Use DocumentFragment for better performance
|
| 328 |
+
const fragment = document.createDocumentFragment();
|
| 329 |
+
|
| 330 |
+
data.models.forEach(m => {
|
| 331 |
+
const card = document.createElement('div');
|
| 332 |
+
card.className = 'model-card';
|
| 333 |
+
card.innerHTML = `
|
| 334 |
+
<div class="model-lang-pair">
|
| 335 |
+
<span>${m.src_lang.toUpperCase()}</span>
|
| 336 |
+
<span>→</span>
|
| 337 |
+
<span>${m.tgt_lang.toUpperCase()}</span>
|
| 338 |
+
</div>
|
| 339 |
+
<div class="model-id">${m.model_id}</div>
|
| 340 |
+
${m.loaded ? '<span class="loaded-badge">Currently Loaded</span>' : ''}
|
| 341 |
+
`;
|
| 342 |
+
fragment.appendChild(card);
|
| 343 |
+
});
|
| 344 |
+
|
| 345 |
+
// Single DOM update instead of multiple
|
| 346 |
+
modelsList.appendChild(fragment);
|
| 347 |
+
} catch (e) {
|
| 348 |
+
modelsList.innerHTML = '<p>Error loading models</p>';
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
// 4. UI Helpers
|
| 353 |
+
function updateHealth(isOnline) {
|
| 354 |
+
if (isOnline) {
|
| 355 |
+
healthIndicator.className = 'status-pill status-online';
|
| 356 |
+
healthIndicator.querySelector('.status-text').textContent = 'Online';
|
| 357 |
+
} else {
|
| 358 |
+
healthIndicator.className = 'status-pill status-loading';
|
| 359 |
+
healthIndicator.querySelector('.status-text').textContent = 'Offline';
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
// Event Listeners
|
| 364 |
+
srcText.addEventListener('input', () => {
|
| 365 |
+
charCount.textContent = `${srcText.value.length} characters`;
|
| 366 |
+
clearTimeout(debounceTimer);
|
| 367 |
+
debounceTimer = setTimeout(performTranslation, 250);
|
| 368 |
+
});
|
| 369 |
+
|
| 370 |
+
srcLangSelect.addEventListener('change', () => {
|
| 371 |
+
updateTargetOptions();
|
| 372 |
+
performTranslation();
|
| 373 |
+
});
|
| 374 |
+
|
| 375 |
+
tgtLangSelect.addEventListener('change', performTranslation);
|
| 376 |
+
|
| 377 |
+
copyBtn.addEventListener('click', () => {
|
| 378 |
+
navigator.clipboard.writeText(tgtText.value);
|
| 379 |
+
const originalText = copyBtn.textContent;
|
| 380 |
+
copyBtn.textContent = 'Copied!';
|
| 381 |
+
setTimeout(() => copyBtn.textContent = originalText, 2000);
|
| 382 |
+
});
|
| 383 |
+
|
| 384 |
+
// Navigation
|
| 385 |
+
navLinks.forEach(link => {
|
| 386 |
+
link.addEventListener('click', (e) => {
|
| 387 |
+
e.preventDefault();
|
| 388 |
+
const targetId = link.getAttribute('href').substring(1);
|
| 389 |
+
|
| 390 |
+
navLinks.forEach(l => l.parentElement.classList.remove('active'));
|
| 391 |
+
link.parentElement.classList.add('active');
|
| 392 |
+
|
| 393 |
+
views.forEach(v => {
|
| 394 |
+
v.classList.remove('active');
|
| 395 |
+
if (v.id === `${targetId}-view`) v.classList.add('active');
|
| 396 |
+
});
|
| 397 |
+
|
| 398 |
+
if (targetId === 'models') fetchModels();
|
| 399 |
+
});
|
| 400 |
+
});
|
| 401 |
+
|
| 402 |
+
// Start
|
| 403 |
+
initTheme();
|
| 404 |
+
initSidebar();
|
| 405 |
+
initSettings();
|
| 406 |
+
initLanguages();
|
| 407 |
+
setInterval(initLanguages, 10000); // Pulse health check
|
| 408 |
+
});
|
quickmt/gui/static/index.html
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>QuickMT Machine Translation</title>
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link
|
| 11 |
+
href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=Inter:wght@300;400;500;600&display=swap"
|
| 12 |
+
rel="stylesheet">
|
| 13 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.1/css/all.min.css">
|
| 14 |
+
<link rel="stylesheet" href="style.css">
|
| 15 |
+
</head>
|
| 16 |
+
|
| 17 |
+
<body>
|
| 18 |
+
<div class="bg-blur"></div>
|
| 19 |
+
|
| 20 |
+
<div class="top-nav-links">
|
| 21 |
+
<a href="https://huggingface.co/quickmt" target="_blank" title="Hugging Face Models" class="glass-btn">
|
| 22 |
+
<span class="btn-icon">🤗</span>
|
| 23 |
+
</a>
|
| 24 |
+
<a href="https://github.com/quickmt/quickmt" target="_blank" title="GitHub Repository"
|
| 25 |
+
class="glass-btn icon-only">
|
| 26 |
+
<span class="btn-icon"><i class="fa-brands fa-github"></i></span>
|
| 27 |
+
</a>
|
| 28 |
+
</div>
|
| 29 |
+
|
| 30 |
+
<main class="app-container">
|
| 31 |
+
<!-- Sidebar Navigation -->
|
| 32 |
+
<nav class="sidebar glass">
|
| 33 |
+
<div class="logo">
|
| 34 |
+
<div class="logo-icon">Q</div>
|
| 35 |
+
<span>QuickMT</span>
|
| 36 |
+
</div>
|
| 37 |
+
<ul class="nav-links">
|
| 38 |
+
<li class="active">
|
| 39 |
+
<a href="#translate">
|
| 40 |
+
<span class="nav-icon">🔁</span>
|
| 41 |
+
<span class="nav-text">Translate</span>
|
| 42 |
+
</a>
|
| 43 |
+
</li>
|
| 44 |
+
<li>
|
| 45 |
+
<a href="#models">
|
| 46 |
+
<span class="nav-icon">🧩</span>
|
| 47 |
+
<span class="nav-text">Models</span>
|
| 48 |
+
</a>
|
| 49 |
+
</li>
|
| 50 |
+
<li>
|
| 51 |
+
<a href="#settings">
|
| 52 |
+
<span class="nav-icon">⚙️</span>
|
| 53 |
+
<span class="nav-text">Settings</span>
|
| 54 |
+
</a>
|
| 55 |
+
</li>
|
| 56 |
+
</ul>
|
| 57 |
+
<div class="sidebar-footer">
|
| 58 |
+
<div id="health-indicator" class="status-pill status-loading">
|
| 59 |
+
<span class="dot"></span>
|
| 60 |
+
<span class="status-text">Connecting...</span>
|
| 61 |
+
</div>
|
| 62 |
+
<div class="theme-toggle-container">
|
| 63 |
+
<button id="theme-toggle" title="Toggle Light/Dark Mode">
|
| 64 |
+
<span class="mode-icon">◑</span>
|
| 65 |
+
<span class="mode-text">Dark Mode</span>
|
| 66 |
+
</button>
|
| 67 |
+
</div>
|
| 68 |
+
<button id="sidebar-toggle" class="icon-btn collapse-btn" title="Toggle Sidebar">«</button>
|
| 69 |
+
</div>
|
| 70 |
+
</nav>
|
| 71 |
+
|
| 72 |
+
<!-- Main Content -->
|
| 73 |
+
<section class="content">
|
| 74 |
+
<!-- Translate View -->
|
| 75 |
+
<div id="translate-view" class="view active">
|
| 76 |
+
<header class="view-header">
|
| 77 |
+
<h1>QuickMT Neural MachineTranslation</h1>
|
| 78 |
+
</header>
|
| 79 |
+
|
| 80 |
+
<div class="translation-grid">
|
| 81 |
+
<!-- Source Column -->
|
| 82 |
+
<div class="card glass translation-card">
|
| 83 |
+
<div class="card-header">
|
| 84 |
+
<div class="lang-group">
|
| 85 |
+
<span class="lang-label">From</span>
|
| 86 |
+
<select id="src-lang-select" class="lang-select">
|
| 87 |
+
<option value="">Auto-detect</option>
|
| 88 |
+
</select>
|
| 89 |
+
</div>
|
| 90 |
+
<div id="detected-badge" class="detected-badge"></div>
|
| 91 |
+
</div>
|
| 92 |
+
<div class="card-body">
|
| 93 |
+
<textarea id="src-text" placeholder="Enter text to translate..." autofocus></textarea>
|
| 94 |
+
</div>
|
| 95 |
+
<div class="card-footer">
|
| 96 |
+
<span id="char-count">0 characters</span>
|
| 97 |
+
</div>
|
| 98 |
+
</div>
|
| 99 |
+
|
| 100 |
+
<!-- Target Column -->
|
| 101 |
+
<div class="card glass translation-card target-card">
|
| 102 |
+
<div class="card-header">
|
| 103 |
+
<div class="lang-group">
|
| 104 |
+
<span class="lang-label">To</span>
|
| 105 |
+
<select id="tgt-lang-select" class="lang-select">
|
| 106 |
+
<option value="en">English</option>
|
| 107 |
+
</select>
|
| 108 |
+
</div>
|
| 109 |
+
</div>
|
| 110 |
+
<div class="card-body">
|
| 111 |
+
<textarea id="tgt-text" readonly placeholder="Translation will appear here..."></textarea>
|
| 112 |
+
<div id="translation-loader" class="loader-overlay hidden">
|
| 113 |
+
<div class="spinner"></div>
|
| 114 |
+
</div>
|
| 115 |
+
</div>
|
| 116 |
+
<div class="card-footer">
|
| 117 |
+
<span id="timing-info">Ready</span>
|
| 118 |
+
<button id="copy-btn" class="action-btn">Copy</button>
|
| 119 |
+
</div>
|
| 120 |
+
</div>
|
| 121 |
+
</div>
|
| 122 |
+
</div>
|
| 123 |
+
|
| 124 |
+
<!-- Models View -->
|
| 125 |
+
<div id="models-view" class="view">
|
| 126 |
+
<header class="view-header">
|
| 127 |
+
<h1>Available Models</h1>
|
| 128 |
+
<p>Browse models from the quickmt Hugging Face collection</p>
|
| 129 |
+
</header>
|
| 130 |
+
<div id="models-list" class="models-grid">
|
| 131 |
+
<!-- Model cards will be injected here -->
|
| 132 |
+
</div>
|
| 133 |
+
</div>
|
| 134 |
+
<!-- Settings View -->
|
| 135 |
+
<div id="settings-view" class="view">
|
| 136 |
+
<header class="view-header">
|
| 137 |
+
<h1>Inference Settings</h1>
|
| 138 |
+
<p>Fine-tune the translation engine for your needs</p>
|
| 139 |
+
</header>
|
| 140 |
+
|
| 141 |
+
<div class="settings-container glass">
|
| 142 |
+
<div class="settings-grid">
|
| 143 |
+
<!-- Beam Size -->
|
| 144 |
+
<div class="setting-item">
|
| 145 |
+
<div class="setting-info">
|
| 146 |
+
<label>Beam Size</label>
|
| 147 |
+
<span class="setting-desc">Number of hypotheses to explore (1-10)</span>
|
| 148 |
+
</div>
|
| 149 |
+
<div class="setting-control">
|
| 150 |
+
<input type="range" id="setting-beam-size" min="1" max="10" step="1" value="2">
|
| 151 |
+
<span class="setting-val">2</span>
|
| 152 |
+
</div>
|
| 153 |
+
</div>
|
| 154 |
+
|
| 155 |
+
<!-- Patience -->
|
| 156 |
+
<div class="setting-item">
|
| 157 |
+
<div class="setting-info">
|
| 158 |
+
<label>Patience</label>
|
| 159 |
+
<span class="setting-desc">Stopping criterion factor (1-10)</span>
|
| 160 |
+
</div>
|
| 161 |
+
<div class="setting-control">
|
| 162 |
+
<input type="range" id="setting-patience" min="1" max="10" step="1" value="1">
|
| 163 |
+
<span class="setting-val">1</span>
|
| 164 |
+
</div>
|
| 165 |
+
</div>
|
| 166 |
+
|
| 167 |
+
<!-- Length Penalty -->
|
| 168 |
+
<div class="setting-item">
|
| 169 |
+
<div class="setting-info">
|
| 170 |
+
<label>Length Penalty</label>
|
| 171 |
+
<span class="setting-desc">Favour longer or shorter sentences (default 1.0)</span>
|
| 172 |
+
</div>
|
| 173 |
+
<div class="setting-control">
|
| 174 |
+
<input type="number" id="setting-length-penalty" step="0.1" value="1.0">
|
| 175 |
+
</div>
|
| 176 |
+
</div>
|
| 177 |
+
|
| 178 |
+
<!-- Coverage Penalty -->
|
| 179 |
+
<div class="setting-item">
|
| 180 |
+
<div class="setting-info">
|
| 181 |
+
<label>Coverage Penalty</label>
|
| 182 |
+
<span class="setting-desc">Ensure all source words are translated (default 0.0)</span>
|
| 183 |
+
</div>
|
| 184 |
+
<div class="setting-control">
|
| 185 |
+
<input type="number" id="setting-coverage-penalty" step="0.1" value="0.0">
|
| 186 |
+
</div>
|
| 187 |
+
</div>
|
| 188 |
+
|
| 189 |
+
<!-- Repetition Penalty -->
|
| 190 |
+
<div class="setting-item">
|
| 191 |
+
<div class="setting-info">
|
| 192 |
+
<label>Repetition Penalty</label>
|
| 193 |
+
<span class="setting-desc">Prevent repeating words (default 1.0)</span>
|
| 194 |
+
</div>
|
| 195 |
+
<div class="setting-control">
|
| 196 |
+
<input type="number" id="setting-repetition-penalty" step="0.1" value="1.0">
|
| 197 |
+
</div>
|
| 198 |
+
</div>
|
| 199 |
+
</div>
|
| 200 |
+
|
| 201 |
+
<div class="settings-actions">
|
| 202 |
+
<button id="reset-settings" class="action-btn secondary">Reset to Defaults</button>
|
| 203 |
+
</div>
|
| 204 |
+
</div>
|
| 205 |
+
</div>
|
| 206 |
+
</section>
|
| 207 |
+
</main>
|
| 208 |
+
|
| 209 |
+
<script src="app.js"></script>
|
| 210 |
+
</body>
|
| 211 |
+
|
| 212 |
+
</html>
|
quickmt/gui/static/style.css
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--primary: #6366f1;
|
| 3 |
+
--primary-glow: rgba(99, 102, 241, 0.5);
|
| 4 |
+
--bg-gradient: linear-gradient(135deg, #0f172a 0%, #1e1b4b 100%);
|
| 5 |
+
--glass-bg: rgba(255, 255, 255, 0.03);
|
| 6 |
+
--glass-border: rgba(255, 255, 255, 0.1);
|
| 7 |
+
--text-main: #f8fafc;
|
| 8 |
+
--text-muted: #94a3b8;
|
| 9 |
+
--card-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.37);
|
| 10 |
+
--transition: none;
|
| 11 |
+
--sidebar-active-bg: rgba(255, 255, 255, 0.05);
|
| 12 |
+
--input-bg: rgba(255, 255, 255, 0.05);
|
| 13 |
+
--btn-hover-bg: rgba(255, 255, 255, 0.1);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
body.light-mode {
|
| 17 |
+
--bg-gradient: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
|
| 18 |
+
--glass-bg: rgba(255, 255, 255, 0.7);
|
| 19 |
+
--glass-border: rgba(99, 102, 241, 0.1);
|
| 20 |
+
--text-main: #1e293b;
|
| 21 |
+
--text-muted: #64748b;
|
| 22 |
+
--card-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.07);
|
| 23 |
+
--sidebar-active-bg: rgba(99, 102, 241, 0.05);
|
| 24 |
+
--input-bg: rgba(0, 0, 0, 0.02);
|
| 25 |
+
--btn-hover-bg: rgba(99, 102, 241, 0.1);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
* {
|
| 29 |
+
margin: 0;
|
| 30 |
+
padding: 0;
|
| 31 |
+
box-sizing: border-box;
|
| 32 |
+
font-family: 'Inter', sans-serif;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
h1,
|
| 36 |
+
h2,
|
| 37 |
+
h3,
|
| 38 |
+
.logo {
|
| 39 |
+
font-family: 'Outfit', sans-serif;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
body {
|
| 43 |
+
background: var(--bg-gradient);
|
| 44 |
+
color: var(--text-main);
|
| 45 |
+
min-height: 100vh;
|
| 46 |
+
overflow: hidden;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.bg-blur {
|
| 50 |
+
position: fixed;
|
| 51 |
+
top: 0;
|
| 52 |
+
left: 0;
|
| 53 |
+
width: 100%;
|
| 54 |
+
height: 100%;
|
| 55 |
+
z-index: -1;
|
| 56 |
+
background: radial-gradient(circle at 20% 30%, rgba(99, 102, 241, 0.15) 0%, transparent 40%),
|
| 57 |
+
radial-gradient(circle at 80% 70%, rgba(168, 85, 247, 0.15) 0%, transparent 40%);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
.top-nav-links {
|
| 61 |
+
position: fixed;
|
| 62 |
+
top: 1.5rem;
|
| 63 |
+
right: 2.5rem;
|
| 64 |
+
display: flex;
|
| 65 |
+
gap: 0.75rem;
|
| 66 |
+
z-index: 100;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
.glass-btn {
|
| 70 |
+
display: flex;
|
| 71 |
+
align-items: center;
|
| 72 |
+
justify-content: center;
|
| 73 |
+
padding: 0.6rem 1rem;
|
| 74 |
+
min-width: 44px;
|
| 75 |
+
height: 44px;
|
| 76 |
+
background: var(--glass-bg);
|
| 77 |
+
backdrop-filter: blur(12px);
|
| 78 |
+
-webkit-backdrop-filter: blur(12px);
|
| 79 |
+
border: 1px solid var(--glass-border);
|
| 80 |
+
border-radius: 0.75rem;
|
| 81 |
+
color: var(--text-main);
|
| 82 |
+
text-decoration: none;
|
| 83 |
+
font-size: 0.85rem;
|
| 84 |
+
font-weight: 600;
|
| 85 |
+
transition: none;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.glass-btn.icon-only {
|
| 89 |
+
padding: 0;
|
| 90 |
+
width: 44px;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.glass-btn .btn-icon i {
|
| 94 |
+
font-size: 1.25rem;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.glass-btn:hover {
|
| 98 |
+
background: var(--btn-hover-bg);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.glass-btn .btn-icon {
|
| 102 |
+
display: flex;
|
| 103 |
+
align-items: center;
|
| 104 |
+
gap: 0.5rem;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
.app-container {
|
| 108 |
+
display: flex;
|
| 109 |
+
height: 100vh;
|
| 110 |
+
padding: 1.5rem;
|
| 111 |
+
gap: 1.5rem;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/* Glass Effect */
|
| 115 |
+
.glass {
|
| 116 |
+
background: var(--glass-bg);
|
| 117 |
+
backdrop-filter: blur(12px);
|
| 118 |
+
-webkit-backdrop-filter: blur(12px);
|
| 119 |
+
border: 1px solid var(--glass-border);
|
| 120 |
+
border-radius: 1.25rem;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
/* Sidebar */
|
| 124 |
+
.sidebar {
|
| 125 |
+
width: 280px;
|
| 126 |
+
display: flex;
|
| 127 |
+
flex-direction: column;
|
| 128 |
+
padding: 2rem;
|
| 129 |
+
transition: none;
|
| 130 |
+
overflow: hidden;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
.sidebar.collapsed {
|
| 134 |
+
width: 90px;
|
| 135 |
+
padding: 2rem 1.25rem;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
.logo {
|
| 139 |
+
display: flex;
|
| 140 |
+
align-items: center;
|
| 141 |
+
gap: 1rem;
|
| 142 |
+
font-size: 1.5rem;
|
| 143 |
+
font-weight: 700;
|
| 144 |
+
margin-bottom: 3rem;
|
| 145 |
+
position: relative;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.logo span {
|
| 149 |
+
transition: none;
|
| 150 |
+
white-space: nowrap;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.sidebar.collapsed .logo span {
|
| 154 |
+
opacity: 0;
|
| 155 |
+
pointer-events: none;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
.logo-icon {
|
| 159 |
+
width: 40px;
|
| 160 |
+
height: 40px;
|
| 161 |
+
min-width: 40px;
|
| 162 |
+
background: var(--primary);
|
| 163 |
+
border-radius: 10px;
|
| 164 |
+
display: flex;
|
| 165 |
+
align-items: center;
|
| 166 |
+
justify-content: center;
|
| 167 |
+
color: white;
|
| 168 |
+
box-shadow: 0 0 20px var(--primary-glow);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.nav-links {
|
| 172 |
+
list-style: none;
|
| 173 |
+
flex: 1;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.nav-links li {
|
| 177 |
+
margin-bottom: 0.5rem;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.nav-links a {
|
| 181 |
+
display: flex;
|
| 182 |
+
align-items: center;
|
| 183 |
+
padding: 0.75rem 1rem;
|
| 184 |
+
color: var(--text-muted);
|
| 185 |
+
text-decoration: none;
|
| 186 |
+
border-radius: 0.75rem;
|
| 187 |
+
transition: none;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
.nav-links .nav-text {
|
| 191 |
+
transition: none;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.sidebar.collapsed .nav-links .nav-text {
|
| 195 |
+
opacity: 0;
|
| 196 |
+
pointer-events: none;
|
| 197 |
+
width: 0;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
.nav-icon {
|
| 201 |
+
font-size: 1.25rem;
|
| 202 |
+
min-width: 24px;
|
| 203 |
+
display: flex;
|
| 204 |
+
align-items: center;
|
| 205 |
+
justify-content: center;
|
| 206 |
+
margin-right: 0.75rem;
|
| 207 |
+
transition: none;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
.sidebar.collapsed .nav-icon {
|
| 211 |
+
margin-right: 0;
|
| 212 |
+
width: 100%;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.sidebar.collapsed .nav-links a {
|
| 216 |
+
justify-content: center;
|
| 217 |
+
padding: 0.75rem 0;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
.nav-links li.active a,
|
| 221 |
+
.nav-links a:hover {
|
| 222 |
+
background: var(--sidebar-active-bg);
|
| 223 |
+
color: var(--text-main);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
.nav-links li.active a {
|
| 227 |
+
border-left: 3px solid var(--primary);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
.sidebar.collapsed .nav-links li.active a {
|
| 231 |
+
border-left: none;
|
| 232 |
+
background: var(--sidebar-active-bg);
|
| 233 |
+
box-shadow: inset 0 0 10px rgba(99, 102, 241, 0.2);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
/* Content Area */
|
| 237 |
+
.content {
|
| 238 |
+
flex: 1;
|
| 239 |
+
overflow-y: auto;
|
| 240 |
+
padding-right: 0.5rem;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
.view {
|
| 244 |
+
display: none;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.view.active {
|
| 248 |
+
display: flex;
|
| 249 |
+
flex-direction: column;
|
| 250 |
+
height: 100%;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
.view-header {
|
| 256 |
+
margin-bottom: 2rem;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
.view-header h1 {
|
| 260 |
+
font-size: 2.25rem;
|
| 261 |
+
margin-bottom: 0.5rem;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.view-header p {
|
| 265 |
+
color: var(--text-muted);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
/* Prediction / Translation Grid */
|
| 269 |
+
.translation-grid {
|
| 270 |
+
display: grid;
|
| 271 |
+
grid-template-columns: 1fr 1fr;
|
| 272 |
+
gap: 1.5rem;
|
| 273 |
+
flex: 1;
|
| 274 |
+
min-height: 0;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
.card {
|
| 278 |
+
display: flex;
|
| 279 |
+
flex-direction: column;
|
| 280 |
+
box-shadow: var(--card-shadow);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
.card-header {
|
| 284 |
+
padding: 1rem 1.5rem;
|
| 285 |
+
border-bottom: 1px solid var(--glass-border);
|
| 286 |
+
display: flex;
|
| 287 |
+
align-items: center;
|
| 288 |
+
justify-content: space-between;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.card-body {
|
| 292 |
+
flex: 1;
|
| 293 |
+
position: relative;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
.card-footer {
|
| 297 |
+
padding: 1rem 1.5rem;
|
| 298 |
+
border-top: 1px solid var(--glass-border);
|
| 299 |
+
font-size: 0.85rem;
|
| 300 |
+
color: var(--text-muted);
|
| 301 |
+
display: flex;
|
| 302 |
+
align-items: center;
|
| 303 |
+
justify-content: space-between;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
textarea {
|
| 307 |
+
width: 100%;
|
| 308 |
+
height: 100%;
|
| 309 |
+
background: transparent;
|
| 310 |
+
border: none;
|
| 311 |
+
resize: none;
|
| 312 |
+
padding: 1.5rem;
|
| 313 |
+
color: var(--text-main);
|
| 314 |
+
font-size: 1.1rem;
|
| 315 |
+
line-height: 1.6;
|
| 316 |
+
outline: none;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
.lang-group {
|
| 320 |
+
display: flex;
|
| 321 |
+
align-items: center;
|
| 322 |
+
gap: 0.75rem;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
.lang-label {
|
| 326 |
+
font-size: 0.7rem;
|
| 327 |
+
font-weight: 700;
|
| 328 |
+
text-transform: uppercase;
|
| 329 |
+
letter-spacing: 0.05em;
|
| 330 |
+
color: var(--text-muted);
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
.lang-select {
|
| 334 |
+
background: var(--input-bg);
|
| 335 |
+
border: 1px solid var(--glass-border);
|
| 336 |
+
color: var(--text-main);
|
| 337 |
+
padding: 0.5rem 1rem;
|
| 338 |
+
border-radius: 0.5rem;
|
| 339 |
+
outline: none;
|
| 340 |
+
cursor: pointer;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.detected-badge {
|
| 344 |
+
font-size: 0.75rem;
|
| 345 |
+
background: var(--primary);
|
| 346 |
+
padding: 0.2rem 0.6rem;
|
| 347 |
+
border-radius: 1rem;
|
| 348 |
+
color: white;
|
| 349 |
+
opacity: 0;
|
| 350 |
+
transition: none;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
.detected-badge.visible {
|
| 354 |
+
opacity: 1;
|
| 355 |
+
transition: none;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
/* Loader */
|
| 359 |
+
.loader-overlay {
|
| 360 |
+
position: absolute;
|
| 361 |
+
top: 0;
|
| 362 |
+
left: 0;
|
| 363 |
+
width: 100%;
|
| 364 |
+
height: 100%;
|
| 365 |
+
background: rgba(15, 23, 42, 0.5);
|
| 366 |
+
display: flex;
|
| 367 |
+
align-items: center;
|
| 368 |
+
justify-content: center;
|
| 369 |
+
border-radius: 0 0 1.25rem 1.25rem;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
.hidden {
|
| 373 |
+
display: none !important;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
.spinner {
|
| 377 |
+
width: 30px;
|
| 378 |
+
height: 30px;
|
| 379 |
+
border: 3px solid rgba(255, 255, 255, 0.1);
|
| 380 |
+
border-top-color: var(--primary);
|
| 381 |
+
border-radius: 50%;
|
| 382 |
+
animation: none;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
@keyframes spin {
|
| 386 |
+
to {
|
| 387 |
+
transform: rotate(360deg);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
/* Success Pills */
|
| 392 |
+
.status-pill {
|
| 393 |
+
display: flex;
|
| 394 |
+
align-items: center;
|
| 395 |
+
gap: 0.5rem;
|
| 396 |
+
padding: 0.5rem 1rem;
|
| 397 |
+
border-radius: 2rem;
|
| 398 |
+
background: rgba(0, 0, 0, 0.2);
|
| 399 |
+
font-size: 0.85rem;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
.dot {
|
| 403 |
+
width: 8px;
|
| 404 |
+
height: 8px;
|
| 405 |
+
border-radius: 50%;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
.status-online .dot {
|
| 409 |
+
background: #10b981;
|
| 410 |
+
box-shadow: 0 0 10px #10b981;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
.status-loading .dot {
|
| 414 |
+
background: #f59e0b;
|
| 415 |
+
animation: none;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
@keyframes pulse {
|
| 419 |
+
0% {
|
| 420 |
+
transform: scale(1);
|
| 421 |
+
opacity: 1;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
50% {
|
| 425 |
+
transform: scale(1.2);
|
| 426 |
+
opacity: 0.5;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
100% {
|
| 430 |
+
transform: scale(1);
|
| 431 |
+
opacity: 1;
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
/* Models Grid */
|
| 436 |
+
.models-grid {
|
| 437 |
+
display: grid;
|
| 438 |
+
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
|
| 439 |
+
gap: 1.5rem;
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
.model-card {
|
| 443 |
+
padding: 1.5rem;
|
| 444 |
+
transition: none;
|
| 445 |
+
/* Performance optimizations */
|
| 446 |
+
contain: layout style paint;
|
| 447 |
+
will-change: background;
|
| 448 |
+
/* Remove expensive backdrop-filter for better scrolling */
|
| 449 |
+
background: var(--glass-bg);
|
| 450 |
+
border: 1px solid var(--glass-border);
|
| 451 |
+
border-radius: 1.25rem;
|
| 452 |
+
box-shadow: var(--card-shadow);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
.model-card:hover {
|
| 456 |
+
background: var(--sidebar-active-bg);
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
.model-lang-pair {
|
| 460 |
+
display: flex;
|
| 461 |
+
align-items: center;
|
| 462 |
+
gap: 0.5rem;
|
| 463 |
+
font-weight: 600;
|
| 464 |
+
margin-bottom: 0.5rem;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
.model-id {
|
| 468 |
+
font-size: 0.8rem;
|
| 469 |
+
color: var(--text-muted);
|
| 470 |
+
word-break: break-all;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
.loaded-badge {
|
| 474 |
+
display: inline-block;
|
| 475 |
+
padding: 0.2rem 0.5rem;
|
| 476 |
+
background: rgba(16, 185, 129, 0.1);
|
| 477 |
+
color: #10b981;
|
| 478 |
+
border: 1px solid rgba(16, 185, 129, 0.2);
|
| 479 |
+
border-radius: 0.4rem;
|
| 480 |
+
font-size: 0.7rem;
|
| 481 |
+
margin-top: 1rem;
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
/* Buttons */
|
| 485 |
+
.action-btn {
|
| 486 |
+
background: var(--primary);
|
| 487 |
+
color: white;
|
| 488 |
+
border: none;
|
| 489 |
+
padding: 0.5rem 1rem;
|
| 490 |
+
border-radius: 0.5rem;
|
| 491 |
+
cursor: pointer;
|
| 492 |
+
font-weight: 500;
|
| 493 |
+
transition: none;
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
.action-btn:hover {
|
| 497 |
+
background: #4f46e5;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
.icon-btn {
|
| 501 |
+
background: transparent;
|
| 502 |
+
border: none;
|
| 503 |
+
color: var(--text-muted);
|
| 504 |
+
font-size: 1.25rem;
|
| 505 |
+
cursor: pointer;
|
| 506 |
+
transition: none;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
.icon-btn:hover {
|
| 510 |
+
color: var(--text-main);
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
/* Settings View */
|
| 514 |
+
.settings-container {
|
| 515 |
+
max-width: 800px;
|
| 516 |
+
padding: 2rem;
|
| 517 |
+
margin-top: 1rem;
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
.settings-grid {
|
| 521 |
+
display: flex;
|
| 522 |
+
flex-direction: column;
|
| 523 |
+
gap: 2rem;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
.setting-item {
|
| 527 |
+
display: grid;
|
| 528 |
+
grid-template-columns: 1fr 200px;
|
| 529 |
+
align-items: center;
|
| 530 |
+
gap: 2rem;
|
| 531 |
+
padding-bottom: 2rem;
|
| 532 |
+
border-bottom: 1px solid var(--glass-border);
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
.setting-item:last-child {
|
| 536 |
+
border-bottom: none;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
.setting-info {
|
| 540 |
+
display: flex;
|
| 541 |
+
flex-direction: column;
|
| 542 |
+
gap: 0.25rem;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
.setting-info label {
|
| 546 |
+
font-weight: 600;
|
| 547 |
+
font-size: 1.1rem;
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
.setting-desc {
|
| 551 |
+
color: var(--text-muted);
|
| 552 |
+
font-size: 0.85rem;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
.setting-control {
|
| 556 |
+
display: flex;
|
| 557 |
+
align-items: center;
|
| 558 |
+
gap: 1rem;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
.setting-val {
|
| 562 |
+
min-width: 30px;
|
| 563 |
+
font-family: monospace;
|
| 564 |
+
font-weight: 600;
|
| 565 |
+
color: var(--primary);
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
input[type="range"] {
|
| 569 |
+
flex: 1;
|
| 570 |
+
cursor: pointer;
|
| 571 |
+
accent-color: var(--primary);
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
input[type="number"] {
|
| 575 |
+
width: 100%;
|
| 576 |
+
background: var(--input-bg);
|
| 577 |
+
border: 1px solid var(--glass-border);
|
| 578 |
+
color: var(--text-main);
|
| 579 |
+
padding: 0.5rem;
|
| 580 |
+
border-radius: 0.5rem;
|
| 581 |
+
outline: none;
|
| 582 |
+
font-family: inherit;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
.settings-actions {
|
| 586 |
+
margin-top: 3rem;
|
| 587 |
+
display: flex;
|
| 588 |
+
justify-content: flex-end;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
.action-btn.secondary {
|
| 592 |
+
background: rgba(255, 255, 255, 0.05);
|
| 593 |
+
border: 1px solid var(--glass-border);
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
.action-btn.secondary:hover {
|
| 597 |
+
background: rgba(255, 255, 255, 0.1);
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
/* Theme Toggle Button */
|
| 601 |
+
.theme-toggle-container {
|
| 602 |
+
margin-top: 1rem;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
#theme-toggle {
|
| 606 |
+
width: 100%;
|
| 607 |
+
display: flex;
|
| 608 |
+
align-items: center;
|
| 609 |
+
gap: 0.75rem;
|
| 610 |
+
background: var(--input-bg);
|
| 611 |
+
border: 1px solid var(--glass-border);
|
| 612 |
+
color: var(--text-main);
|
| 613 |
+
padding: 0.75rem 1rem;
|
| 614 |
+
border-radius: 0.75rem;
|
| 615 |
+
cursor: pointer;
|
| 616 |
+
font-size: 0.9rem;
|
| 617 |
+
transition: none;
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
#theme-toggle:hover {
|
| 621 |
+
background: var(--sidebar-active-bg);
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
.mode-icon {
|
| 625 |
+
font-size: 1.1rem;
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
.sidebar.collapsed .mode-text {
|
| 629 |
+
display: none;
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
.sidebar.collapsed .theme-toggle-container {
|
| 633 |
+
display: flex;
|
| 634 |
+
justify-content: center;
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
#sidebar-toggle {
|
| 638 |
+
margin-top: 1rem;
|
| 639 |
+
width: 40px;
|
| 640 |
+
height: 40px;
|
| 641 |
+
display: flex;
|
| 642 |
+
align-items: center;
|
| 643 |
+
justify-content: center;
|
| 644 |
+
background: var(--input-bg);
|
| 645 |
+
border-radius: 50%;
|
| 646 |
+
margin-left: auto;
|
| 647 |
+
font-size: 1.25rem;
|
| 648 |
+
transition: none;
|
| 649 |
+
z-index: 10;
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
#sidebar-toggle:hover {
|
| 653 |
+
background: var(--sidebar-active-bg);
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
.sidebar.collapsed #sidebar-toggle {
|
| 657 |
+
transform: rotate(180deg);
|
| 658 |
+
margin-left: 0.5rem;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
.sidebar.collapsed .sidebar-footer .status-text {
|
| 662 |
+
display: none;
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
.sidebar.collapsed .status-pill {
|
| 666 |
+
padding: 0.5rem;
|
| 667 |
+
justify-content: center;
|
| 668 |
+
}
|
quickmt/langid.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Union, Optional
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
import urllib.request
|
| 5 |
+
import fasttext
|
| 6 |
+
|
| 7 |
+
# Suppress fasttext's warning about being loaded in a way that doesn't
|
| 8 |
+
# allow querying its version (common in some environments)
|
| 9 |
+
fasttext.FastText.eprint = lambda x: None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LanguageIdentification:
|
| 13 |
+
"""Detect language using a FastText langid model.
|
| 14 |
+
|
| 15 |
+
This class provides a wrapper around the FastText library for efficient
|
| 16 |
+
language identification, supporting both single-string and batch processing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model_path: Optional[Union[str, Path]] = None):
|
| 20 |
+
"""Initialize the LanguageIdentification model.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model_path: Path to the pre-trained FastText model file.
|
| 24 |
+
If None, defaults to 'models/lid.176.bin' and downloads if missing.
|
| 25 |
+
"""
|
| 26 |
+
if model_path is None:
|
| 27 |
+
cache_dir = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
|
| 28 |
+
model_dir = cache_dir / "fasttext_language_id"
|
| 29 |
+
model_path = model_dir / "lid.176.bin"
|
| 30 |
+
|
| 31 |
+
model_path = Path(model_path)
|
| 32 |
+
|
| 33 |
+
if not model_path.exists():
|
| 34 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
|
| 36 |
+
print(f"Downloading FastText model from {url} to {model_path}...")
|
| 37 |
+
urllib.request.urlretrieve(url, str(model_path))
|
| 38 |
+
print("Download complete.")
|
| 39 |
+
|
| 40 |
+
self.ft = fasttext.load_model(str(model_path))
|
| 41 |
+
|
| 42 |
+
def predict(
|
| 43 |
+
self,
|
| 44 |
+
text: Union[str, List[str]],
|
| 45 |
+
k: int = 1,
|
| 46 |
+
threshold: float = 0.0
|
| 47 |
+
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
|
| 48 |
+
"""Predict the language(s) for the given text or list of texts.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
text: A single string or a list of strings to identify.
|
| 52 |
+
k: Number of most likely languages to return. Defaults to 1.
|
| 53 |
+
threshold: Minimum score for a language to be included in the results.
|
| 54 |
+
Defaults to 0.0 (return all k results regardless of score).
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
If input is a string: A list of (lang, score) tuples.
|
| 58 |
+
If input is a list of strings: A list of lists of (lang, score) tuples,
|
| 59 |
+
maintaining the input order.
|
| 60 |
+
"""
|
| 61 |
+
is_single = isinstance(text, str)
|
| 62 |
+
items = [text] if is_single else text
|
| 63 |
+
|
| 64 |
+
# Sanitize inputs: FastText errors on newlines
|
| 65 |
+
items = [t.replace("\n", " ") for t in items]
|
| 66 |
+
|
| 67 |
+
# FastText predict handles lists natively and is faster than looping
|
| 68 |
+
ft_output = self.ft.predict(items, k=k, threshold=threshold)
|
| 69 |
+
|
| 70 |
+
# FastText returns ([['__label__en', ...], ...], [[0.9, ...], ...])
|
| 71 |
+
labels, scores = ft_output
|
| 72 |
+
|
| 73 |
+
results = []
|
| 74 |
+
for item_labels, item_scores in zip(labels, scores):
|
| 75 |
+
item_results = [
|
| 76 |
+
(label.replace("__label__", ""), float(score))
|
| 77 |
+
for label, score in zip(item_labels, item_scores)
|
| 78 |
+
]
|
| 79 |
+
results.append(item_results)
|
| 80 |
+
|
| 81 |
+
return results[0] if is_single else results
|
| 82 |
+
|
| 83 |
+
def predict_best(
|
| 84 |
+
self,
|
| 85 |
+
text: Union[str, List[str]],
|
| 86 |
+
threshold: float = 0.0
|
| 87 |
+
) -> Union[Optional[str], List[Optional[str]]]:
|
| 88 |
+
"""Predict the most likely language for the given text or list of texts.
|
| 89 |
+
|
| 90 |
+
This is a convenience wrapper around `predict` that returns only the
|
| 91 |
+
top-scoring language label (or None if no language exceeds the threshold).
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
text: A single string or a list of strings to identify.
|
| 95 |
+
threshold: Minimum score for a language to be selected.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
If input is a string: The language code (e.g., 'en') or None.
|
| 99 |
+
If input is a list: A list of language codes or None.
|
| 100 |
+
"""
|
| 101 |
+
results = self.predict(text, k=1, threshold=threshold)
|
| 102 |
+
|
| 103 |
+
if isinstance(text, str):
|
| 104 |
+
# results is List[Tuple[str, float]]
|
| 105 |
+
return results[0][0] if results else None
|
| 106 |
+
else:
|
| 107 |
+
# results is List[List[Tuple[str, float]]]
|
| 108 |
+
return [r[0][0] if r else None for r in results]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def ensure_model_exists(model_path: Optional[Union[str, Path]] = None):
|
| 112 |
+
"""Ensure the FastText model exists on disk, downloading if necessary.
|
| 113 |
+
This should be called from the main process before starting worker pools.
|
| 114 |
+
"""
|
| 115 |
+
if model_path is None:
|
| 116 |
+
cache_dir = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache"))
|
| 117 |
+
model_dir = cache_dir / "fasttext_language_id"
|
| 118 |
+
model_path = model_dir / "lid.176.bin"
|
| 119 |
+
|
| 120 |
+
model_path = Path(model_path)
|
| 121 |
+
|
| 122 |
+
if not model_path.exists():
|
| 123 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
|
| 125 |
+
print(f"Downloading FastText model from {url} to {model_path}...")
|
| 126 |
+
urllib.request.urlretrieve(url, str(model_path))
|
| 127 |
+
print("Download complete.")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Global detector instance for process pool workers
|
| 131 |
+
_detector: Optional[LanguageIdentification] = None
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def init_worker(model_path: Optional[Union[str, Path]] = None):
|
| 135 |
+
"""Initialize the global detector instance for a worker process."""
|
| 136 |
+
global _detector
|
| 137 |
+
# We assume ensure_model_exists was already called in the main process
|
| 138 |
+
_detector = LanguageIdentification(model_path)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def predict_worker(
|
| 142 |
+
text: Union[str, List[str]],
|
| 143 |
+
k: int = 1,
|
| 144 |
+
threshold: float = 0.0
|
| 145 |
+
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
|
| 146 |
+
"""Prediction function to be run in a worker process."""
|
| 147 |
+
if _detector is None:
|
| 148 |
+
# Fallback if init_worker failed or wasn't called
|
| 149 |
+
init_worker()
|
| 150 |
+
return _detector.predict(text, k=k, threshold=threshold)
|
quickmt/manager.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
|
| 9 |
+
from fastapi import HTTPException
|
| 10 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 11 |
+
from cachetools import TTLCache, cached, LRUCache
|
| 12 |
+
|
| 13 |
+
from quickmt.translator import Translator
|
| 14 |
+
from quickmt.settings import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BatchTranslator:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model_id: str,
|
| 23 |
+
model_path: str,
|
| 24 |
+
device: str = "cpu",
|
| 25 |
+
compute_type: str = "default",
|
| 26 |
+
inter_threads: int = 1,
|
| 27 |
+
intra_threads: int = 0,
|
| 28 |
+
):
|
| 29 |
+
self.model_id = model_id
|
| 30 |
+
self.model_path = model_path
|
| 31 |
+
self.device = device
|
| 32 |
+
self.compute_type = compute_type
|
| 33 |
+
self.inter_threads = inter_threads
|
| 34 |
+
self.intra_threads = intra_threads
|
| 35 |
+
self.translator: Optional[Translator] = None
|
| 36 |
+
self.queue: asyncio.Queue = asyncio.Queue()
|
| 37 |
+
self.worker_task: Optional[asyncio.Task] = None
|
| 38 |
+
# LRU cache for translations
|
| 39 |
+
self.translation_cache: LRUCache = LRUCache(
|
| 40 |
+
maxsize=settings.translation_cache_size
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
async def start_worker(self):
|
| 44 |
+
if self.worker_task:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
# Load model in main process (or worker thread if needed)
|
| 48 |
+
# For now, Translator handles its own loading
|
| 49 |
+
self.translator = Translator(
|
| 50 |
+
Path(self.model_path),
|
| 51 |
+
device=self.device,
|
| 52 |
+
compute_type=self.compute_type,
|
| 53 |
+
inter_threads=self.inter_threads,
|
| 54 |
+
intra_threads=self.intra_threads,
|
| 55 |
+
)
|
| 56 |
+
self.worker_task = asyncio.create_task(self._worker())
|
| 57 |
+
logger.info(f"Started translation worker for model: {self.model_id}")
|
| 58 |
+
|
| 59 |
+
async def stop_worker(self):
|
| 60 |
+
if not self.worker_task:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
# Send sentinel to stop worker
|
| 64 |
+
await self.queue.put(None)
|
| 65 |
+
await self.worker_task
|
| 66 |
+
self.worker_task = None
|
| 67 |
+
if self.translator:
|
| 68 |
+
self.translator.unload()
|
| 69 |
+
self.translator = None
|
| 70 |
+
logger.info(f"Stopped translation worker for model: {self.model_id}")
|
| 71 |
+
|
| 72 |
+
async def _worker(self):
|
| 73 |
+
while True:
|
| 74 |
+
item = await self.queue.get()
|
| 75 |
+
if item is None:
|
| 76 |
+
self.queue.task_done()
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
src, src_lang, tgt_lang, kwargs, future = item
|
| 80 |
+
try:
|
| 81 |
+
# 1. Collect batch
|
| 82 |
+
batch_texts = [src]
|
| 83 |
+
futures = [future]
|
| 84 |
+
|
| 85 |
+
# Try to grab more items up to MAX_BATCH_SIZE or timeout
|
| 86 |
+
start_time = time.time()
|
| 87 |
+
while len(batch_texts) < settings.max_batch_size:
|
| 88 |
+
wait_time = (settings.batch_timeout_ms / 1000.0) - (
|
| 89 |
+
time.time() - start_time
|
| 90 |
+
)
|
| 91 |
+
if wait_time <= 0:
|
| 92 |
+
break
|
| 93 |
+
try:
|
| 94 |
+
next_item = await asyncio.wait_for(
|
| 95 |
+
self.queue.get(), timeout=wait_time
|
| 96 |
+
)
|
| 97 |
+
if next_item is None:
|
| 98 |
+
# Re-add sentinel to handle later
|
| 99 |
+
await self.queue.put(None)
|
| 100 |
+
break
|
| 101 |
+
n_src, n_sl, n_tl, n_kw, n_fut = next_item
|
| 102 |
+
|
| 103 |
+
# Only batch if parameters match exactly
|
| 104 |
+
if n_sl == src_lang and n_tl == tgt_lang and n_kw == kwargs:
|
| 105 |
+
batch_texts.append(n_src)
|
| 106 |
+
futures.append(n_fut)
|
| 107 |
+
else:
|
| 108 |
+
# Re-queue item for a later batch/worker cycle
|
| 109 |
+
await self.queue.put(next_item)
|
| 110 |
+
break
|
| 111 |
+
except asyncio.TimeoutError:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
# 2. Process batch
|
| 115 |
+
# Run in executor to avoid blocking the asyncio loop during inference
|
| 116 |
+
loop = asyncio.get_running_loop()
|
| 117 |
+
results = await loop.run_in_executor(
|
| 118 |
+
None,
|
| 119 |
+
lambda: self.translator(
|
| 120 |
+
batch_texts, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# result can be string or list
|
| 125 |
+
if isinstance(results, str):
|
| 126 |
+
results = [results]
|
| 127 |
+
|
| 128 |
+
# 3. Resolve futures
|
| 129 |
+
for res, fut in zip(results, futures):
|
| 130 |
+
if not fut.done():
|
| 131 |
+
fut.set_result(res)
|
| 132 |
+
|
| 133 |
+
# Mark done for all processed items
|
| 134 |
+
for _ in range(len(batch_texts)):
|
| 135 |
+
self.queue.task_done()
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error in translation worker for {self.model_id}: {e}")
|
| 139 |
+
if not future.done():
|
| 140 |
+
future.set_exception(e)
|
| 141 |
+
# TODO: handle others if batched
|
| 142 |
+
|
| 143 |
+
async def translate(
|
| 144 |
+
self, src: str, src_lang: str = None, tgt_lang: str = None, **kwargs
|
| 145 |
+
) -> str:
|
| 146 |
+
if not self.worker_task:
|
| 147 |
+
await self.start_worker()
|
| 148 |
+
|
| 149 |
+
# Create cache key from input parameters
|
| 150 |
+
# Convert kwargs to a sorted tuple for hashability
|
| 151 |
+
kwargs_tuple = tuple(sorted(kwargs.items()))
|
| 152 |
+
cache_key = (src, src_lang, tgt_lang, kwargs_tuple)
|
| 153 |
+
|
| 154 |
+
# Check cache first
|
| 155 |
+
if cache_key in self.translation_cache:
|
| 156 |
+
return self.translation_cache[cache_key]
|
| 157 |
+
|
| 158 |
+
# Cache miss - perform translation
|
| 159 |
+
future = asyncio.get_running_loop().create_future()
|
| 160 |
+
await self.queue.put((src, src_lang, tgt_lang, kwargs, future))
|
| 161 |
+
result = await future
|
| 162 |
+
|
| 163 |
+
# Store in cache
|
| 164 |
+
self.translation_cache[cache_key] = result
|
| 165 |
+
return result
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ModelManager:
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
max_loaded: int,
|
| 172 |
+
device: str,
|
| 173 |
+
compute_type: str = "default",
|
| 174 |
+
inter_threads: int = 1,
|
| 175 |
+
intra_threads: int = 0,
|
| 176 |
+
):
|
| 177 |
+
self.max_loaded = max_loaded
|
| 178 |
+
self.device = device
|
| 179 |
+
self.compute_type = compute_type
|
| 180 |
+
self.inter_threads = inter_threads
|
| 181 |
+
self.intra_threads = intra_threads
|
| 182 |
+
# cache key: src-tgt string
|
| 183 |
+
self.models: OrderedDict[str, BatchTranslator] = OrderedDict()
|
| 184 |
+
self.pending_loads: Dict[str, asyncio.Event] = {}
|
| 185 |
+
self.lock = asyncio.Lock()
|
| 186 |
+
self.hf_collection_models: List[Dict] = []
|
| 187 |
+
self.api = HfApi()
|
| 188 |
+
|
| 189 |
+
@cached(cache=TTLCache(maxsize=1, ttl=3600))
|
| 190 |
+
async def fetch_hf_models(self):
|
| 191 |
+
"""Fetch available models from the quickmt collection on Hugging Face."""
|
| 192 |
+
try:
|
| 193 |
+
loop = asyncio.get_running_loop()
|
| 194 |
+
collection = await loop.run_in_executor(
|
| 195 |
+
None, lambda: self.api.get_collection("quickmt/quickmt-models")
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
hf_models = []
|
| 199 |
+
for item in collection.items:
|
| 200 |
+
if item.item_type == "model":
|
| 201 |
+
model_id = item.item_id
|
| 202 |
+
# Expecting format: quickmt/quickmt-en-fr
|
| 203 |
+
parts = model_id.split("/")[-1].replace("quickmt-", "").split("-")
|
| 204 |
+
if len(parts) == 2:
|
| 205 |
+
src, tgt = parts
|
| 206 |
+
hf_models.append(
|
| 207 |
+
{"model_id": model_id, "src_lang": src, "tgt_lang": tgt}
|
| 208 |
+
)
|
| 209 |
+
self.hf_collection_models = hf_models
|
| 210 |
+
logger.info(
|
| 211 |
+
f"Discovered {len(hf_models)} models from Hugging Face collection"
|
| 212 |
+
)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"Failed to fetch models from Hugging Face: {e}")
|
| 215 |
+
|
| 216 |
+
async def get_model(self, src_lang: str, tgt_lang: str) -> BatchTranslator:
|
| 217 |
+
model_name = f"{src_lang}-{tgt_lang}"
|
| 218 |
+
|
| 219 |
+
async with self.lock:
|
| 220 |
+
# 1. Check if loaded
|
| 221 |
+
if model_name in self.models:
|
| 222 |
+
self.models.move_to_end(model_name)
|
| 223 |
+
return self.models[model_name]
|
| 224 |
+
|
| 225 |
+
# 2. Check if currently loading
|
| 226 |
+
if model_name in self.pending_loads:
|
| 227 |
+
event = self.pending_loads[model_name]
|
| 228 |
+
else:
|
| 229 |
+
# NEW: Pre-check existence before starting task to ensure clean 404
|
| 230 |
+
hf_model = next(
|
| 231 |
+
(
|
| 232 |
+
m
|
| 233 |
+
for m in self.hf_collection_models
|
| 234 |
+
if m["src_lang"] == src_lang and m["tgt_lang"] == tgt_lang
|
| 235 |
+
),
|
| 236 |
+
None,
|
| 237 |
+
)
|
| 238 |
+
if not hf_model:
|
| 239 |
+
raise HTTPException(
|
| 240 |
+
status_code=404,
|
| 241 |
+
detail=f"Model for {src_lang}->{tgt_lang} not found in Hugging Face collection",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
event = asyncio.Event()
|
| 245 |
+
self.pending_loads[model_name] = event
|
| 246 |
+
# This task will do the actual loading
|
| 247 |
+
asyncio.create_task(self._load_model_task(src_lang, tgt_lang, event))
|
| 248 |
+
|
| 249 |
+
# 3. Wait for load
|
| 250 |
+
await event.wait()
|
| 251 |
+
|
| 252 |
+
# 4. Return from cache
|
| 253 |
+
async with self.lock:
|
| 254 |
+
return self.models[model_name]
|
| 255 |
+
|
| 256 |
+
async def _load_model_task(
|
| 257 |
+
self, src_lang: str, tgt_lang: str, new_event: asyncio.Event
|
| 258 |
+
):
|
| 259 |
+
model_name = f"{src_lang}-{tgt_lang}"
|
| 260 |
+
try:
|
| 261 |
+
try:
|
| 262 |
+
# Find matching model from HF collection (already checked in get_model)
|
| 263 |
+
hf_model = next(
|
| 264 |
+
m
|
| 265 |
+
for m in self.hf_collection_models
|
| 266 |
+
if m["src_lang"] == src_lang and m["tgt_lang"] == tgt_lang
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
logger.info(f"Accessing Hugging Face model: {hf_model['model_id']}")
|
| 270 |
+
loop = asyncio.get_running_loop()
|
| 271 |
+
# snapshot_download returns the local path in the HF cache.
|
| 272 |
+
# Try local only first to speed up loading
|
| 273 |
+
try:
|
| 274 |
+
cached_path = await loop.run_in_executor(
|
| 275 |
+
None,
|
| 276 |
+
lambda: snapshot_download(
|
| 277 |
+
repo_id=hf_model["model_id"],
|
| 278 |
+
ignore_patterns=["eole-model/*", "eole_model/*"],
|
| 279 |
+
local_files_only=True,
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
except Exception:
|
| 283 |
+
# Fallback to checking online
|
| 284 |
+
logger.info(
|
| 285 |
+
f"Model {hf_model['model_id']} not fully cached, checking online..."
|
| 286 |
+
)
|
| 287 |
+
cached_path = await loop.run_in_executor(
|
| 288 |
+
None,
|
| 289 |
+
lambda: snapshot_download(
|
| 290 |
+
repo_id=hf_model["model_id"],
|
| 291 |
+
ignore_patterns=["eole-model/*", "eole_model/*"],
|
| 292 |
+
),
|
| 293 |
+
)
|
| 294 |
+
model_path = Path(cached_path)
|
| 295 |
+
|
| 296 |
+
# Prepare for eviction
|
| 297 |
+
evicted_model = None
|
| 298 |
+
async with self.lock:
|
| 299 |
+
if len(self.models) >= self.max_loaded:
|
| 300 |
+
oldest_name, evicted_model = self.models.popitem(last=False)
|
| 301 |
+
logger.info(f"Evicting model: {oldest_name}")
|
| 302 |
+
|
| 303 |
+
if evicted_model:
|
| 304 |
+
await evicted_model.stop_worker()
|
| 305 |
+
|
| 306 |
+
# Load new model (SLOW, outside lock)
|
| 307 |
+
logger.info(
|
| 308 |
+
f"Loading model: {hf_model['model_id']} (device: {self.device}, compute: {self.compute_type})"
|
| 309 |
+
)
|
| 310 |
+
new_model = BatchTranslator(
|
| 311 |
+
model_id=hf_model["model_id"],
|
| 312 |
+
model_path=str(model_path),
|
| 313 |
+
device=self.device,
|
| 314 |
+
compute_type=self.compute_type,
|
| 315 |
+
inter_threads=self.inter_threads,
|
| 316 |
+
intra_threads=self.intra_threads,
|
| 317 |
+
)
|
| 318 |
+
await new_model.start_worker()
|
| 319 |
+
|
| 320 |
+
# Add to cache
|
| 321 |
+
async with self.lock:
|
| 322 |
+
self.models[model_name] = new_model
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.error(f"Error loading model {model_name}: {e}")
|
| 326 |
+
# We still need to set the event to unblock waiters,
|
| 327 |
+
# but we should probably handle errors better in get_model
|
| 328 |
+
raise e
|
| 329 |
+
finally:
|
| 330 |
+
async with self.lock:
|
| 331 |
+
if model_name in self.pending_loads:
|
| 332 |
+
del self.pending_loads[model_name]
|
| 333 |
+
new_event.set()
|
| 334 |
+
|
| 335 |
+
def list_available_models(self) -> List[Dict]:
|
| 336 |
+
"""List all models discovered from Hugging Face."""
|
| 337 |
+
available = []
|
| 338 |
+
for m in self.hf_collection_models:
|
| 339 |
+
lang_pair = f"{m['src_lang']}-{m['tgt_lang']}"
|
| 340 |
+
available.append(
|
| 341 |
+
{
|
| 342 |
+
"model_id": m["model_id"],
|
| 343 |
+
"src_lang": m["src_lang"],
|
| 344 |
+
"tgt_lang": m["tgt_lang"],
|
| 345 |
+
"loaded": lang_pair in self.models,
|
| 346 |
+
}
|
| 347 |
+
)
|
| 348 |
+
return available
|
| 349 |
+
|
| 350 |
+
@lru_cache(maxsize=1)
|
| 351 |
+
def get_language_pairs(self) -> Dict[str, List[str]]:
|
| 352 |
+
"""Return a dictionary of source languages to list of supported target languages."""
|
| 353 |
+
pairs: Dict[str, set] = {}
|
| 354 |
+
for m in self.hf_collection_models:
|
| 355 |
+
src = m["src_lang"]
|
| 356 |
+
tgt = m["tgt_lang"]
|
| 357 |
+
if src not in pairs:
|
| 358 |
+
pairs[src] = set()
|
| 359 |
+
pairs[src].add(tgt)
|
| 360 |
+
|
| 361 |
+
# Convert sets to sorted lists
|
| 362 |
+
return {src: sorted(list(tgts)) for src, tgts in sorted(pairs.items())}
|
| 363 |
+
|
| 364 |
+
async def shutdown(self):
|
| 365 |
+
for name, model in self.models.items():
|
| 366 |
+
await model.stop_worker()
|
| 367 |
+
self.models.clear()
|
quickmt/rest_server.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import List, Optional, Union, Dict
|
| 7 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, HTTPException, APIRouter
|
| 10 |
+
from fastapi.responses import ORJSONResponse
|
| 11 |
+
from fastapi.staticfiles import StaticFiles
|
| 12 |
+
from pydantic import BaseModel, model_validator
|
| 13 |
+
|
| 14 |
+
from quickmt.langid import init_worker, predict_worker, ensure_model_exists
|
| 15 |
+
from quickmt.manager import ModelManager
|
| 16 |
+
from quickmt.settings import settings
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("uvicorn.error")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TranslationRequest(BaseModel):
|
| 22 |
+
src: Union[str, List[str]]
|
| 23 |
+
src_lang: Optional[Union[str, List[str]]] = None
|
| 24 |
+
tgt_lang: str = "en"
|
| 25 |
+
beam_size: int = 5
|
| 26 |
+
patience: int = 1
|
| 27 |
+
length_penalty: float = 1.0
|
| 28 |
+
coverage_penalty: float = 0.0
|
| 29 |
+
repetition_penalty: float = 1.0
|
| 30 |
+
max_decoding_length: int = 256
|
| 31 |
+
|
| 32 |
+
@model_validator(mode="after")
|
| 33 |
+
def validate_patience(self):
|
| 34 |
+
if self.patience > self.beam_size:
|
| 35 |
+
raise ValueError("patience cannot be greater than beam_size")
|
| 36 |
+
return self
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TranslationResponse(BaseModel):
|
| 40 |
+
translation: Union[str, List[str]]
|
| 41 |
+
src_lang: Union[str, List[str]]
|
| 42 |
+
src_lang_score: Union[float, List[float]]
|
| 43 |
+
tgt_lang: str
|
| 44 |
+
processing_time: float
|
| 45 |
+
model_used: Union[str, List[str]]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DetectionRequest(BaseModel):
|
| 49 |
+
src: Union[str, List[str]]
|
| 50 |
+
k: int = 1
|
| 51 |
+
threshold: float = 0.0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DetectionResult(BaseModel):
|
| 55 |
+
lang: str
|
| 56 |
+
score: float
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DetectionResponse(BaseModel):
|
| 60 |
+
results: Union[List[DetectionResult], List[List[DetectionResult]]]
|
| 61 |
+
processing_time: float
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class BatchItem:
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
src: List[str],
|
| 68 |
+
src_lang: str,
|
| 69 |
+
tgt_lang: str,
|
| 70 |
+
beam_size: int,
|
| 71 |
+
max_decoding_length: int,
|
| 72 |
+
future: asyncio.Future,
|
| 73 |
+
):
|
| 74 |
+
self.src = src
|
| 75 |
+
self.src_lang = src_lang
|
| 76 |
+
self.tgt_lang = tgt_lang
|
| 77 |
+
self.beam_size = beam_size
|
| 78 |
+
self.max_decoding_length = max_decoding_length
|
| 79 |
+
self.future = future
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Global instances initialized in lifespan
|
| 83 |
+
model_manager: Optional[ModelManager] = None
|
| 84 |
+
langid_executor: Optional[ProcessPoolExecutor] = None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@asynccontextmanager
|
| 88 |
+
async def lifespan(app: FastAPI):
|
| 89 |
+
global model_manager, langid_executor
|
| 90 |
+
|
| 91 |
+
model_manager = ModelManager(
|
| 92 |
+
max_loaded=settings.max_loaded_models,
|
| 93 |
+
device=settings.device,
|
| 94 |
+
compute_type=settings.compute_type,
|
| 95 |
+
inter_threads=settings.inter_threads,
|
| 96 |
+
intra_threads=settings.intra_threads,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# 1. Fetch available models from Hugging Face
|
| 100 |
+
await model_manager.fetch_hf_models()
|
| 101 |
+
|
| 102 |
+
# 2. Ensure langid model is downloaded in main process before starting workers
|
| 103 |
+
loop = asyncio.get_running_loop()
|
| 104 |
+
await loop.run_in_executor(None, ensure_model_exists, settings.langid_model_path)
|
| 105 |
+
|
| 106 |
+
# Initialize langid process pool
|
| 107 |
+
langid_executor = ProcessPoolExecutor(
|
| 108 |
+
max_workers=settings.langid_workers,
|
| 109 |
+
initializer=init_worker,
|
| 110 |
+
initargs=(settings.langid_model_path,),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
yield
|
| 114 |
+
|
| 115 |
+
if langid_executor:
|
| 116 |
+
langid_executor.shutdown()
|
| 117 |
+
await model_manager.shutdown()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
app = FastAPI(
|
| 121 |
+
title="quickmt Multi-Model API",
|
| 122 |
+
lifespan=lifespan,
|
| 123 |
+
default_response_class=ORJSONResponse,
|
| 124 |
+
)
|
| 125 |
+
api_router = APIRouter(prefix="/api")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@api_router.post("/translate", response_model=TranslationResponse)
|
| 129 |
+
async def translate_endpoint(request: TranslationRequest):
|
| 130 |
+
if not model_manager:
|
| 131 |
+
raise HTTPException(status_code=503, detail="Model manager not initialized")
|
| 132 |
+
|
| 133 |
+
start_time = time.time()
|
| 134 |
+
src_list = [request.src] if isinstance(request.src, str) else request.src
|
| 135 |
+
if not src_list:
|
| 136 |
+
return TranslationResponse(
|
| 137 |
+
translation="" if isinstance(request.src, str) else [],
|
| 138 |
+
src_lang="" if isinstance(request.src, str) else [],
|
| 139 |
+
src_lang_score=0.0 if isinstance(request.src, str) else [],
|
| 140 |
+
tgt_lang=request.tgt_lang,
|
| 141 |
+
processing_time=time.time() - start_time,
|
| 142 |
+
model_used="none",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
loop = asyncio.get_running_loop()
|
| 147 |
+
|
| 148 |
+
# 1. Determine source languages and confidence scores
|
| 149 |
+
if request.src_lang:
|
| 150 |
+
if isinstance(request.src_lang, list):
|
| 151 |
+
if not isinstance(src_list, list) or len(request.src_lang) != len(
|
| 152 |
+
src_list
|
| 153 |
+
):
|
| 154 |
+
raise HTTPException(
|
| 155 |
+
status_code=422,
|
| 156 |
+
detail="src_lang list length must match src list length",
|
| 157 |
+
)
|
| 158 |
+
src_langs = request.src_lang
|
| 159 |
+
src_lang_scores = [1.0] * len(src_list)
|
| 160 |
+
else:
|
| 161 |
+
src_langs = [request.src_lang] * len(src_list)
|
| 162 |
+
src_lang_scores = [1.0] * len(src_list)
|
| 163 |
+
else:
|
| 164 |
+
if not langid_executor:
|
| 165 |
+
raise HTTPException(
|
| 166 |
+
status_code=503, detail="Language identification not initialized"
|
| 167 |
+
)
|
| 168 |
+
# Batch detect languages
|
| 169 |
+
raw_langid_results = await loop.run_in_executor(
|
| 170 |
+
langid_executor,
|
| 171 |
+
predict_worker,
|
| 172 |
+
src_list,
|
| 173 |
+
1, # k=1 (best guess)
|
| 174 |
+
0.0, # threshold
|
| 175 |
+
)
|
| 176 |
+
# results are List[List[Tuple[str, float]]], extract labels and scores
|
| 177 |
+
src_langs = [r[0][0] if r else "unknown" for r in raw_langid_results]
|
| 178 |
+
src_lang_scores = [float(r[0][1]) if r else 0.0 for r in raw_langid_results]
|
| 179 |
+
|
| 180 |
+
# 2. Group indices by source language
|
| 181 |
+
# groups: { "fr": [0, 2, ...], "es": [1, ...] }
|
| 182 |
+
groups: Dict[str, List[int]] = {}
|
| 183 |
+
for idx, lang in enumerate(src_langs):
|
| 184 |
+
if lang not in groups:
|
| 185 |
+
groups[lang] = []
|
| 186 |
+
groups[lang].append(idx)
|
| 187 |
+
|
| 188 |
+
# 3. Process each group
|
| 189 |
+
final_translations = [""] * len(src_list)
|
| 190 |
+
final_models = [""] * len(src_list)
|
| 191 |
+
tasks = []
|
| 192 |
+
|
| 193 |
+
# We need a way to track which lang pairs were actually used for the 'model_used' string
|
| 194 |
+
used_pairs = set()
|
| 195 |
+
|
| 196 |
+
for lang, indices in groups.items():
|
| 197 |
+
group_src = [src_list[i] for i in indices]
|
| 198 |
+
|
| 199 |
+
# Optimization: If src == tgt, skip translation
|
| 200 |
+
if lang == request.tgt_lang:
|
| 201 |
+
for src_idx, idx in enumerate(indices):
|
| 202 |
+
final_translations[idx] = group_src[src_idx]
|
| 203 |
+
final_models[idx] = "identity"
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
# Load model and translate for this group
|
| 207 |
+
async def process_group_task(l=lang, i_list=indices, g_src=group_src):
|
| 208 |
+
try:
|
| 209 |
+
translator = await model_manager.get_model(l, request.tgt_lang)
|
| 210 |
+
used_pairs.add(translator.model_id)
|
| 211 |
+
# Call translate for each sentence; BatchTranslator will handle opportunistic batching
|
| 212 |
+
translation_tasks = [
|
| 213 |
+
translator.translate(
|
| 214 |
+
s,
|
| 215 |
+
src_lang=l,
|
| 216 |
+
tgt_lang=request.tgt_lang,
|
| 217 |
+
beam_size=request.beam_size,
|
| 218 |
+
patience=request.patience,
|
| 219 |
+
length_penalty=request.length_penalty,
|
| 220 |
+
coverage_penalty=request.coverage_penalty,
|
| 221 |
+
repetition_penalty=request.repetition_penalty,
|
| 222 |
+
max_decoding_length=request.max_decoding_length,
|
| 223 |
+
)
|
| 224 |
+
for s in g_src
|
| 225 |
+
]
|
| 226 |
+
results = await asyncio.gather(*translation_tasks)
|
| 227 |
+
for result_idx, original_idx in enumerate(i_list):
|
| 228 |
+
final_translations[original_idx] = results[result_idx]
|
| 229 |
+
final_models[original_idx] = translator.model_id
|
| 230 |
+
except HTTPException as e:
|
| 231 |
+
# If a specific model is missing, we could either fail the whole batch
|
| 232 |
+
# or keep original text. Here we fail for consistency with previous behavior.
|
| 233 |
+
raise e
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.error(f"Error translating {l} to {request.tgt_lang}: {e}")
|
| 236 |
+
raise e
|
| 237 |
+
|
| 238 |
+
tasks.append(process_group_task())
|
| 239 |
+
|
| 240 |
+
if tasks:
|
| 241 |
+
await asyncio.gather(*tasks)
|
| 242 |
+
|
| 243 |
+
# 4. Prepare response
|
| 244 |
+
if isinstance(request.src, str):
|
| 245 |
+
result = final_translations[0]
|
| 246 |
+
src_lang_res = src_langs[0]
|
| 247 |
+
src_lang_score_res = src_lang_scores[0]
|
| 248 |
+
model_used_res = final_models[0]
|
| 249 |
+
else:
|
| 250 |
+
result = final_translations
|
| 251 |
+
src_lang_res = src_langs
|
| 252 |
+
src_lang_score_res = src_lang_scores
|
| 253 |
+
model_used_res = final_models
|
| 254 |
+
|
| 255 |
+
return TranslationResponse(
|
| 256 |
+
translation=result,
|
| 257 |
+
src_lang=src_lang_res,
|
| 258 |
+
src_lang_score=src_lang_score_res,
|
| 259 |
+
tgt_lang=request.tgt_lang,
|
| 260 |
+
processing_time=time.time() - start_time,
|
| 261 |
+
model_used=model_used_res,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
except HTTPException:
|
| 265 |
+
raise
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.exception("Unexpected error in translate_endpoint")
|
| 268 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@api_router.post("/identify-language", response_model=DetectionResponse)
|
| 272 |
+
async def identify_language_endpoint(request: DetectionRequest):
|
| 273 |
+
if not langid_executor:
|
| 274 |
+
raise HTTPException(
|
| 275 |
+
status_code=503, detail="Language identification not initialized"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
start_time = time.time()
|
| 279 |
+
try:
|
| 280 |
+
loop = asyncio.get_running_loop()
|
| 281 |
+
# Offload detection to process pool to avoid GIL issues
|
| 282 |
+
raw_results = await loop.run_in_executor(
|
| 283 |
+
langid_executor, predict_worker, request.src, request.k, request.threshold
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Convert raw tuples to Pydantic models
|
| 287 |
+
if isinstance(request.src, str):
|
| 288 |
+
results = [
|
| 289 |
+
DetectionResult(lang=lang, score=score) for lang, score in raw_results
|
| 290 |
+
]
|
| 291 |
+
else:
|
| 292 |
+
results = [
|
| 293 |
+
[
|
| 294 |
+
DetectionResult(lang=lang, score=score)
|
| 295 |
+
for lang, score in item_results
|
| 296 |
+
]
|
| 297 |
+
for item_results in raw_results
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
return DetectionResponse(
|
| 301 |
+
results=results, processing_time=time.time() - start_time
|
| 302 |
+
)
|
| 303 |
+
except Exception as e:
|
| 304 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@api_router.get("/models")
|
| 308 |
+
async def get_models():
|
| 309 |
+
if not model_manager:
|
| 310 |
+
raise HTTPException(status_code=503, detail="Model manager not initialized")
|
| 311 |
+
return {"models": model_manager.list_available_models()}
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@api_router.get("/languages")
|
| 315 |
+
async def get_languages():
|
| 316 |
+
if not model_manager:
|
| 317 |
+
raise HTTPException(status_code=503, detail="Model manager not initialized")
|
| 318 |
+
return model_manager.get_language_pairs()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@api_router.get("/health")
|
| 322 |
+
async def health_check():
|
| 323 |
+
loaded_models = list(model_manager.models.keys()) if model_manager else []
|
| 324 |
+
return {
|
| 325 |
+
"status": "ok",
|
| 326 |
+
"loaded_models": loaded_models,
|
| 327 |
+
"max_models": settings.max_loaded_models,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
app.include_router(api_router)
|
| 332 |
+
|
| 333 |
+
# Serve static files for the GUI
|
| 334 |
+
static_dir = os.path.join(os.path.dirname(__file__), "gui", "static")
|
| 335 |
+
if os.path.exists(static_dir):
|
| 336 |
+
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def start():
|
| 340 |
+
"""Entry point for the quickmt-serve CLI."""
|
| 341 |
+
import uvicorn
|
| 342 |
+
|
| 343 |
+
uvicorn.run("quickmt.rest_server:app", host="0.0.0.0", port=8000, reload=False)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def start_gui():
|
| 347 |
+
"""Entry point for the quickmt-gui CLI."""
|
| 348 |
+
import uvicorn
|
| 349 |
+
import webbrowser
|
| 350 |
+
import threading
|
| 351 |
+
import time
|
| 352 |
+
|
| 353 |
+
def open_browser():
|
| 354 |
+
time.sleep(1.5)
|
| 355 |
+
webbrowser.open("http://127.0.0.1:8000")
|
| 356 |
+
|
| 357 |
+
threading.Thread(target=open_browser, daemon=True).start()
|
| 358 |
+
uvicorn.run("quickmt.rest_server:app", host="0.0.0.0", port=8000, reload=False)
|
quickmt/settings.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Centralized configuration management using pydantic-settings.
|
| 2 |
+
|
| 3 |
+
This module provides a type-safe, centralized way to manage all configuration
|
| 4 |
+
settings for the quickmt library. Settings can be configured via:
|
| 5 |
+
- Environment variables (e.g., MAX_LOADED_MODELS=10)
|
| 6 |
+
- .env file in the project root
|
| 7 |
+
- Runtime modification of the global settings object
|
| 8 |
+
|
| 9 |
+
All environment variables are case-insensitive.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import Optional
|
| 13 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Settings(BaseSettings):
|
| 17 |
+
"""Application settings with environment variable support.
|
| 18 |
+
|
| 19 |
+
All settings can be overridden via environment variables.
|
| 20 |
+
For example, to set max_loaded_models, use MAX_LOADED_MODELS=10
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Model Manager Settings
|
| 24 |
+
max_loaded_models: int = 5
|
| 25 |
+
"""Maximum number of translation models to keep loaded in memory"""
|
| 26 |
+
|
| 27 |
+
device: str = "cpu"
|
| 28 |
+
"""Device to use for inference: 'cpu', 'cuda', or 'auto'"""
|
| 29 |
+
|
| 30 |
+
compute_type: str = "default"
|
| 31 |
+
"""CTranslate2 compute type: 'default', 'int8', 'int8_float16', 'int16', 'float16', 'float32'"""
|
| 32 |
+
|
| 33 |
+
inter_threads: int = 1
|
| 34 |
+
"""Number of threads to use for inter-op parallelism (simultaneous translations)"""
|
| 35 |
+
|
| 36 |
+
intra_threads: int = 4
|
| 37 |
+
"""Number of threads to use for intra-op parallelism (within each translation)"""
|
| 38 |
+
|
| 39 |
+
# Batch Processing Settings
|
| 40 |
+
max_batch_size: int = 32
|
| 41 |
+
"""Maximum batch size for translation requests"""
|
| 42 |
+
|
| 43 |
+
batch_timeout_ms: int = 5
|
| 44 |
+
"""Timeout in milliseconds to wait for batching additional requests"""
|
| 45 |
+
|
| 46 |
+
# Language Identification Settings
|
| 47 |
+
langid_model_path: Optional[str] = None
|
| 48 |
+
"""Path to FastText language identification model. If None, uses default cache location"""
|
| 49 |
+
|
| 50 |
+
langid_workers: int = 2
|
| 51 |
+
"""Number of worker processes for language identification"""
|
| 52 |
+
|
| 53 |
+
# Translation Cache Settings
|
| 54 |
+
translation_cache_size: int = 10000
|
| 55 |
+
"""Maximum number of translations to cache (LRU eviction)"""
|
| 56 |
+
|
| 57 |
+
model_config = SettingsConfigDict(
|
| 58 |
+
env_prefix="",
|
| 59 |
+
case_sensitive=False,
|
| 60 |
+
env_file=".env",
|
| 61 |
+
env_file_encoding="utf-8",
|
| 62 |
+
extra="ignore",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Global settings instance
|
| 67 |
+
# This can be imported and used throughout the application
|
| 68 |
+
# Settings can be modified at runtime: settings.max_loaded_models = 10
|
| 69 |
+
settings = Settings()
|
quickmt/translator.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from time import time
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import ctranslate2
|
| 7 |
+
import sentencepiece
|
| 8 |
+
from blingfire import text_to_sentences
|
| 9 |
+
from pydantic import DirectoryPath, validate_call
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TranslatorABC(ABC):
|
| 13 |
+
def __init__(self, model_path: DirectoryPath, **kwargs):
|
| 14 |
+
"""Create quickmt translation object
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
model_path (DirectoryPath): Path to quickmt model folder
|
| 18 |
+
**kwargs: CTranslate2 Translator arguments - see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html
|
| 19 |
+
"""
|
| 20 |
+
self.model_path = Path(model_path)
|
| 21 |
+
self.translator = ctranslate2.Translator(str(model_path), **kwargs)
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
@validate_call
|
| 25 |
+
def _sentence_split(src: List[str]):
|
| 26 |
+
"""Split sentences with Blingfire
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
src (List[str]): Input list of strings to split by sentences
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List[int], List[int], List[str]: List of input ids, list of paragraph ids and sentences
|
| 33 |
+
"""
|
| 34 |
+
input_ids = []
|
| 35 |
+
paragraph_ids = []
|
| 36 |
+
sentences = []
|
| 37 |
+
for idx, i in enumerate(src):
|
| 38 |
+
for paragraph, j in enumerate(i.splitlines(keepends=True)):
|
| 39 |
+
sents = text_to_sentences(j).splitlines()
|
| 40 |
+
for sent in sents:
|
| 41 |
+
stripped_sent = sent.strip()
|
| 42 |
+
if len(stripped_sent) > 0:
|
| 43 |
+
if (
|
| 44 |
+
len(stripped_sent) < 5
|
| 45 |
+
and len(paragraph_ids) > 0
|
| 46 |
+
and paragraph == paragraph_ids[-1]
|
| 47 |
+
and len(input_ids) > 0
|
| 48 |
+
and input_ids[-1] == idx
|
| 49 |
+
):
|
| 50 |
+
sentences[-1] += " " + stripped_sent
|
| 51 |
+
else:
|
| 52 |
+
input_ids.append(idx)
|
| 53 |
+
paragraph_ids.append(paragraph)
|
| 54 |
+
sentences.append(stripped_sent)
|
| 55 |
+
|
| 56 |
+
return input_ids, paragraph_ids, sentences
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
@validate_call
|
| 60 |
+
def _sentence_join(
|
| 61 |
+
input_ids: List[int],
|
| 62 |
+
paragraph_ids: List[int],
|
| 63 |
+
sentences: List[str],
|
| 64 |
+
paragraph_join_str: str = "\n",
|
| 65 |
+
sent_join_str: str = " ",
|
| 66 |
+
length: Optional[int] = None,
|
| 67 |
+
):
|
| 68 |
+
"""Sentence joiner
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input_ids (List[int]): List of input IDs
|
| 72 |
+
paragraph_ids (List[int]): List of paragraph IDs
|
| 73 |
+
sentences (List[str]): List of sentences to join up by input and paragraph ids
|
| 74 |
+
paragraph_join_str (str, optional): str to use to join paragraphs. Defaults to "\n".
|
| 75 |
+
sent_join_str (str, optional): str to join up sentences. Defaults to " ".
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
List[str]: Joined up sentences
|
| 79 |
+
"""
|
| 80 |
+
if not input_ids:
|
| 81 |
+
return [""] * (length or 0)
|
| 82 |
+
|
| 83 |
+
target_len = length if length is not None else (max(input_ids) + 1)
|
| 84 |
+
ret = [""] * target_len
|
| 85 |
+
last_paragraph = 0
|
| 86 |
+
for idx, paragraph, text in zip(input_ids, paragraph_ids, sentences):
|
| 87 |
+
if len(ret[idx]) > 0:
|
| 88 |
+
if paragraph == last_paragraph:
|
| 89 |
+
ret[idx] += sent_join_str + text
|
| 90 |
+
else:
|
| 91 |
+
ret[idx] += paragraph_join_str + text
|
| 92 |
+
last_paragraph = paragraph
|
| 93 |
+
else:
|
| 94 |
+
ret[idx] = text
|
| 95 |
+
last_paragraph = paragraph
|
| 96 |
+
return ret
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def tokenize(
|
| 100 |
+
self,
|
| 101 |
+
sentences: List[str],
|
| 102 |
+
src_lang: Optional[str] = None,
|
| 103 |
+
tgt_lang: Optional[str] = None,
|
| 104 |
+
): ...
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def detokenize(
|
| 108 |
+
self,
|
| 109 |
+
sentences: List[List[str]],
|
| 110 |
+
src_lang: Optional[str] = None,
|
| 111 |
+
tgt_lang: Optional[str] = None,
|
| 112 |
+
): ...
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def translate_batch(
|
| 116 |
+
self,
|
| 117 |
+
sentences: List[List[str]],
|
| 118 |
+
src_lang: Optional[str] = None,
|
| 119 |
+
tgt_lang: Optional[str] = None,
|
| 120 |
+
): ...
|
| 121 |
+
|
| 122 |
+
@abstractmethod
|
| 123 |
+
def unload(self): ...
|
| 124 |
+
|
| 125 |
+
@validate_call
|
| 126 |
+
def __call__(
|
| 127 |
+
self,
|
| 128 |
+
src: Union[str, List[str]],
|
| 129 |
+
max_batch_size: int = 32,
|
| 130 |
+
max_decoding_length: int = 256,
|
| 131 |
+
beam_size: int = 2,
|
| 132 |
+
patience: int = 1,
|
| 133 |
+
length_penalty: float = 1.0,
|
| 134 |
+
coverage_penalty: float = 0.0,
|
| 135 |
+
repetition_penalty: float = 1.0,
|
| 136 |
+
verbose: bool = False,
|
| 137 |
+
src_lang: Union[None, str] = None,
|
| 138 |
+
tgt_lang: Union[None, str] = None,
|
| 139 |
+
**kwargs,
|
| 140 |
+
) -> Union[str, List[str]]:
|
| 141 |
+
"""Translate a list of strings with quickmt model
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
src (List[str]): Input list of strings to translate
|
| 145 |
+
max_batch_size (int, optional): Maximum batch size, to constrain RAM utilization. Defaults to 32.
|
| 146 |
+
beam_size (int, optional): CTranslate2 Beam size. Defaults to 5.
|
| 147 |
+
patience (int, optional): CTranslate2 Patience. Defaults to 1.
|
| 148 |
+
max_decoding_length (int, optional): Maximum length of translation
|
| 149 |
+
**args: Other CTranslate2 translate_batch args, see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Union[str, List[str]]: Translation of the input
|
| 153 |
+
"""
|
| 154 |
+
if isinstance(src, str):
|
| 155 |
+
return_string = True
|
| 156 |
+
src = [src]
|
| 157 |
+
else:
|
| 158 |
+
return_string = False
|
| 159 |
+
|
| 160 |
+
indices, paragraphs, sentences = self._sentence_split(src)
|
| 161 |
+
|
| 162 |
+
if not sentences:
|
| 163 |
+
return "" if return_string else [""] * len(src)
|
| 164 |
+
|
| 165 |
+
if verbose:
|
| 166 |
+
print(f"Split sentences: {sentences}")
|
| 167 |
+
|
| 168 |
+
input_text = self.tokenize(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
| 169 |
+
if verbose:
|
| 170 |
+
print(f"Tokenized input: {input_text}")
|
| 171 |
+
|
| 172 |
+
t1 = time()
|
| 173 |
+
results = self.translate_batch(
|
| 174 |
+
input_text,
|
| 175 |
+
beam_size=beam_size,
|
| 176 |
+
patience=patience,
|
| 177 |
+
length_penalty=length_penalty,
|
| 178 |
+
coverage_penalty=coverage_penalty,
|
| 179 |
+
repetition_penalty=repetition_penalty,
|
| 180 |
+
max_decoding_length=max_decoding_length,
|
| 181 |
+
max_batch_size=max_batch_size,
|
| 182 |
+
src_lang=src_lang,
|
| 183 |
+
tgt_lang=tgt_lang,
|
| 184 |
+
**kwargs,
|
| 185 |
+
)
|
| 186 |
+
t2 = time()
|
| 187 |
+
if verbose:
|
| 188 |
+
print(f"Translation time: {t2 - t1}")
|
| 189 |
+
|
| 190 |
+
output_tokens = [i.hypotheses[0] for i in results]
|
| 191 |
+
|
| 192 |
+
if verbose:
|
| 193 |
+
print(f"Tokenized output: {output_tokens}")
|
| 194 |
+
|
| 195 |
+
translated_sents = self.detokenize(
|
| 196 |
+
output_tokens, src_lang=src_lang, tgt_lang=tgt_lang
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
ret = self._sentence_join(
|
| 200 |
+
indices, paragraphs, translated_sents, length=len(src)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if return_string:
|
| 204 |
+
return ret[0]
|
| 205 |
+
else:
|
| 206 |
+
return ret
|
| 207 |
+
|
| 208 |
+
@validate_call
|
| 209 |
+
def translate_file(self, input_file: str, output_file: str, **kwargs) -> None:
|
| 210 |
+
"""Translate a file with a quickmt model
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
file_path (str): Path to plain-text file to translate
|
| 214 |
+
"""
|
| 215 |
+
with open(input_file, "rt") as myfile:
|
| 216 |
+
src = myfile.readlines()
|
| 217 |
+
|
| 218 |
+
# Remove newlines
|
| 219 |
+
src = [i.strip() for i in src]
|
| 220 |
+
|
| 221 |
+
# Translate
|
| 222 |
+
mt = self(src, **kwargs)
|
| 223 |
+
|
| 224 |
+
# Replace newlines to ensure output is the same number of lines
|
| 225 |
+
mt = [i.replace("\n", "\t") for i in mt]
|
| 226 |
+
|
| 227 |
+
with open(output_file, "wt") as myfile:
|
| 228 |
+
myfile.write("".join([i + "\n" for i in mt]))
|
| 229 |
+
|
| 230 |
+
@validate_call
|
| 231 |
+
def translate_stream(
|
| 232 |
+
self,
|
| 233 |
+
src: Union[str, List[str]],
|
| 234 |
+
max_batch_size: int = 32,
|
| 235 |
+
max_decoding_length: int = 256,
|
| 236 |
+
beam_size: int = 5,
|
| 237 |
+
patience: int = 1,
|
| 238 |
+
src_lang: Union[None, str] = None,
|
| 239 |
+
tgt_lang: Union[None, str] = None,
|
| 240 |
+
**kwargs,
|
| 241 |
+
):
|
| 242 |
+
"""Translate a list of strings with quickmt model
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
src (List[str]): Input list of strings to translate
|
| 246 |
+
max_batch_size (int, optional): Maximum batch size, to constrain RAM utilization. Defaults to 32.
|
| 247 |
+
beam_size (int, optional): CTranslate2 Beam size. Defaults to 5.
|
| 248 |
+
patience (int, optional): CTranslate2 Patience. Defaults to 1.
|
| 249 |
+
max_decoding_length (int, optional): Maximum length of translation
|
| 250 |
+
**args: Other CTranslate2 translate_batch args, see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch
|
| 251 |
+
"""
|
| 252 |
+
if isinstance(src, str):
|
| 253 |
+
src = [src]
|
| 254 |
+
|
| 255 |
+
indices, paragraphs, sentences = self._sentence_split(src)
|
| 256 |
+
|
| 257 |
+
input_text = self.tokenize(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
| 258 |
+
|
| 259 |
+
translations_iterator = self.translator.translate_iterable(
|
| 260 |
+
input_text,
|
| 261 |
+
beam_size=beam_size,
|
| 262 |
+
patience=patience,
|
| 263 |
+
max_decoding_length=max_decoding_length,
|
| 264 |
+
max_batch_size=max_batch_size,
|
| 265 |
+
**kwargs,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
for idx, para, sent, output in zip(
|
| 269 |
+
indices, paragraphs, sentences, translations_iterator
|
| 270 |
+
):
|
| 271 |
+
yield {
|
| 272 |
+
"input_idx": idx,
|
| 273 |
+
"sentence_idx": para,
|
| 274 |
+
"input_text": sent,
|
| 275 |
+
"translation": self.detokenize([output.hypotheses[0]])[0],
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class Translator(TranslatorABC):
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
model_path: DirectoryPath,
|
| 283 |
+
inter_threads: int = 1,
|
| 284 |
+
intra_threads: int = 0,
|
| 285 |
+
**kwargs,
|
| 286 |
+
):
|
| 287 |
+
"""Create quickmt translation object
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
model_path (DirectoryPath): Path to quickmt model folder
|
| 291 |
+
inter_threads (int): Number of simultaneous translations
|
| 292 |
+
intra_threads (int): Number of threads for each translation
|
| 293 |
+
**kwargs: CTranslate2 Translator arguments - see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html
|
| 294 |
+
"""
|
| 295 |
+
super().__init__(
|
| 296 |
+
model_path,
|
| 297 |
+
inter_threads=inter_threads,
|
| 298 |
+
intra_threads=intra_threads,
|
| 299 |
+
**kwargs,
|
| 300 |
+
)
|
| 301 |
+
joint_tokenizer_path = self.model_path / "joint.spm.model"
|
| 302 |
+
if joint_tokenizer_path.exists():
|
| 303 |
+
self.source_tokenizer = sentencepiece.SentencePieceProcessor(
|
| 304 |
+
model_file=str(self.model_path / "joint.spm.model")
|
| 305 |
+
)
|
| 306 |
+
self.target_tokenizer = sentencepiece.SentencePieceProcessor(
|
| 307 |
+
model_file=str(self.model_path / "joint.spm.model")
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
self.source_tokenizer = sentencepiece.SentencePieceProcessor(
|
| 311 |
+
model_file=str(self.model_path / "src.spm.model")
|
| 312 |
+
)
|
| 313 |
+
self.target_tokenizer = sentencepiece.SentencePieceProcessor(
|
| 314 |
+
model_file=str(self.model_path / "tgt.spm.model")
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def __del__(self):
|
| 318 |
+
self.unload()
|
| 319 |
+
|
| 320 |
+
def tokenize(
|
| 321 |
+
self,
|
| 322 |
+
sentences: List[str],
|
| 323 |
+
src_lang: Optional[str] = None,
|
| 324 |
+
tgt_lang: Optional[str] = None,
|
| 325 |
+
):
|
| 326 |
+
# Default implementation ignores lang tags unless explicitly handled
|
| 327 |
+
return [
|
| 328 |
+
i + ["</s>"] for i in self.source_tokenizer.encode(sentences, out_type=str)
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
def detokenize(
|
| 332 |
+
self,
|
| 333 |
+
sentences: List[List[str]],
|
| 334 |
+
src_lang: Optional[str] = None,
|
| 335 |
+
tgt_lang: Optional[str] = None,
|
| 336 |
+
):
|
| 337 |
+
return self.target_tokenizer.decode(sentences)
|
| 338 |
+
|
| 339 |
+
def unload(self):
|
| 340 |
+
"""Explicitly release CTranslate2 translator resources"""
|
| 341 |
+
if hasattr(self, "translator"):
|
| 342 |
+
del self.translator
|
| 343 |
+
|
| 344 |
+
def translate_batch(
|
| 345 |
+
self,
|
| 346 |
+
input_text: List[List[str]],
|
| 347 |
+
beam_size: int = 5,
|
| 348 |
+
patience: int = 1,
|
| 349 |
+
max_decoding_length: int = 256,
|
| 350 |
+
max_batch_size: int = 32,
|
| 351 |
+
disable_unk: bool = True,
|
| 352 |
+
replace_unknowns: bool = False,
|
| 353 |
+
length_penalty: float = 1.0,
|
| 354 |
+
coverage_penalty: float = 0.0,
|
| 355 |
+
repetition_penalty: float = 1.0,
|
| 356 |
+
src_lang: str = None,
|
| 357 |
+
tgt_lang: str = None,
|
| 358 |
+
**kwargs,
|
| 359 |
+
):
|
| 360 |
+
"""Translate a list of strings
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
input_text (List[List[str]]): Input text to be translated
|
| 364 |
+
beam_size (int, optional): Beam size for beam search. Defaults to 5.
|
| 365 |
+
patience (int, optional): Stop beam search when `patience` beams finish. Defaults to 1.
|
| 366 |
+
max_decoding_length (int, optional): Max decoding length for model. Defaults to 256.
|
| 367 |
+
max_batch_size (int, optional): Max batch size. Reduce to limit RAM usage. Increase for faster speed. Defaults to 32.
|
| 368 |
+
disable_unk (bool, optional): Disable generating unk token. Defaults to True.
|
| 369 |
+
replace_unknowns (bool, optional): Replace unk tokens with src token that has the highest attention value. Defaults to False.
|
| 370 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
| 371 |
+
coverage_penalty (float, optional): Coverage penalty. Defaults to 0.0.
|
| 372 |
+
src_lang (str, optional): Source language. Only needed for multilingual models. Defaults to None.
|
| 373 |
+
tgt_lang (str, optional): Target language. Only needed for multilingual models. Defaults to None.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
List[str]: Translated text
|
| 377 |
+
"""
|
| 378 |
+
return self.translator.translate_batch(
|
| 379 |
+
input_text,
|
| 380 |
+
beam_size=beam_size,
|
| 381 |
+
patience=patience,
|
| 382 |
+
max_decoding_length=max_decoding_length,
|
| 383 |
+
max_batch_size=max_batch_size,
|
| 384 |
+
disable_unk=disable_unk,
|
| 385 |
+
replace_unknowns=replace_unknowns,
|
| 386 |
+
length_penalty=length_penalty,
|
| 387 |
+
coverage_penalty=coverage_penalty,
|
| 388 |
+
repetition_penalty=repetition_penalty,
|
| 389 |
+
**kwargs,
|
| 390 |
+
)
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest
|
| 2 |
+
pytest-asyncio
|
| 3 |
+
httpx
|
| 4 |
+
locust
|
| 5 |
+
sacrebleu
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blingfire
|
| 2 |
+
cachetools
|
| 3 |
+
fastapi
|
| 4 |
+
uvicorn[standard]
|
| 5 |
+
ctranslate2
|
| 6 |
+
sentencepiece
|
| 7 |
+
huggingface_hub
|
| 8 |
+
fasttext-wheel
|
| 9 |
+
orjson
|
| 10 |
+
uvloop
|
| 11 |
+
httptools
|
| 12 |
+
pydantic
|
| 13 |
+
pydantic-settings
|
| 14 |
+
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import os
|
| 3 |
+
from typing import AsyncGenerator
|
| 4 |
+
from quickmt.rest_server import app
|
| 5 |
+
from httpx import AsyncClient
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture(scope="session")
|
| 9 |
+
def base_url() -> str:
|
| 10 |
+
return os.getenv("TEST_BASE_URL", "http://127.0.0.1:8000")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
async def client(base_url: str) -> AsyncGenerator[AsyncClient, None]:
|
| 15 |
+
async with AsyncClient(base_url=base_url, timeout=60.0) as client:
|
| 16 |
+
yield client
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
from httpx import AsyncClient
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@pytest.mark.asyncio
|
| 7 |
+
async def test_health_check(client: AsyncClient):
|
| 8 |
+
response = await client.get("/api/health")
|
| 9 |
+
assert response.status_code == 200
|
| 10 |
+
data = response.json()
|
| 11 |
+
assert data["status"] == "ok"
|
| 12 |
+
assert "loaded_models" in data
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.mark.asyncio
|
| 16 |
+
async def test_get_models(client: AsyncClient):
|
| 17 |
+
response = await client.get("/api/models")
|
| 18 |
+
assert response.status_code == 200
|
| 19 |
+
data = response.json()
|
| 20 |
+
assert "models" in data
|
| 21 |
+
assert isinstance(data["models"], list)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.mark.asyncio
|
| 25 |
+
async def test_get_languages(client: AsyncClient):
|
| 26 |
+
response = await client.get("/api/languages")
|
| 27 |
+
assert response.status_code == 200
|
| 28 |
+
data = response.json()
|
| 29 |
+
assert isinstance(data, dict)
|
| 30 |
+
# Check structure if models exist
|
| 31 |
+
if data:
|
| 32 |
+
src = list(data.keys())[0]
|
| 33 |
+
assert isinstance(data[src], list)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_translate_single(client: AsyncClient):
|
| 38 |
+
# First, find an available model
|
| 39 |
+
models_res = await client.get("/api/models")
|
| 40 |
+
models = models_res.json()["models"]
|
| 41 |
+
if not models:
|
| 42 |
+
pytest.skip("No models available in MODELS_DIR")
|
| 43 |
+
|
| 44 |
+
model = models[0]
|
| 45 |
+
payload = {
|
| 46 |
+
"src": "Hello world",
|
| 47 |
+
"src_lang": model["src_lang"],
|
| 48 |
+
"tgt_lang": model["tgt_lang"],
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
response = await client.post("/api/translate", json=payload)
|
| 52 |
+
assert response.status_code == 200
|
| 53 |
+
data = response.json()
|
| 54 |
+
assert "translation" in data
|
| 55 |
+
assert "processing_time" in data
|
| 56 |
+
assert data["src_lang"] == model["src_lang"]
|
| 57 |
+
assert data["src_lang_score"] == 1.0
|
| 58 |
+
assert data["tgt_lang"] == model["tgt_lang"]
|
| 59 |
+
assert data["model_used"] == model["model_id"]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@pytest.mark.asyncio
|
| 63 |
+
async def test_translate_list(client: AsyncClient):
|
| 64 |
+
models_res = await client.get("/api/models")
|
| 65 |
+
models = models_res.json()["models"]
|
| 66 |
+
if not models:
|
| 67 |
+
pytest.skip("No models available")
|
| 68 |
+
|
| 69 |
+
model = models[0]
|
| 70 |
+
payload = {
|
| 71 |
+
"src": ["Hello", "World"],
|
| 72 |
+
"src_lang": model["src_lang"],
|
| 73 |
+
"tgt_lang": model["tgt_lang"],
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
response = await client.post("/api/translate", json=payload)
|
| 77 |
+
assert response.status_code == 200
|
| 78 |
+
data = response.json()
|
| 79 |
+
assert isinstance(data["translation"], list)
|
| 80 |
+
assert len(data["translation"]) == 2
|
| 81 |
+
assert data["src_lang"] == [model["src_lang"], model["src_lang"]]
|
| 82 |
+
assert data["src_lang_score"] == [1.0, 1.0]
|
| 83 |
+
assert data["tgt_lang"] == model["tgt_lang"]
|
| 84 |
+
assert data["model_used"] == [model["model_id"], model["model_id"]]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@pytest.mark.asyncio
|
| 88 |
+
async def test_dynamic_batching(client: AsyncClient):
|
| 89 |
+
"""Verify that multiple concurrent requests work correctly (triggering batching logic)."""
|
| 90 |
+
models_res = await client.get("/api/models")
|
| 91 |
+
models = models_res.json()["models"]
|
| 92 |
+
if not models:
|
| 93 |
+
pytest.skip("No models available")
|
| 94 |
+
|
| 95 |
+
model = models[0]
|
| 96 |
+
src, tgt = model["src_lang"], model["tgt_lang"]
|
| 97 |
+
|
| 98 |
+
texts = [f"Sentence number {i}" for i in range(5)]
|
| 99 |
+
tasks = []
|
| 100 |
+
|
| 101 |
+
for text in texts:
|
| 102 |
+
payload = {"src": text, "src_lang": src, "tgt_lang": tgt}
|
| 103 |
+
tasks.append(client.post("/api/translate", json=payload))
|
| 104 |
+
|
| 105 |
+
responses = await asyncio.gather(*tasks)
|
| 106 |
+
|
| 107 |
+
for response in responses:
|
| 108 |
+
assert response.status_code == 200
|
| 109 |
+
assert "translation" in response.json()
|
tests/test_auto_translate.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.mark.asyncio
|
| 6 |
+
async def test_auto_detect_src_lang(client: AsyncClient):
|
| 7 |
+
"""Verify that src_lang is auto-detected if missing."""
|
| 8 |
+
# Ensure some models are available
|
| 9 |
+
models_res = await client.get("/api/models")
|
| 10 |
+
available_models = models_res.json()["models"]
|
| 11 |
+
if not any(
|
| 12 |
+
m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
|
| 13 |
+
):
|
| 14 |
+
pytest.skip("fr-en model needed for this test")
|
| 15 |
+
|
| 16 |
+
payload = {"src": "Bonjour tout le monde", "tgt_lang": "en"}
|
| 17 |
+
response = await client.post("/api/translate", json=payload)
|
| 18 |
+
assert response.status_code == 200
|
| 19 |
+
data = response.json()
|
| 20 |
+
assert "translation" in data
|
| 21 |
+
assert data["src_lang"] == "fr"
|
| 22 |
+
assert 0.0 < data["src_lang_score"] <= 1.0
|
| 23 |
+
assert data["tgt_lang"] == "en"
|
| 24 |
+
assert "quickmt/quickmt-fr-en" in data["model_used"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.mark.asyncio
|
| 28 |
+
async def test_default_tgt_lang(client: AsyncClient):
|
| 29 |
+
"""Verify that tgt_lang defaults to 'en'."""
|
| 30 |
+
models_res = await client.get("/api/models")
|
| 31 |
+
available_models = models_res.json()["models"]
|
| 32 |
+
if not any(
|
| 33 |
+
m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
|
| 34 |
+
):
|
| 35 |
+
pytest.skip("fr-en model needed for this test")
|
| 36 |
+
|
| 37 |
+
payload = {"src": "Bonjour", "src_lang": "fr"}
|
| 38 |
+
response = await client.post("/api/translate", json=payload)
|
| 39 |
+
assert response.status_code == 200
|
| 40 |
+
data = response.json()
|
| 41 |
+
assert data["src_lang"] == "fr"
|
| 42 |
+
assert data["src_lang_score"] == 1.0
|
| 43 |
+
assert data["tgt_lang"] == "en"
|
| 44 |
+
assert "quickmt/quickmt-fr-en" in data["model_used"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.mark.asyncio
|
| 48 |
+
async def test_mixed_language_batch(client: AsyncClient):
|
| 49 |
+
"""Verify that a batch with mixed languages is handled correctly."""
|
| 50 |
+
models_res = await client.get("/api/models")
|
| 51 |
+
available_models = models_res.json()["models"]
|
| 52 |
+
|
| 53 |
+
needed = [("fr", "en"), ("es", "en")]
|
| 54 |
+
for src, tgt in needed:
|
| 55 |
+
if not any(
|
| 56 |
+
m["src_lang"] == src and m["tgt_lang"] == tgt for m in available_models
|
| 57 |
+
):
|
| 58 |
+
pytest.skip(f"Mixed batch test needs both fr-en and es-en models")
|
| 59 |
+
|
| 60 |
+
payload = {"src": ["Bonjour tout le monde", "Hola amigos"], "tgt_lang": "en"}
|
| 61 |
+
response = await client.post("/api/translate", json=payload)
|
| 62 |
+
assert response.status_code == 200
|
| 63 |
+
data = response.json()
|
| 64 |
+
assert isinstance(data["translation"], list)
|
| 65 |
+
assert len(data["translation"]) == 2
|
| 66 |
+
assert data["src_lang"] == ["fr", "es"]
|
| 67 |
+
assert len(data["src_lang_score"]) == 2
|
| 68 |
+
assert all(0.0 < s <= 1.0 for s in data["src_lang_score"])
|
| 69 |
+
assert data["tgt_lang"] == "en"
|
| 70 |
+
assert "quickmt/quickmt-fr-en" in data["model_used"]
|
| 71 |
+
assert "quickmt/quickmt-es-en" in data["model_used"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@pytest.mark.asyncio
|
| 75 |
+
async def test_identity_translation(client: AsyncClient):
|
| 76 |
+
"""Verify that translation is skipped if src_lang == tgt_lang."""
|
| 77 |
+
payload = {"src": "This is already English", "src_lang": "en", "tgt_lang": "en"}
|
| 78 |
+
response = await client.post("/api/translate", json=payload)
|
| 79 |
+
assert response.status_code == 200
|
| 80 |
+
data = response.json()
|
| 81 |
+
assert data["translation"] == "This is already English"
|
| 82 |
+
assert data["src_lang"] == "en"
|
| 83 |
+
assert data["src_lang_score"] == 1.0
|
| 84 |
+
assert data["tgt_lang"] == "en"
|
| 85 |
+
assert data["model_used"] == "identity"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@pytest.mark.asyncio
|
| 89 |
+
async def test_auto_detect_mixed_identity(client: AsyncClient):
|
| 90 |
+
"""Verify mixed batch with some items needing translation and some remaining as-is."""
|
| 91 |
+
models_res = await client.get("/api/models")
|
| 92 |
+
available_models = models_res.json()["models"]
|
| 93 |
+
if not any(
|
| 94 |
+
m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
|
| 95 |
+
):
|
| 96 |
+
pytest.skip("fr-en model needed for this test")
|
| 97 |
+
|
| 98 |
+
payload = {"src": ["Bonjour", "Hello world"], "tgt_lang": "en"}
|
| 99 |
+
response = await client.post("/api/translate", json=payload)
|
| 100 |
+
assert response.status_code == 200
|
| 101 |
+
data = response.json()
|
| 102 |
+
assert len(data["translation"]) == 2
|
| 103 |
+
assert data["src_lang"] == ["fr", "en"]
|
| 104 |
+
assert len(data["src_lang_score"]) == 2
|
| 105 |
+
# First should be auto-detected, second should be auto-detected (and high confidence)
|
| 106 |
+
assert all(0.0 < s <= 1.0 for s in data["src_lang_score"])
|
| 107 |
+
assert data["tgt_lang"] == "en"
|
| 108 |
+
assert data["model_used"] == ["quickmt/quickmt-fr-en", "identity"]
|
tests/test_cache.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple test to demonstrate the translation cache functionality.
|
| 3 |
+
Run this to verify cache hits provide instant responses.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
from quickmt.manager import BatchTranslator
|
| 9 |
+
from quickmt.settings import settings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def test_translation_cache():
|
| 13 |
+
print("=== Translation Cache Test ===\n")
|
| 14 |
+
|
| 15 |
+
# Create a mock BatchTranslator (would normally be created by ModelManager)
|
| 16 |
+
# For this test, we'll just verify the cache mechanism
|
| 17 |
+
print(f"Cache size configured: {settings.translation_cache_size}")
|
| 18 |
+
|
| 19 |
+
# Simulate cache behavior
|
| 20 |
+
from cachetools import LRUCache
|
| 21 |
+
|
| 22 |
+
cache = LRUCache(maxsize=settings.translation_cache_size)
|
| 23 |
+
|
| 24 |
+
# Test data
|
| 25 |
+
test_text = "Hello, world!"
|
| 26 |
+
src_lang = "en"
|
| 27 |
+
tgt_lang = "fr"
|
| 28 |
+
kwargs_tuple = tuple(sorted({"beam_size": 5, "patience": 1}.items()))
|
| 29 |
+
|
| 30 |
+
cache_key = (test_text, src_lang, tgt_lang, kwargs_tuple)
|
| 31 |
+
|
| 32 |
+
# First request - cache miss
|
| 33 |
+
print("\n1. First translation (cache miss):")
|
| 34 |
+
print(f" Key: {cache_key}")
|
| 35 |
+
if cache_key in cache:
|
| 36 |
+
print(" ✓ Cache HIT")
|
| 37 |
+
else:
|
| 38 |
+
print(" ✗ Cache MISS (expected)")
|
| 39 |
+
# Simulate translation and caching
|
| 40 |
+
cache[cache_key] = "Bonjour, monde!"
|
| 41 |
+
print(" → Cached result")
|
| 42 |
+
|
| 43 |
+
# Second request - cache hit
|
| 44 |
+
print("\n2. Repeated translation (cache hit):")
|
| 45 |
+
print(f" Key: {cache_key}")
|
| 46 |
+
if cache_key in cache:
|
| 47 |
+
print(" ✓ Cache HIT (instant!)")
|
| 48 |
+
print(f" → Result: {cache[cache_key]}")
|
| 49 |
+
else:
|
| 50 |
+
print(" ✗ Cache MISS (unexpected)")
|
| 51 |
+
|
| 52 |
+
# Different parameters - cache miss
|
| 53 |
+
different_kwargs = tuple(sorted({"beam_size": 10, "patience": 2}.items()))
|
| 54 |
+
different_key = (test_text, src_lang, tgt_lang, different_kwargs)
|
| 55 |
+
|
| 56 |
+
print("\n3. Same text, different parameters (cache miss):")
|
| 57 |
+
print(f" Key: {different_key}")
|
| 58 |
+
if different_key in cache:
|
| 59 |
+
print(" ✓ Cache HIT")
|
| 60 |
+
else:
|
| 61 |
+
print(" ✗ Cache MISS (expected - different params)")
|
| 62 |
+
|
| 63 |
+
print("\n✅ Cache test complete!")
|
| 64 |
+
print(f"Cache size: {len(cache)}/{settings.translation_cache_size}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
asyncio.run(test_translation_cache())
|
tests/test_identify_language.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.mark.asyncio
|
| 6 |
+
async def test_identify_language_single(client: AsyncClient):
|
| 7 |
+
"""Verify single string language identification."""
|
| 8 |
+
payload = {"src": "Hello, how are you?", "k": 1}
|
| 9 |
+
response = await client.post("/api/identify-language", json=payload)
|
| 10 |
+
assert response.status_code == 200
|
| 11 |
+
data = response.json()
|
| 12 |
+
assert "results" in data
|
| 13 |
+
assert "processing_time" in data
|
| 14 |
+
assert isinstance(data["results"], list)
|
| 15 |
+
# FastText should identify this as English
|
| 16 |
+
assert data["results"][0]["lang"] == "en"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.mark.asyncio
|
| 20 |
+
async def test_identify_language_batch(client: AsyncClient):
|
| 21 |
+
"""Verify batch language identification."""
|
| 22 |
+
payload = {"src": ["Bonjour tout le monde", "Hola amigos"], "k": 1}
|
| 23 |
+
response = await client.post("/api/identify-language", json=payload)
|
| 24 |
+
assert response.status_code == 200
|
| 25 |
+
data = response.json()
|
| 26 |
+
assert len(data["results"]) == 2
|
| 27 |
+
assert data["results"][0][0]["lang"] == "fr"
|
| 28 |
+
assert data["results"][1][0]["lang"] == "es"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_identify_language_threshold(client: AsyncClient):
|
| 33 |
+
"""Verify threshold filtering in the endpoint."""
|
| 34 |
+
payload = {"src": "This is definitely English", "k": 5, "threshold": 0.9}
|
| 35 |
+
response = await client.post("/api/identify-language", json=payload)
|
| 36 |
+
assert response.status_code == 200
|
| 37 |
+
data = response.json()
|
| 38 |
+
# Only 'en' should probably be above 0.9
|
| 39 |
+
assert len(data["results"]) == 1
|
| 40 |
+
assert data["results"][0]["lang"] == "en"
|
tests/test_langid.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from unittest.mock import MagicMock, patch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from quickmt.langid import LanguageIdentification
|
| 5 |
+
|
| 6 |
+
@pytest.fixture
|
| 7 |
+
def mock_fasttext():
|
| 8 |
+
with patch("fasttext.load_model") as mock_load:
|
| 9 |
+
mock_model = MagicMock()
|
| 10 |
+
mock_load.return_value = mock_model
|
| 11 |
+
|
| 12 |
+
# Configure default behavior for predict
|
| 13 |
+
def mock_predict(items, k=1, threshold=0.0):
|
| 14 |
+
# Return ([['__label__en', ...]], [[0.9, ...]])
|
| 15 |
+
labels = [["__label__en"] * k for _ in items]
|
| 16 |
+
scores = [[0.9] * k for _ in items]
|
| 17 |
+
return labels, scores
|
| 18 |
+
|
| 19 |
+
mock_model.predict.side_effect = mock_predict
|
| 20 |
+
yield mock_model
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def langid_model(mock_fasttext, tmp_path):
|
| 24 |
+
# Create a dummy model file so the existence check passes
|
| 25 |
+
model_path = tmp_path / "model.bin"
|
| 26 |
+
model_path.write_text("dummy content")
|
| 27 |
+
return LanguageIdentification(model_path)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_predict_single(langid_model, mock_fasttext):
|
| 31 |
+
result = langid_model.predict("Hello world")
|
| 32 |
+
|
| 33 |
+
assert isinstance(result, list)
|
| 34 |
+
assert len(result) == 1
|
| 35 |
+
assert result[0] == ("en", 0.9)
|
| 36 |
+
mock_fasttext.predict.assert_called_once_with(["Hello world"], k=1, threshold=0.0)
|
| 37 |
+
|
| 38 |
+
def test_predict_batch(langid_model, mock_fasttext):
|
| 39 |
+
texts = ["Hello", "Bonjour"]
|
| 40 |
+
results = langid_model.predict(texts, k=2)
|
| 41 |
+
|
| 42 |
+
assert isinstance(results, list)
|
| 43 |
+
assert len(results) == 2
|
| 44 |
+
for r in results:
|
| 45 |
+
assert len(r) == 2
|
| 46 |
+
assert r[0] == ("en", 0.9)
|
| 47 |
+
|
| 48 |
+
mock_fasttext.predict.assert_called_once_with(texts, k=2, threshold=0.0)
|
| 49 |
+
|
| 50 |
+
def test_predict_best_single(langid_model):
|
| 51 |
+
result = langid_model.predict_best("Hello")
|
| 52 |
+
assert result == "en"
|
| 53 |
+
|
| 54 |
+
def test_predict_best_batch(langid_model):
|
| 55 |
+
results = langid_model.predict_best(["Hello", "World"])
|
| 56 |
+
assert results == ["en", "en"]
|
| 57 |
+
|
| 58 |
+
def test_predict_threshold(langid_model, mock_fasttext):
|
| 59 |
+
# Configure mock to return nothing if threshold is high (simulated)
|
| 60 |
+
def mock_predict_low_score(items, k=1, threshold=0.0):
|
| 61 |
+
if threshold > 0.9:
|
| 62 |
+
return [[] for _ in items], [[] for _ in items]
|
| 63 |
+
return [["__label__en"] for _ in items], [[0.9] for _ in items]
|
| 64 |
+
|
| 65 |
+
mock_fasttext.predict.side_effect = mock_predict_low_score
|
| 66 |
+
|
| 67 |
+
result = langid_model.predict_best("Hello", threshold=0.95)
|
| 68 |
+
assert result is None
|
| 69 |
+
|
| 70 |
+
result = langid_model.predict_best("Hello", threshold=0.5)
|
| 71 |
+
assert result == "en"
|
tests/test_langid_batch.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.mark.asyncio
|
| 6 |
+
async def test_langid_batch(client: AsyncClient):
|
| 7 |
+
"""Verify that language identification works for a list of strings."""
|
| 8 |
+
payload = {"src": ["This is English text.", "Ceci est un texte français."]}
|
| 9 |
+
response = await client.post("/api/identify-language", json=payload)
|
| 10 |
+
assert response.status_code == 200
|
| 11 |
+
data = response.json()
|
| 12 |
+
|
| 13 |
+
# Expect a list of lists of DetectionResult
|
| 14 |
+
results = data["results"]
|
| 15 |
+
assert isinstance(results, list)
|
| 16 |
+
assert len(results) == 2
|
| 17 |
+
|
| 18 |
+
# First item: English
|
| 19 |
+
assert len(results[0]) >= 1
|
| 20 |
+
assert results[0][0]["lang"] == "en"
|
| 21 |
+
|
| 22 |
+
# Second item: French
|
| 23 |
+
assert len(results[1]) >= 1
|
| 24 |
+
assert results[1][0]["lang"] == "fr"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.mark.asyncio
|
| 28 |
+
async def test_langid_newline_handling(client: AsyncClient):
|
| 29 |
+
"""Verify that inputs with newlines are handled gracefully (no 500 error)."""
|
| 30 |
+
# Single string with newline
|
| 31 |
+
payload_single = {"src": "This text\nhas a newline."}
|
| 32 |
+
response = await client.post("/api/identify-language", json=payload_single)
|
| 33 |
+
assert response.status_code == 200
|
| 34 |
+
data = response.json()
|
| 35 |
+
assert data["results"][0]["lang"] == "en"
|
| 36 |
+
|
| 37 |
+
# Batch with newlines
|
| 38 |
+
payload_batch = {"src": ["Line 1\nLine 2", "Another\nline"]}
|
| 39 |
+
response = await client.post("/api/identify-language", json=payload_batch)
|
| 40 |
+
assert response.status_code == 200
|
| 41 |
+
data = response.json()
|
| 42 |
+
assert len(data["results"]) == 2
|
tests/test_langid_path.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
from quickmt.langid import ensure_model_exists, LanguageIdentification
|
| 5 |
+
|
| 6 |
+
def test_langid_default_path():
|
| 7 |
+
"""Verify that LanguageIdentification uses the XDG cache path by default."""
|
| 8 |
+
# Mock os.getenv to ensure we test the default behavior, but respect XDG_CACHE_HOME if we want to mock it.
|
| 9 |
+
# Here we simulate no explicit model path provided.
|
| 10 |
+
|
| 11 |
+
with patch("quickmt.langid.fasttext.load_model") as mock_load, \
|
| 12 |
+
patch("quickmt.langid.urllib.request.urlretrieve") as mock_retrieve, \
|
| 13 |
+
patch("pathlib.Path.exists") as mock_exists, \
|
| 14 |
+
patch("pathlib.Path.mkdir") as mock_mkdir:
|
| 15 |
+
|
| 16 |
+
# Simulate model cached and exists
|
| 17 |
+
mock_exists.return_value = True
|
| 18 |
+
|
| 19 |
+
lid = LanguageIdentification(model_path=None)
|
| 20 |
+
|
| 21 |
+
# Verify load_model was called with a path in the cache
|
| 22 |
+
args, _ = mock_load.call_args
|
| 23 |
+
loaded_path = str(args[0])
|
| 24 |
+
|
| 25 |
+
expected_part = os.path.join(".cache", "fasttext_language_id", "lid.176.bin")
|
| 26 |
+
assert expected_part in loaded_path
|
| 27 |
+
|
| 28 |
+
# Old path should not be used
|
| 29 |
+
assert "models/lid.176.ftz" not in loaded_path
|
| 30 |
+
|
| 31 |
+
def test_ensure_model_exists_path():
|
| 32 |
+
"""Verify ensure_model_exists resolves to cache path."""
|
| 33 |
+
with patch("quickmt.langid.urllib.request.urlretrieve") as mock_retrieve, \
|
| 34 |
+
patch("pathlib.Path.exists") as mock_exists, \
|
| 35 |
+
patch("pathlib.Path.mkdir") as mock_mkdir:
|
| 36 |
+
|
| 37 |
+
# Simulate model missing to trigger download logic path check
|
| 38 |
+
mock_exists.return_value = False
|
| 39 |
+
|
| 40 |
+
ensure_model_exists(None)
|
| 41 |
+
|
| 42 |
+
# Check download target
|
| 43 |
+
args, _ = mock_retrieve.call_args
|
| 44 |
+
download_target = str(args[1])
|
| 45 |
+
|
| 46 |
+
expected_part = os.path.join(".cache", "fasttext_language_id", "lid.176.bin")
|
| 47 |
+
assert expected_part in download_target
|
tests/test_lru.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.mark.asyncio
|
| 6 |
+
async def test_lru_eviction(client: AsyncClient):
|
| 7 |
+
"""
|
| 8 |
+
Test that the server correctly unloads the least recently used model
|
| 9 |
+
when MAX_LOADED_MODELS is exceeded.
|
| 10 |
+
"""
|
| 11 |
+
# 1. Get available models
|
| 12 |
+
models_res = await client.get("/api/models")
|
| 13 |
+
available_models = models_res.json()["models"]
|
| 14 |
+
|
| 15 |
+
# 2. Get MAX_LOADED_MODELS from health
|
| 16 |
+
health_res = await client.get("/api/health")
|
| 17 |
+
max_models = health_res.json()["max_models"]
|
| 18 |
+
|
| 19 |
+
if len(available_models) <= max_models:
|
| 20 |
+
pytest.skip(
|
| 21 |
+
f"Not enough models in MODELS_DIR to test eviction (need > {max_models})"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# 3. Load max_models + 1 models sequentially
|
| 25 |
+
loaded_in_order = []
|
| 26 |
+
for i in range(max_models + 1):
|
| 27 |
+
model = available_models[i]
|
| 28 |
+
payload = {
|
| 29 |
+
"src": "test",
|
| 30 |
+
"src_lang": model["src_lang"],
|
| 31 |
+
"tgt_lang": model["tgt_lang"],
|
| 32 |
+
}
|
| 33 |
+
await client.post("/api/translate", json=payload)
|
| 34 |
+
loaded_in_order.append(f"{model['src_lang']}-{model['tgt_lang']}")
|
| 35 |
+
|
| 36 |
+
# 4. Check currently loaded models
|
| 37 |
+
health_after = await client.get("/api/health")
|
| 38 |
+
currently_loaded = health_after.json()["loaded_models"]
|
| 39 |
+
|
| 40 |
+
# The first model should have been evicted
|
| 41 |
+
first_model = loaded_in_order[0]
|
| 42 |
+
assert first_model not in currently_loaded
|
| 43 |
+
assert len(currently_loaded) == max_models
|
| 44 |
+
|
| 45 |
+
# The most recently requested model should be there
|
| 46 |
+
last_model = loaded_in_order[-1]
|
| 47 |
+
assert last_model in currently_loaded
|
tests/test_manager.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from unittest.mock import MagicMock, patch, AsyncMock
|
| 5 |
+
from quickmt.manager import ModelManager, BatchTranslator
|
| 6 |
+
from quickmt.translator import Translator
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def mock_translator():
|
| 10 |
+
with patch("quickmt.manager.Translator") as mock:
|
| 11 |
+
instance = MagicMock()
|
| 12 |
+
mock.return_value = instance
|
| 13 |
+
yield instance
|
| 14 |
+
|
| 15 |
+
@pytest.fixture
|
| 16 |
+
def mock_hf():
|
| 17 |
+
with patch("quickmt.manager.snapshot_download") as mock_dl, \
|
| 18 |
+
patch("quickmt.manager.HfApi") as mock_api:
|
| 19 |
+
|
| 20 |
+
# Mock collection fetch
|
| 21 |
+
coll = MagicMock()
|
| 22 |
+
coll.items = [
|
| 23 |
+
MagicMock(item_id="quickmt/quickmt-en-fr", item_type="model"),
|
| 24 |
+
MagicMock(item_id="quickmt/quickmt-fr-en", item_type="model")
|
| 25 |
+
]
|
| 26 |
+
mock_api.return_value.get_collection.return_value = coll
|
| 27 |
+
mock_dl.return_value = "/tmp/mock-model-path"
|
| 28 |
+
|
| 29 |
+
yield mock_api, mock_dl
|
| 30 |
+
|
| 31 |
+
class TestBatchTranslator:
|
| 32 |
+
@pytest.mark.asyncio
|
| 33 |
+
async def test_translate_single(self, mock_translator):
|
| 34 |
+
bt = BatchTranslator("test-id", "/tmp/path")
|
| 35 |
+
|
| 36 |
+
# Mock translator call
|
| 37 |
+
mock_translator.return_value = "Hola"
|
| 38 |
+
|
| 39 |
+
result = await bt.translate("Hello", src_lang="en", tgt_lang="es")
|
| 40 |
+
assert result == "Hola"
|
| 41 |
+
assert bt.worker_task is not None
|
| 42 |
+
|
| 43 |
+
await bt.stop_worker()
|
| 44 |
+
assert bt.worker_task is None
|
| 45 |
+
|
| 46 |
+
class TestModelManager:
|
| 47 |
+
@pytest.mark.asyncio
|
| 48 |
+
async def test_fetch_hf_models(self, mock_hf):
|
| 49 |
+
mm = ModelManager(max_loaded=2, device="cpu")
|
| 50 |
+
await mm.fetch_hf_models()
|
| 51 |
+
|
| 52 |
+
assert len(mm.hf_collection_models) == 2
|
| 53 |
+
assert mm.hf_collection_models[0]["src_lang"] == "en"
|
| 54 |
+
assert mm.hf_collection_models[0]["tgt_lang"] == "fr"
|
| 55 |
+
|
| 56 |
+
@pytest.mark.asyncio
|
| 57 |
+
async def test_get_model_lazy_load(self, mock_hf, mock_translator):
|
| 58 |
+
mm = ModelManager(max_loaded=2, device="cpu")
|
| 59 |
+
await mm.fetch_hf_models()
|
| 60 |
+
|
| 61 |
+
# This should trigger download and start worker
|
| 62 |
+
bt = await mm.get_model("en", "fr")
|
| 63 |
+
assert isinstance(bt, BatchTranslator)
|
| 64 |
+
assert "en-fr" in mm.models
|
| 65 |
+
assert bt.model_id == "quickmt/quickmt-en-fr"
|
| 66 |
+
|
| 67 |
+
@pytest.mark.asyncio
|
| 68 |
+
async def test_lru_eviction(self, mock_hf, mock_translator):
|
| 69 |
+
# Set max_loaded to 1 to trigger eviction immediately
|
| 70 |
+
mm = ModelManager(max_loaded=1, device="cpu")
|
| 71 |
+
await mm.fetch_hf_models()
|
| 72 |
+
|
| 73 |
+
# Load first
|
| 74 |
+
bt1 = await mm.get_model("en", "fr")
|
| 75 |
+
assert len(mm.models) == 1
|
| 76 |
+
|
| 77 |
+
# Load second (should evict first)
|
| 78 |
+
bt2 = await mm.get_model("fr", "en")
|
| 79 |
+
assert len(mm.models) == 1
|
| 80 |
+
assert "fr-en" in mm.models
|
| 81 |
+
assert "en-fr" not in mm.models
|
| 82 |
+
@pytest.mark.asyncio
|
| 83 |
+
async def test_get_model_cache_first(self, mock_hf, mock_translator):
|
| 84 |
+
mock_api, mock_dl = mock_hf
|
| 85 |
+
mm = ModelManager(max_loaded=2, device="cpu")
|
| 86 |
+
await mm.fetch_hf_models()
|
| 87 |
+
|
| 88 |
+
# Scenario 1: Local cache hit
|
| 89 |
+
# Reset mock to track new calls
|
| 90 |
+
mock_dl.reset_mock()
|
| 91 |
+
mock_dl.return_value = "/tmp/mock-model-path"
|
| 92 |
+
|
| 93 |
+
await mm.get_model("en", "fr")
|
| 94 |
+
|
| 95 |
+
# Verify it tried local_files_only=True first
|
| 96 |
+
assert mock_dl.call_count == 1
|
| 97 |
+
args, kwargs = mock_dl.call_args
|
| 98 |
+
assert kwargs.get("local_files_only") is True
|
| 99 |
+
|
| 100 |
+
@pytest.mark.asyncio
|
| 101 |
+
async def test_get_model_fallback(self, mock_hf, mock_translator):
|
| 102 |
+
mock_api, mock_dl = mock_hf
|
| 103 |
+
mm = ModelManager(max_loaded=2, device="cpu")
|
| 104 |
+
await mm.fetch_hf_models()
|
| 105 |
+
|
| 106 |
+
# Scenario 2: Local cache miss, fallback to online
|
| 107 |
+
# First call fails, second succeeds
|
| 108 |
+
mock_dl.side_effect = [Exception("Not found locally"), "/tmp/mock-model-path"]
|
| 109 |
+
|
| 110 |
+
await mm.get_model("fr", "en")
|
| 111 |
+
|
| 112 |
+
assert mock_dl.call_count == 2
|
| 113 |
+
# First call was local only
|
| 114 |
+
args1, kwargs1 = mock_dl.call_args_list[0]
|
| 115 |
+
assert kwargs1.get("local_files_only") is True
|
| 116 |
+
# Second call was online (no local_files_only or False)
|
| 117 |
+
args2, kwargs2 = mock_dl.call_args_list[1]
|
| 118 |
+
assert not kwargs2.get("local_files_only")
|
tests/test_mixed_src.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.mark.asyncio
|
| 6 |
+
async def test_explicit_mixed_languages(client: AsyncClient):
|
| 7 |
+
"""Verify explicit src_lang list for a mixed batch."""
|
| 8 |
+
# Ensure needed models are available
|
| 9 |
+
models_res = await client.get("/api/models")
|
| 10 |
+
available_models = models_res.json()["models"]
|
| 11 |
+
|
| 12 |
+
needed = [("fr", "en"), ("es", "en")]
|
| 13 |
+
for src, tgt in needed:
|
| 14 |
+
if not any(
|
| 15 |
+
m["src_lang"] == src and m["tgt_lang"] == tgt for m in available_models
|
| 16 |
+
):
|
| 17 |
+
pytest.skip(f"Mixed batch test needs both fr-en and es-en models")
|
| 18 |
+
|
| 19 |
+
# Explicitly specify languages for each input
|
| 20 |
+
payload = {"src": ["Bonjour", "Hola"], "src_lang": ["fr", "es"], "tgt_lang": "en"}
|
| 21 |
+
|
| 22 |
+
response = await client.post("/api/translate", json=payload)
|
| 23 |
+
assert response.status_code == 200
|
| 24 |
+
data = response.json()
|
| 25 |
+
|
| 26 |
+
assert data["src_lang"] == ["fr", "es"]
|
| 27 |
+
assert data["model_used"] == ["quickmt/quickmt-fr-en", "quickmt/quickmt-es-en"]
|
| 28 |
+
assert len(data["translation"]) == 2
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.asyncio
|
| 32 |
+
async def test_src_lang_length_mismatch(client: AsyncClient):
|
| 33 |
+
"""Verify 422 error when src and src_lang lengths differ."""
|
| 34 |
+
payload = {
|
| 35 |
+
"src": ["Hello", "World"],
|
| 36 |
+
"src_lang": ["en"], # Only 1 language for 2 inputs
|
| 37 |
+
"tgt_lang": "es",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
response = await client.post("/api/translate", json=payload)
|
| 41 |
+
assert response.status_code == 422
|
| 42 |
+
assert (
|
| 43 |
+
"src_lang list length must match src list length" in response.json()["detail"]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.mark.asyncio
|
| 48 |
+
async def test_src_lang_list_with_single_src(client: AsyncClient):
|
| 49 |
+
"""Verify single src string with single-item src_lang list is not allowed or handled gracefully."""
|
| 50 |
+
# The Pydantic model allows this, but our logic checks lengths.
|
| 51 |
+
# If src is str, src_list has len 1. If src_lang is list, it must have len 1.
|
| 52 |
+
|
| 53 |
+
# Needs a model
|
| 54 |
+
models_res = await client.get("/api/models")
|
| 55 |
+
models = models_res.json()["models"]
|
| 56 |
+
if not models:
|
| 57 |
+
pytest.skip("No models available")
|
| 58 |
+
model = models[0]
|
| 59 |
+
|
| 60 |
+
payload = {
|
| 61 |
+
"src": "Hello",
|
| 62 |
+
"src_lang": [model["src_lang"]],
|
| 63 |
+
"tgt_lang": model["tgt_lang"],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
response = await client.post("/api/translate", json=payload)
|
| 67 |
+
assert response.status_code == 200
|
| 68 |
+
data = response.json()
|
| 69 |
+
assert data["src_lang"] == model["src_lang"]
|
tests/test_robustness.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
import time
|
| 4 |
+
from httpx import AsyncClient
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@pytest.mark.asyncio
|
| 8 |
+
async def test_model_not_found(client: AsyncClient):
|
| 9 |
+
"""Verify that requesting a non-existent model returns 404."""
|
| 10 |
+
payload = {
|
| 11 |
+
"src": "Hello",
|
| 12 |
+
"src_lang": "en",
|
| 13 |
+
"tgt_lang": "zz", # Non-existent
|
| 14 |
+
}
|
| 15 |
+
response = await client.post("/api/translate", json=payload)
|
| 16 |
+
assert response.status_code == 404
|
| 17 |
+
assert "not found" in response.json()["detail"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.mark.asyncio
|
| 21 |
+
async def test_empty_input_string(client: AsyncClient):
|
| 22 |
+
"""Verify handling of empty string input."""
|
| 23 |
+
models_res = await client.get("/api/models")
|
| 24 |
+
model = models_res.json()["models"][0]
|
| 25 |
+
|
| 26 |
+
payload = {"src": "", "src_lang": model["src_lang"], "tgt_lang": model["tgt_lang"]}
|
| 27 |
+
response = await client.post("/api/translate", json=payload)
|
| 28 |
+
assert response.status_code == 200
|
| 29 |
+
assert response.json()["translation"] == ""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.mark.asyncio
|
| 33 |
+
async def test_empty_input_list(client: AsyncClient):
|
| 34 |
+
"""Verify handling of empty list input."""
|
| 35 |
+
models_res = await client.get("/api/models")
|
| 36 |
+
model = models_res.json()["models"][0]
|
| 37 |
+
|
| 38 |
+
payload = {"src": [], "src_lang": model["src_lang"], "tgt_lang": model["tgt_lang"]}
|
| 39 |
+
response = await client.post("/api/translate", json=payload)
|
| 40 |
+
assert response.status_code == 200
|
| 41 |
+
assert response.json()["translation"] == []
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.mark.asyncio
|
| 45 |
+
async def test_invalid_input_type(client: AsyncClient):
|
| 46 |
+
"""Verify that invalid input types are rejected by Pydantic."""
|
| 47 |
+
payload = {
|
| 48 |
+
"src": 123, # Should be string or list of strings
|
| 49 |
+
"src_lang": "en",
|
| 50 |
+
"tgt_lang": "fr",
|
| 51 |
+
}
|
| 52 |
+
response = await client.post("/api/translate", json=payload)
|
| 53 |
+
assert response.status_code == 422 # Unprocessable Entity (Validation Error)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@pytest.mark.asyncio
|
| 57 |
+
async def test_concurrent_model_load(client: AsyncClient):
|
| 58 |
+
"""
|
| 59 |
+
Test that concurrent requests for a new model are handled correctly
|
| 60 |
+
(only one load should happen, others wait on the event).
|
| 61 |
+
"""
|
| 62 |
+
# Find a model that is definitely NOT loaded
|
| 63 |
+
health_res = await client.get("/api/health")
|
| 64 |
+
loaded = health_res.json()["loaded_models"]
|
| 65 |
+
|
| 66 |
+
models_res = await client.get("/api/models")
|
| 67 |
+
available = models_res.json()["models"]
|
| 68 |
+
|
| 69 |
+
target_model = None
|
| 70 |
+
for m in available:
|
| 71 |
+
lang_pair = f"{m['src_lang']}-{m['tgt_lang']}"
|
| 72 |
+
if lang_pair not in loaded:
|
| 73 |
+
target_model = m
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
if not target_model:
|
| 77 |
+
pytest.skip("No unloaded models available to test concurrent loading")
|
| 78 |
+
|
| 79 |
+
# Send multiple concurrent requests for the same new model
|
| 80 |
+
payload = {
|
| 81 |
+
"src": "Concurrent test",
|
| 82 |
+
"src_lang": target_model["src_lang"],
|
| 83 |
+
"tgt_lang": target_model["tgt_lang"],
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
tasks = [client.post("/api/translate", json=payload) for _ in range(3)]
|
| 87 |
+
responses = await asyncio.gather(*tasks)
|
| 88 |
+
|
| 89 |
+
for resp in responses:
|
| 90 |
+
assert resp.status_code == 200
|
| 91 |
+
assert "translation" in resp.json()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@pytest.mark.asyncio
|
| 95 |
+
async def test_parameter_overrides(client: AsyncClient):
|
| 96 |
+
"""Verify that request-level parameters are respected."""
|
| 97 |
+
models_res = await client.get("/api/models")
|
| 98 |
+
model = models_res.json()["models"][0]
|
| 99 |
+
|
| 100 |
+
payload = {
|
| 101 |
+
"src": "This is a test of parameter overrides.",
|
| 102 |
+
"src_lang": model["src_lang"],
|
| 103 |
+
"tgt_lang": model["tgt_lang"],
|
| 104 |
+
"beam_size": 1,
|
| 105 |
+
"max_decoding_length": 5,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
response = await client.post("/api/translate", json=payload)
|
| 109 |
+
assert response.status_code == 200
|
| 110 |
+
# With max_decoding_length=5, the translation should be very short
|
| 111 |
+
# Note: tokens != words, but usually it translates to 1-3 words
|
| 112 |
+
trans = response.json()["translation"]
|
| 113 |
+
# We can't strictly assert word count but we can check it's non-empty
|
| 114 |
+
assert len(trans) > 0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.mark.asyncio
|
| 118 |
+
async def test_large_batch_processing(client: AsyncClient):
|
| 119 |
+
"""Verify processing of a batch larger than MAX_BATCH_SIZE."""
|
| 120 |
+
models_res = await client.get("/api/models")
|
| 121 |
+
models = models_res.json()["models"]
|
| 122 |
+
if not models:
|
| 123 |
+
pytest.skip("No translation models available")
|
| 124 |
+
model = models[0]
|
| 125 |
+
|
| 126 |
+
# Send 50 sentences (default MAX_BATCH_SIZE is 32)
|
| 127 |
+
sentences = [f"This is sentence {i}" for i in range(50)]
|
| 128 |
+
payload = {
|
| 129 |
+
"src": sentences,
|
| 130 |
+
"src_lang": model["src_lang"],
|
| 131 |
+
"tgt_lang": model["tgt_lang"],
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
response = await client.post("/api/translate", json=payload)
|
| 135 |
+
assert response.status_code == 200
|
| 136 |
+
data = response.json()
|
| 137 |
+
assert len(data["translation"]) == 50
|
tests/test_threading_config.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
from quickmt.manager import ModelManager
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@pytest.mark.asyncio
|
| 7 |
+
async def test_threading_config_propagation():
|
| 8 |
+
"""Verify that inter_threads and intra_threads are passed to CTranslate2."""
|
| 9 |
+
|
| 10 |
+
# Mocking components to prevent actual model loading
|
| 11 |
+
with patch("quickmt.manager.Translator") as mock_translator_cls:
|
| 12 |
+
# Configuration
|
| 13 |
+
inter = 2
|
| 14 |
+
intra = 4
|
| 15 |
+
|
| 16 |
+
manager = ModelManager(
|
| 17 |
+
max_loaded=1,
|
| 18 |
+
device="cpu",
|
| 19 |
+
compute_type="int8",
|
| 20 |
+
inter_threads=inter,
|
| 21 |
+
intra_threads=intra,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Inject a dummy model to collection
|
| 25 |
+
manager.hf_collection_models = [
|
| 26 |
+
{"model_id": "test/model", "src_lang": "en", "tgt_lang": "fr"}
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Mock snapshot_download
|
| 30 |
+
with patch("quickmt.manager.snapshot_download", return_value="/tmp/model"):
|
| 31 |
+
# Trigger model load
|
| 32 |
+
await manager.get_model("en", "fr")
|
| 33 |
+
|
| 34 |
+
# Verify Translator was instantiated with correct parameters
|
| 35 |
+
args, kwargs = mock_translator_cls.call_args
|
| 36 |
+
assert kwargs["inter_threads"] == inter
|
| 37 |
+
assert kwargs["intra_threads"] == intra
|
| 38 |
+
assert kwargs["device"] == "cpu"
|
| 39 |
+
assert kwargs["compute_type"] == "int8"
|
tests/test_translation_quality.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
import sacrebleu
|
| 4 |
+
|
| 5 |
+
# 10 Diverse English -> French pairs
|
| 6 |
+
EN_FR_PAIRS = [
|
| 7 |
+
("Hello world", "Bonjour le monde"),
|
| 8 |
+
("The cat sits on the mat.", "Le chat est assis sur le tapis."),
|
| 9 |
+
("I would like a coffee, please.", "Je voudrais un café, s'il vous plaît."),
|
| 10 |
+
("Where is the nearest train station?", "Où est la gare la plus proche ?"),
|
| 11 |
+
(
|
| 12 |
+
"Artificial intelligence is fascinating.",
|
| 13 |
+
"L'intelligence artificielle est fascinante.",
|
| 14 |
+
),
|
| 15 |
+
("Can you help me translate this?", "Pouvez-vous m'aider à traduire ceci ?"),
|
| 16 |
+
("It is raining today.", "Il pleut aujourd'hui."),
|
| 17 |
+
("Programming is fun.", "La programmation est amusante."),
|
| 18 |
+
("I am learning French.", "J'apprends le français."),
|
| 19 |
+
("Have a nice day.", "Bonne journée."),
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# 10 Diverse French -> English pairs
|
| 23 |
+
FR_EN_PAIRS = [
|
| 24 |
+
("Bonjour tout le monde", "Hello everyone"),
|
| 25 |
+
("La vie est belle", "Life is beautiful"),
|
| 26 |
+
("Je suis fatigué", "I am tired"),
|
| 27 |
+
("Quelle heure est-il ?", "What time is it?"),
|
| 28 |
+
("J'aime manger des croissants", "I like eating croissants"),
|
| 29 |
+
("Merci beaucoup", "Thank you very much"),
|
| 30 |
+
("À bientôt", "See you soon"),
|
| 31 |
+
("Le livre est sur la table", "The book is on the table"),
|
| 32 |
+
("Je ne comprends pas", "I do not understand"),
|
| 33 |
+
("C'est magnifique", "It is magnificent"),
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def translate_batch(client, texts, src, tgt):
|
| 38 |
+
payload = {"src": texts, "src_lang": src, "tgt_lang": tgt}
|
| 39 |
+
response = await client.post("/api/translate", json=payload)
|
| 40 |
+
if response.status_code != 200:
|
| 41 |
+
return []
|
| 42 |
+
return response.json()["translation"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@pytest.mark.asyncio
|
| 46 |
+
async def test_quality_en_fr(client: AsyncClient):
|
| 47 |
+
"""Assess translation quality for English to French."""
|
| 48 |
+
# Check model availability first
|
| 49 |
+
models_res = await client.get("/api/models")
|
| 50 |
+
available_models = models_res.json()["models"]
|
| 51 |
+
if not any(
|
| 52 |
+
m["src_lang"] == "en" and m["tgt_lang"] == "fr" for m in available_models
|
| 53 |
+
):
|
| 54 |
+
pytest.skip("en-fr model needed for this test")
|
| 55 |
+
|
| 56 |
+
sources = [p[0] for p in EN_FR_PAIRS]
|
| 57 |
+
refs = [
|
| 58 |
+
[p[1] for p in EN_FR_PAIRS]
|
| 59 |
+
] # sacrebleu expects list of lists of references
|
| 60 |
+
|
| 61 |
+
hyps = await translate_batch(client, sources, "en", "fr")
|
| 62 |
+
assert len(hyps) == len(sources)
|
| 63 |
+
|
| 64 |
+
# Calculate BLEU
|
| 65 |
+
bleu = sacrebleu.corpus_bleu(hyps, refs)
|
| 66 |
+
chrf = sacrebleu.corpus_chrf(hyps, refs)
|
| 67 |
+
|
| 68 |
+
print(f"\nEN->FR Quality: BLEU={bleu.score:.2f}, CHRF={chrf.score:.2f}")
|
| 69 |
+
|
| 70 |
+
# Assert minimum quality (adjust baselines based on model capability)
|
| 71 |
+
# Generic models should at least get > 10 BLEU on simple sentences
|
| 72 |
+
assert bleu.score > 40.0
|
| 73 |
+
assert chrf.score > 70.0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@pytest.mark.asyncio
|
| 77 |
+
async def test_quality_fr_en(client: AsyncClient):
|
| 78 |
+
"""Assess translation quality for French to English."""
|
| 79 |
+
# Check model availability first
|
| 80 |
+
models_res = await client.get("/api/models")
|
| 81 |
+
available_models = models_res.json()["models"]
|
| 82 |
+
if not any(
|
| 83 |
+
m["src_lang"] == "fr" and m["tgt_lang"] == "en" for m in available_models
|
| 84 |
+
):
|
| 85 |
+
pytest.skip("fr-en model needed for this test")
|
| 86 |
+
|
| 87 |
+
sources = [p[0] for p in FR_EN_PAIRS]
|
| 88 |
+
refs = [[p[1] for p in FR_EN_PAIRS]]
|
| 89 |
+
|
| 90 |
+
hyps = await translate_batch(client, sources, "fr", "en")
|
| 91 |
+
assert len(hyps) == len(sources)
|
| 92 |
+
|
| 93 |
+
# Calculate BLEU
|
| 94 |
+
bleu = sacrebleu.corpus_bleu(hyps, refs)
|
| 95 |
+
chrf = sacrebleu.corpus_chrf(hyps, refs)
|
| 96 |
+
|
| 97 |
+
print(f"\nFR->EN Quality: BLEU={bleu.score:.2f}, CHRF={chrf.score:.2f}")
|
| 98 |
+
|
| 99 |
+
# Assert minimum quality
|
| 100 |
+
assert bleu.score > 40.0
|
| 101 |
+
assert chrf.score > 70.0
|
tests/test_translator.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
from quickmt.translator import Translator, TranslatorABC
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Mock objects
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def mock_ctranslate2():
|
| 10 |
+
with patch("ctranslate2.Translator") as mock:
|
| 11 |
+
yield mock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def mock_sentencepiece():
|
| 16 |
+
with patch("sentencepiece.SentencePieceProcessor") as mock:
|
| 17 |
+
yield mock
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture
|
| 21 |
+
def temp_model_dir(tmp_path):
|
| 22 |
+
"""Create a dummy model directory with required files."""
|
| 23 |
+
model_dir = tmp_path / "dummy-model"
|
| 24 |
+
model_dir.mkdir()
|
| 25 |
+
(model_dir / "src.spm.model").write_text("dummy")
|
| 26 |
+
(model_dir / "tgt.spm.model").write_text("dummy")
|
| 27 |
+
return model_dir
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture
|
| 31 |
+
def translator_instance(temp_model_dir, mock_ctranslate2, mock_sentencepiece):
|
| 32 |
+
return Translator(temp_model_dir)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestTranslatorABC:
|
| 36 |
+
def test_sentence_split(self):
|
| 37 |
+
src = ["Hello world. This is a test.", "Another paragraph."]
|
| 38 |
+
input_ids, paragraph_ids, sentences = TranslatorABC._sentence_split(src)
|
| 39 |
+
|
| 40 |
+
assert len(sentences) == 3
|
| 41 |
+
assert input_ids == [0, 0, 1]
|
| 42 |
+
assert paragraph_ids == [0, 0, 0]
|
| 43 |
+
assert sentences[0] == "Hello world."
|
| 44 |
+
assert sentences[1] == "This is a test."
|
| 45 |
+
assert sentences[2] == "Another paragraph."
|
| 46 |
+
|
| 47 |
+
def test_sentence_join(self):
|
| 48 |
+
input_ids = [0, 0, 1]
|
| 49 |
+
paragraph_ids = [0, 0, 0]
|
| 50 |
+
sentences = ["Hello world.", "This is a test.", "Another paragraph."]
|
| 51 |
+
|
| 52 |
+
joined = TranslatorABC._sentence_join(input_ids, paragraph_ids, sentences)
|
| 53 |
+
assert len(joined) == 2
|
| 54 |
+
assert joined[0] == "Hello world. This is a test."
|
| 55 |
+
assert joined[1] == "Another paragraph."
|
| 56 |
+
|
| 57 |
+
def test_sentence_join_empty(self):
|
| 58 |
+
assert TranslatorABC._sentence_join([], [], [], length=5) == [""] * 5
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TestTranslator:
|
| 62 |
+
def test_init_joint_tokens(self, tmp_path, mock_ctranslate2, mock_sentencepiece):
|
| 63 |
+
model_dir = tmp_path / "joint-model"
|
| 64 |
+
model_dir.mkdir()
|
| 65 |
+
(model_dir / "joint.spm.model").write_text("dummy")
|
| 66 |
+
|
| 67 |
+
translator = Translator(model_dir)
|
| 68 |
+
assert mock_sentencepiece.call_count == 2
|
| 69 |
+
# Verify it used the joint model for both
|
| 70 |
+
args, kwargs = mock_sentencepiece.call_args_list[0]
|
| 71 |
+
assert "joint.spm.model" in kwargs["model_file"]
|
| 72 |
+
|
| 73 |
+
def test_tokenize(self, translator_instance):
|
| 74 |
+
translator_instance.source_tokenizer.encode.return_value = [
|
| 75 |
+
["token1", "token2"]
|
| 76 |
+
]
|
| 77 |
+
result = translator_instance.tokenize(["Hello"])
|
| 78 |
+
assert result == [["token1", "token2", "</s>"]]
|
| 79 |
+
translator_instance.source_tokenizer.encode.assert_called_with(
|
| 80 |
+
["Hello"], out_type=str
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def test_detokenize(self, translator_instance):
|
| 84 |
+
translator_instance.target_tokenizer.decode.return_value = ["Hello"]
|
| 85 |
+
result = translator_instance.detokenize([["token1", "token2"]])
|
| 86 |
+
assert result == ["Hello"]
|
| 87 |
+
translator_instance.target_tokenizer.decode.assert_called_with(
|
| 88 |
+
[["token1", "token2"]]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def test_unload(self, translator_instance):
|
| 92 |
+
del translator_instance.translator
|
| 93 |
+
# Should not raise
|
| 94 |
+
translator_instance.unload()
|
| 95 |
+
|
| 96 |
+
def test_call_full_pipeline(self, translator_instance):
|
| 97 |
+
# Mock the steps
|
| 98 |
+
with (
|
| 99 |
+
patch.object(Translator, "tokenize") as mock_tok,
|
| 100 |
+
patch.object(Translator, "translate_batch") as mock_trans,
|
| 101 |
+
patch.object(Translator, "detokenize") as mock_detok,
|
| 102 |
+
):
|
| 103 |
+
mock_tok.return_value = [["tok"]]
|
| 104 |
+
mock_res = MagicMock()
|
| 105 |
+
mock_res.hypotheses = [["hypo"]]
|
| 106 |
+
mock_trans.return_value = [mock_res]
|
| 107 |
+
mock_detok.return_value = ["Translated sentence."]
|
| 108 |
+
|
| 109 |
+
result = translator_instance("Source text.")
|
| 110 |
+
assert result == "Translated sentence."
|
| 111 |
+
|
| 112 |
+
mock_tok.assert_called_once()
|
| 113 |
+
mock_trans.assert_called_once()
|
| 114 |
+
mock_detok.assert_called_once()
|
| 115 |
+
|
| 116 |
+
def test_translate_stream(self, translator_instance):
|
| 117 |
+
translator_instance.translator.translate_iterable = MagicMock(
|
| 118 |
+
return_value=[
|
| 119 |
+
MagicMock(hypotheses=[["hypo1"]]),
|
| 120 |
+
MagicMock(hypotheses=[["hypo2"]]),
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
with (
|
| 125 |
+
patch.object(Translator, "tokenize") as mock_tok,
|
| 126 |
+
patch.object(Translator, "detokenize") as mock_detok,
|
| 127 |
+
):
|
| 128 |
+
mock_tok.return_value = [["tok1"], ["tok2"]]
|
| 129 |
+
mock_detok.side_effect = lambda x: [f"Detok {x[0][0]}"]
|
| 130 |
+
|
| 131 |
+
results = list(translator_instance.translate_stream(["Sent 1.", "Sent 2."]))
|
| 132 |
+
assert len(results) == 2
|
| 133 |
+
assert results[0]["translation"] == "Detok hypo1"
|
| 134 |
+
assert results[1]["translation"] == "Detok hypo2"
|
| 135 |
+
|
| 136 |
+
def test_translate_file(self, translator_instance, tmp_path):
|
| 137 |
+
input_file = tmp_path / "input.txt"
|
| 138 |
+
output_file = tmp_path / "output.txt"
|
| 139 |
+
input_file.write_text("Line 1\nLine 2")
|
| 140 |
+
|
| 141 |
+
with patch.object(Translator, "__call__") as mock_call:
|
| 142 |
+
mock_call.return_value = ["Trans 1", "Trans 2"]
|
| 143 |
+
translator_instance.translate_file(str(input_file), str(output_file))
|
| 144 |
+
|
| 145 |
+
content = output_file.read_text()
|
| 146 |
+
assert content == "Trans 1\nTrans 2\n"
|
| 147 |
+
|
| 148 |
+
def test_translate_batch(self, translator_instance):
|
| 149 |
+
translator_instance.translate_batch(
|
| 150 |
+
[["tok"]],
|
| 151 |
+
beam_size=10,
|
| 152 |
+
patience=2,
|
| 153 |
+
max_batch_size=16,
|
| 154 |
+
num_hypotheses=5, # kwargs
|
| 155 |
+
)
|
| 156 |
+
translator_instance.translator.translate_batch.assert_called_once()
|
| 157 |
+
args, kwargs = translator_instance.translator.translate_batch.call_args
|
| 158 |
+
assert kwargs["beam_size"] == 10
|
| 159 |
+
assert kwargs["patience"] == 2
|
| 160 |
+
assert kwargs["max_batch_size"] == 16
|
| 161 |
+
assert kwargs["num_hypotheses"] == 5
|