videoNote / backend /app /services /article.py
zhoujiaangyao
feat(db): 配置与笔记迁入 Postgres,重启不丢
aa08cd6
Raw
History Blame Contribute Delete
13.4 kB
from __future__ import annotations
import json
import os
import uuid
from pathlib import Path
from typing import Callable
from app.article_fetchers.base import ArticleContent, ArticleFetcher
from app.article_fetchers.generic import GenericArticleFetcher
from app.article_fetchers.wechat import WechatArticleFetcher
from app.article_fetchers.xiaohongshu import XiaohongshuArticleFetcher
from app.db.note_dao import load_note, save_note, set_status
from app.db.article_dao import (
create_subscription,
get_article_item,
get_subscription,
link_subscription_item,
list_article_items,
list_subscriptions,
mark_article_summarized,
update_subscription_refresh,
upsert_article_item,
)
from app.enmus.task_status_enums import TaskStatus
from app.gpt.gpt_factory import GPTFactory
from app.models.gpt_model import GPTSource
from app.models.model_config import ModelConfig
from app.models.transcriber_model import TranscriptSegment
from app.services.provider import ProviderService
def _note_output_dir() -> Path:
path = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results"))
path.mkdir(parents=True, exist_ok=True)
return path
class ArticleService:
def __init__(
self,
fetchers: dict[str, ArticleFetcher] | None = None,
gpt_factory: Callable[[str, str], object] | None = None,
):
self.fetchers = fetchers or {
"wechat_mp": WechatArticleFetcher(),
"xiaohongshu": XiaohongshuArticleFetcher(),
"generic_web": GenericArticleFetcher(),
}
self.gpt_factory = gpt_factory or self._create_gpt
def generate_from_url(
self,
url: str,
platform: str,
provider_id: str,
model_name: str,
style: str = "",
extras: str = "",
task_id: str | None = None,
) -> dict:
task_id = task_id or str(uuid.uuid4())
try:
self._update_status(task_id, TaskStatus.PARSING)
article = self._fetcher(platform).fetch(url)
item = upsert_article_item(article)
self._update_status(task_id, TaskStatus.TRANSCRIBING)
gpt = self.gpt_factory(model_name, provider_id)
markdown = gpt.summarize(
GPTSource(
segment=self._segments(article),
title=article.title,
tags="article",
style=style,
extras=extras,
)
)
self._update_status(task_id, TaskStatus.SAVING)
self._write_note_json(
task_id,
article,
markdown,
int(getattr(gpt, "total_tokens", 0) or 0),
)
mark_article_summarized(item.id, task_id)
self._update_status(task_id, TaskStatus.SUCCESS)
self._index_task(task_id)
return {"task_id": task_id, "article_item_id": item.id}
except Exception:
self._update_status(task_id, TaskStatus.FAILED)
raise
def generate_from_content(
self,
url: str,
platform: str,
title: str,
content_text: str,
provider_id: str,
model_name: str,
style: str = "",
extras: str = "",
author_name: str = "",
task_id: str | None = None,
) -> dict:
body = (content_text or "").strip()
if len(body) < 20:
raise ValueError("导入正文过短,无法生成总结")
task_id = task_id or str(uuid.uuid4())
try:
self._update_status(task_id, TaskStatus.PARSING)
article = ArticleContent(
platform=platform or "generic_web",
url=url or f"manual://{task_id}",
article_id=url or task_id,
title=(title or "").strip() or "导入文章",
author_name=author_name,
content_text=body,
raw_metadata={"source": "manual_import"},
)
item = upsert_article_item(article)
self._update_status(task_id, TaskStatus.TRANSCRIBING)
gpt = self.gpt_factory(model_name, provider_id)
markdown = gpt.summarize(
GPTSource(
segment=self._segments(article),
title=article.title,
tags="article",
style=style,
extras=extras,
)
)
self._update_status(task_id, TaskStatus.SAVING)
self._write_note_json(
task_id,
article,
markdown,
int(getattr(gpt, "total_tokens", 0) or 0),
)
mark_article_summarized(item.id, task_id)
self._update_status(task_id, TaskStatus.SUCCESS)
self._index_task(task_id)
return {"task_id": task_id, "article_item_id": item.id}
except Exception:
self._update_status(task_id, TaskStatus.FAILED)
raise
def fetch_only_from_url(self, url: str, platform: str) -> dict:
article = self._fetcher(platform).fetch(url)
item = upsert_article_item(article)
return self._item_payload(item, include_content=True)
def import_only_content(
self,
url: str,
platform: str,
title: str,
content_text: str,
author_name: str = "",
) -> dict:
body = (content_text or "").strip()
if len(body) < 20:
raise ValueError("导入正文过短")
article_id = url or str(uuid.uuid4())
article = ArticleContent(
platform=platform or "generic_web",
url=url or f"manual://{article_id}",
article_id=article_id,
title=(title or "").strip() or "导入文章",
author_name=author_name,
content_text=body,
raw_metadata={"source": "manual_import"},
)
item = upsert_article_item(article)
return self._item_payload(item, include_content=True)
def search(self, platform: str, keyword: str, limit: int = 20) -> dict:
articles = self._fetcher(platform).search(keyword, limit)
items = [upsert_article_item(article) for article in articles]
return {
"platform": platform,
"keyword": keyword,
"status": "ok",
"message": "",
"items": [self._item_payload(item) for item in items],
}
def refresh_subscription(self, subscription_id: int, limit: int = 20) -> dict:
subscription = get_subscription(subscription_id)
if not subscription:
raise ValueError("订阅不存在")
fetcher = self._fetcher(subscription.platform)
if subscription.type == "publisher":
articles = fetcher.fetch_publisher(subscription.query, limit)
reason = f"publisher:{subscription.query}"
else:
articles = fetcher.search(subscription.query, limit)
reason = f"keyword:{subscription.query}"
items = []
for article in articles:
item = upsert_article_item(article)
link_subscription_item(subscription.id, item.id, reason)
items.append(item)
update_subscription_refresh(subscription.id)
return {
"subscription_id": subscription.id,
"count": len(items),
"items": [self._item_payload(item) for item in items],
}
def summarize_item(
self,
item_id: int,
provider_id: str,
model_name: str,
style: str = "",
extras: str = "",
) -> dict:
item = get_article_item(item_id)
if not item:
raise ValueError("文章不存在")
if item.task_id and item.summary_status == "summarized":
return {"task_id": item.task_id, "article_item_id": item.id}
return self.generate_from_url(
url=item.url,
platform=item.platform,
provider_id=provider_id,
model_name=model_name,
style=style,
extras=extras,
)
def list_items(self, subscription_id: int | None = None) -> list[dict]:
return [self._item_payload(item) for item in list_article_items(subscription_id)]
def get_item(self, item_id: int) -> dict:
item = get_article_item(item_id)
if not item:
raise ValueError("文章不存在")
return self._item_payload(item, include_content=True)
def create_subscription(
self,
platform: str,
subscription_type: str,
query: str,
label: str = "",
) -> dict:
subscription = create_subscription(platform, subscription_type, query, label)
return self._subscription_payload(subscription)
def list_subscriptions(self) -> list[dict]:
return [self._subscription_payload(item) for item in list_subscriptions()]
def _fetcher(self, platform: str) -> ArticleFetcher:
if platform not in self.fetchers:
raise ValueError(f"不支持的文章平台:{platform}")
return self.fetchers[platform]
def _item_payload(self, item, include_content: bool = False) -> dict:
payload = {
"id": item.id,
"platform": item.platform,
"title": item.title,
"url": item.url,
"author_name": item.author_name,
"author_id": item.author_id,
"cover_url": item.cover_url,
"published_at": item.published_at,
"summary_status": item.summary_status,
"task_id": item.task_id,
}
if include_content:
payload["content_text"] = (getattr(item, "content_text", "") or "").strip()
if not payload["content_text"] and item.task_id:
payload["content_text"] = self._content_from_note_result(item.task_id)
return payload
def _content_from_note_result(self, task_id: str) -> str:
if not task_id:
return ""
payload = load_note(task_id)
if not payload:
return ""
transcript = payload.get("transcript") or {}
return str(transcript.get("full_text") or "").strip()
def _subscription_payload(self, item) -> dict:
return {
"id": item.id,
"platform": item.platform,
"type": item.type,
"query": item.query,
"label": item.label,
"enabled": item.enabled,
"last_error": item.last_error,
}
def _create_gpt(self, model_name: str, provider_id: str):
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
raise ValueError("请选择模型和提供者")
return GPTFactory().from_config(
ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
model_name=model_name,
provider=provider["type"],
name=provider["name"],
)
)
def _segments(self, article: ArticleContent) -> list[TranscriptSegment]:
paragraphs = [p.strip() for p in article.content_text.splitlines() if p.strip()]
if not paragraphs and article.content_text.strip():
paragraphs = [article.content_text.strip()]
return [
TranscriptSegment(start=float(index), end=float(index + 1), text=text)
for index, text in enumerate(paragraphs)
]
def _write_note_json(
self,
task_id: str,
article: ArticleContent,
markdown: str,
total_tokens: int,
) -> None:
segments = self._segments(article)
payload = {
"markdown": markdown,
"transcript": {
"language": "zh",
"full_text": article.content_text,
"segments": [
{"start": segment.start, "end": segment.end, "text": segment.text}
for segment in segments
],
},
"audio_meta": {
"file_path": "",
"title": article.title,
"duration": 0,
"cover_url": article.cover_url,
"platform": article.platform,
"video_id": article.article_id,
"raw_info": {
"source_type": "article",
"url": article.url,
"author_name": article.author_name,
"author_id": article.author_id,
"published_at": article.published_at,
"image_urls": article.image_urls,
**(article.raw_metadata or {}),
},
"video_path": None,
},
"total_tokens": total_tokens,
}
save_note(task_id, payload)
def _update_status(self, task_id: str, status: TaskStatus) -> None:
set_status(task_id, {"status": status.value, "paused": False})
def _index_task(self, task_id: str) -> None:
try:
from app.services.vector_store import VectorStoreManager
VectorStoreManager().index_task(task_id)
except Exception:
pass