from __future__ import annotations import json import os import uuid from pathlib import Path from typing import Callable from app.article_fetchers.base import ArticleContent, ArticleFetcher from app.article_fetchers.generic import GenericArticleFetcher from app.article_fetchers.wechat import WechatArticleFetcher from app.article_fetchers.xiaohongshu import XiaohongshuArticleFetcher from app.db.note_dao import load_note, save_note, set_status from app.db.article_dao import ( create_subscription, get_article_item, get_subscription, link_subscription_item, list_article_items, list_subscriptions, mark_article_summarized, update_subscription_refresh, upsert_article_item, ) from app.enmus.task_status_enums import TaskStatus from app.gpt.gpt_factory import GPTFactory from app.models.gpt_model import GPTSource from app.models.model_config import ModelConfig from app.models.transcriber_model import TranscriptSegment from app.services.provider import ProviderService def _note_output_dir() -> Path: path = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results")) path.mkdir(parents=True, exist_ok=True) return path class ArticleService: def __init__( self, fetchers: dict[str, ArticleFetcher] | None = None, gpt_factory: Callable[[str, str], object] | None = None, ): self.fetchers = fetchers or { "wechat_mp": WechatArticleFetcher(), "xiaohongshu": XiaohongshuArticleFetcher(), "generic_web": GenericArticleFetcher(), } self.gpt_factory = gpt_factory or self._create_gpt def generate_from_url( self, url: str, platform: str, provider_id: str, model_name: str, style: str = "", extras: str = "", task_id: str | None = None, ) -> dict: task_id = task_id or str(uuid.uuid4()) try: self._update_status(task_id, TaskStatus.PARSING) article = self._fetcher(platform).fetch(url) item = upsert_article_item(article) self._update_status(task_id, TaskStatus.TRANSCRIBING) gpt = self.gpt_factory(model_name, provider_id) markdown = gpt.summarize( GPTSource( segment=self._segments(article), title=article.title, tags="article", style=style, extras=extras, ) ) self._update_status(task_id, TaskStatus.SAVING) self._write_note_json( task_id, article, markdown, int(getattr(gpt, "total_tokens", 0) or 0), ) mark_article_summarized(item.id, task_id) self._update_status(task_id, TaskStatus.SUCCESS) self._index_task(task_id) return {"task_id": task_id, "article_item_id": item.id} except Exception: self._update_status(task_id, TaskStatus.FAILED) raise def generate_from_content( self, url: str, platform: str, title: str, content_text: str, provider_id: str, model_name: str, style: str = "", extras: str = "", author_name: str = "", task_id: str | None = None, ) -> dict: body = (content_text or "").strip() if len(body) < 20: raise ValueError("导入正文过短,无法生成总结") task_id = task_id or str(uuid.uuid4()) try: self._update_status(task_id, TaskStatus.PARSING) article = ArticleContent( platform=platform or "generic_web", url=url or f"manual://{task_id}", article_id=url or task_id, title=(title or "").strip() or "导入文章", author_name=author_name, content_text=body, raw_metadata={"source": "manual_import"}, ) item = upsert_article_item(article) self._update_status(task_id, TaskStatus.TRANSCRIBING) gpt = self.gpt_factory(model_name, provider_id) markdown = gpt.summarize( GPTSource( segment=self._segments(article), title=article.title, tags="article", style=style, extras=extras, ) ) self._update_status(task_id, TaskStatus.SAVING) self._write_note_json( task_id, article, markdown, int(getattr(gpt, "total_tokens", 0) or 0), ) mark_article_summarized(item.id, task_id) self._update_status(task_id, TaskStatus.SUCCESS) self._index_task(task_id) return {"task_id": task_id, "article_item_id": item.id} except Exception: self._update_status(task_id, TaskStatus.FAILED) raise def fetch_only_from_url(self, url: str, platform: str) -> dict: article = self._fetcher(platform).fetch(url) item = upsert_article_item(article) return self._item_payload(item, include_content=True) def import_only_content( self, url: str, platform: str, title: str, content_text: str, author_name: str = "", ) -> dict: body = (content_text or "").strip() if len(body) < 20: raise ValueError("导入正文过短") article_id = url or str(uuid.uuid4()) article = ArticleContent( platform=platform or "generic_web", url=url or f"manual://{article_id}", article_id=article_id, title=(title or "").strip() or "导入文章", author_name=author_name, content_text=body, raw_metadata={"source": "manual_import"}, ) item = upsert_article_item(article) return self._item_payload(item, include_content=True) def search(self, platform: str, keyword: str, limit: int = 20) -> dict: articles = self._fetcher(platform).search(keyword, limit) items = [upsert_article_item(article) for article in articles] return { "platform": platform, "keyword": keyword, "status": "ok", "message": "", "items": [self._item_payload(item) for item in items], } def refresh_subscription(self, subscription_id: int, limit: int = 20) -> dict: subscription = get_subscription(subscription_id) if not subscription: raise ValueError("订阅不存在") fetcher = self._fetcher(subscription.platform) if subscription.type == "publisher": articles = fetcher.fetch_publisher(subscription.query, limit) reason = f"publisher:{subscription.query}" else: articles = fetcher.search(subscription.query, limit) reason = f"keyword:{subscription.query}" items = [] for article in articles: item = upsert_article_item(article) link_subscription_item(subscription.id, item.id, reason) items.append(item) update_subscription_refresh(subscription.id) return { "subscription_id": subscription.id, "count": len(items), "items": [self._item_payload(item) for item in items], } def summarize_item( self, item_id: int, provider_id: str, model_name: str, style: str = "", extras: str = "", ) -> dict: item = get_article_item(item_id) if not item: raise ValueError("文章不存在") if item.task_id and item.summary_status == "summarized": return {"task_id": item.task_id, "article_item_id": item.id} return self.generate_from_url( url=item.url, platform=item.platform, provider_id=provider_id, model_name=model_name, style=style, extras=extras, ) def list_items(self, subscription_id: int | None = None) -> list[dict]: return [self._item_payload(item) for item in list_article_items(subscription_id)] def get_item(self, item_id: int) -> dict: item = get_article_item(item_id) if not item: raise ValueError("文章不存在") return self._item_payload(item, include_content=True) def create_subscription( self, platform: str, subscription_type: str, query: str, label: str = "", ) -> dict: subscription = create_subscription(platform, subscription_type, query, label) return self._subscription_payload(subscription) def list_subscriptions(self) -> list[dict]: return [self._subscription_payload(item) for item in list_subscriptions()] def _fetcher(self, platform: str) -> ArticleFetcher: if platform not in self.fetchers: raise ValueError(f"不支持的文章平台:{platform}") return self.fetchers[platform] def _item_payload(self, item, include_content: bool = False) -> dict: payload = { "id": item.id, "platform": item.platform, "title": item.title, "url": item.url, "author_name": item.author_name, "author_id": item.author_id, "cover_url": item.cover_url, "published_at": item.published_at, "summary_status": item.summary_status, "task_id": item.task_id, } if include_content: payload["content_text"] = (getattr(item, "content_text", "") or "").strip() if not payload["content_text"] and item.task_id: payload["content_text"] = self._content_from_note_result(item.task_id) return payload def _content_from_note_result(self, task_id: str) -> str: if not task_id: return "" payload = load_note(task_id) if not payload: return "" transcript = payload.get("transcript") or {} return str(transcript.get("full_text") or "").strip() def _subscription_payload(self, item) -> dict: return { "id": item.id, "platform": item.platform, "type": item.type, "query": item.query, "label": item.label, "enabled": item.enabled, "last_error": item.last_error, } def _create_gpt(self, model_name: str, provider_id: str): provider = ProviderService.get_provider_by_id(provider_id) if not provider: raise ValueError("请选择模型和提供者") return GPTFactory().from_config( ModelConfig( api_key=provider["api_key"], base_url=provider["base_url"], model_name=model_name, provider=provider["type"], name=provider["name"], ) ) def _segments(self, article: ArticleContent) -> list[TranscriptSegment]: paragraphs = [p.strip() for p in article.content_text.splitlines() if p.strip()] if not paragraphs and article.content_text.strip(): paragraphs = [article.content_text.strip()] return [ TranscriptSegment(start=float(index), end=float(index + 1), text=text) for index, text in enumerate(paragraphs) ] def _write_note_json( self, task_id: str, article: ArticleContent, markdown: str, total_tokens: int, ) -> None: segments = self._segments(article) payload = { "markdown": markdown, "transcript": { "language": "zh", "full_text": article.content_text, "segments": [ {"start": segment.start, "end": segment.end, "text": segment.text} for segment in segments ], }, "audio_meta": { "file_path": "", "title": article.title, "duration": 0, "cover_url": article.cover_url, "platform": article.platform, "video_id": article.article_id, "raw_info": { "source_type": "article", "url": article.url, "author_name": article.author_name, "author_id": article.author_id, "published_at": article.published_at, "image_urls": article.image_urls, **(article.raw_metadata or {}), }, "video_path": None, }, "total_tokens": total_tokens, } save_note(task_id, payload) def _update_status(self, task_id: str, status: TaskStatus) -> None: set_status(task_id, {"status": status.value, "paused": False}) def _index_task(self, task_id: str) -> None: try: from app.services.vector_store import VectorStoreManager VectorStoreManager().index_task(task_id) except Exception: pass