Spaces:
Sleeping
Sleeping
File size: 2,827 Bytes
b65f9e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | from __future__ import annotations
from pathlib import Path
from fastapi.testclient import TestClient
from zai2api.config import Settings
from zai2api.db import Database
from zai2api.server import create_app
from zai2api.zai_client import UpstreamChunk
class ErroringStreamPool:
def __init__(self, *, mode: str):
self.mode = mode
async def collect_prompt(self, **_: object):
raise AssertionError("collect_prompt should not be used in streaming tests")
async def stream_prompt(self, **_: object):
if self.mode == "chunk_error":
yield UpstreamChunk(phase=None, text="", error="upstream said no")
return
raise RuntimeError("upstream stream exploded")
def make_settings(tmp_path: Path) -> Settings:
return Settings(
host="127.0.0.1",
port=8000,
log_level="info",
zai_base_url="https://chat.z.ai",
zai_jwt=None,
zai_session_token=None,
default_model="glm-5",
request_timeout=120.0,
database_path=str(tmp_path / "state.db"),
panel_password_env=None,
api_password_env=None,
admin_cookie_name="zai2api_admin_session",
admin_session_ttl_hours=24,
admin_cookie_secure=False,
)
def test_chat_completion_stream_reports_error_in_band(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
app = create_app(settings, prompt_pool=ErroringStreamPool(mode="chunk_error"))
with TestClient(app) as client:
response = client.post(
"/v1/chat/completions",
json={"model": "glm-5", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
)
assert response.status_code == 200
assert '"type": "upstream_error"' in response.text
assert '"message": "upstream said no"' in response.text
assert "data: [DONE]" in response.text
logs = Database(settings.database_path).list_logs(limit=20)
assert any(item.message == "Streaming chat completion request failed" for item in logs)
def test_responses_stream_reports_runtime_error_in_band(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
app = create_app(settings, prompt_pool=ErroringStreamPool(mode="runtime_error"))
with TestClient(app) as client:
response = client.post(
"/v1/responses",
json={"model": "glm-5", "stream": True, "input": "hi"},
)
assert response.status_code == 200
assert '"type": "response.failed"' in response.text
assert '"status": "failed"' in response.text
assert '"message": "upstream stream exploded"' in response.text
assert "data: [DONE]" in response.text
logs = Database(settings.database_path).list_logs(limit=20)
assert any(item.message == "Streaming responses request failed" for item in logs)
|