//! User Token Database Module //! UserToken 数据库操作模块 #![allow(dead_code)] // 用户令牌存储,部分接口留作后续扩展 use rusqlite::{params, Connection, OptionalExtension}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use uuid::Uuid; use chrono::{Utc, Local, Timelike, FixedOffset}; /// 用户令牌结构体 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserToken { pub id: String, pub token: String, pub username: String, pub description: Option, pub enabled: bool, pub expires_type: String, // "day", "week", "month", "never" pub expires_at: Option, pub max_ips: i32, // 0 = unlimited pub curfew_start: Option, // "HH:MM" 宵禁开始时间 pub curfew_end: Option, // "HH:MM" 宵禁结束时间 pub created_at: i64, pub updated_at: i64, pub last_used_at: Option, pub total_requests: i64, pub total_tokens_used: i64, } /// 令牌 IP 绑定结构体 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenIpBinding { pub id: String, pub token_id: String, pub ip_address: String, pub first_seen_at: i64, pub last_seen_at: i64, pub request_count: i64, pub user_agent: Option, } /// 令牌使用日志结构体 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenUsageLog { pub id: String, pub token_id: String, pub ip_address: String, pub model: String, pub input_tokens: i32, pub output_tokens: i32, pub request_time: i64, pub status: u16, } /// 获取数据库路径 pub fn get_db_path() -> Result { let mut path = crate::modules::account::get_data_dir()?; path.push("user_tokens.db"); Ok(path) } /// 连接数据库 pub fn connect_db() -> Result { let path = get_db_path()?; let conn = Connection::open(&path) .map_err(|e| format!("Failed to open database: {}", e))?; Ok(conn) } /// 初始化数据库 pub fn init_db() -> Result<(), String> { let conn = connect_db()?; // 创建 user_tokens 表 conn.execute( "CREATE TABLE IF NOT EXISTS user_tokens ( id TEXT PRIMARY KEY, token TEXT UNIQUE NOT NULL, username TEXT NOT NULL, description TEXT, enabled BOOLEAN NOT NULL DEFAULT 1, expires_type TEXT NOT NULL, expires_at INTEGER, max_ips INTEGER NOT NULL DEFAULT 0, created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL, last_used_at INTEGER, total_requests INTEGER NOT NULL DEFAULT 0, total_tokens_used INTEGER NOT NULL DEFAULT 0, curfew_start TEXT, curfew_end TEXT )", [], ).map_err(|e| format!("Failed to create user_tokens table: {}", e))?; // 尝试添加新列 (用于旧数据库迁移,忽略已存在的错误) let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN expires_type TEXT", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN expires_at INTEGER", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN max_ips INTEGER DEFAULT 0", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN total_requests INTEGER DEFAULT 0", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN total_tokens_used INTEGER DEFAULT 0", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN last_used_at INTEGER", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN curfew_start TEXT", []); let _ = conn.execute("ALTER TABLE user_tokens ADD COLUMN curfew_end TEXT", []); // 创建 token_ip_bindings 表 conn.execute( "CREATE TABLE IF NOT EXISTS token_ip_bindings ( id TEXT PRIMARY KEY, token_id TEXT NOT NULL, ip_address TEXT NOT NULL, first_seen_at INTEGER NOT NULL, last_seen_at INTEGER NOT NULL, request_count INTEGER NOT NULL DEFAULT 0, user_agent TEXT, FOREIGN KEY(token_id) REFERENCES user_tokens(id) ON DELETE CASCADE, UNIQUE(token_id, ip_address) )", [], ).map_err(|e| format!("Failed to create token_ip_bindings table: {}", e))?; // 创建 token_usage_logs 表 conn.execute( "CREATE TABLE IF NOT EXISTS token_usage_logs ( id TEXT PRIMARY KEY, token_id TEXT NOT NULL, ip_address TEXT, model TEXT, input_tokens INTEGER, output_tokens INTEGER, request_time INTEGER NOT NULL, status INTEGER, FOREIGN KEY(token_id) REFERENCES user_tokens(id) ON DELETE CASCADE )", [], ).map_err(|e| format!("Failed to create token_usage_logs table: {}", e))?; // 创建索引 let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_token_usage_logs_token_id ON token_usage_logs(token_id)", []); let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_token_usage_logs_request_time ON token_usage_logs(request_time)", []); // [FIX Issue #1719] 数据清洗:修复旧版本升级导致的 NULL 字段 // 这些字段在旧版本中可能不存在,ALTER TABLE 添加后默认为 NULL,导致反序列化失败 let _ = conn.execute("UPDATE user_tokens SET expires_type = 'never' WHERE expires_type IS NULL OR expires_type = ''", []); let _ = conn.execute("UPDATE user_tokens SET max_ips = 0 WHERE max_ips IS NULL", []); let _ = conn.execute("UPDATE user_tokens SET total_requests = 0 WHERE total_requests IS NULL", []); let _ = conn.execute("UPDATE user_tokens SET total_tokens_used = 0 WHERE total_tokens_used IS NULL", []); let _ = conn.execute("UPDATE user_tokens SET enabled = 1 WHERE enabled IS NULL", []); Ok(()) } /// 创建新令牌 pub fn create_token( username: String, expires_type: String, description: Option, max_ips: i32, curfew_start: Option, curfew_end: Option, custom_expires_at: Option // 自定义过期时间戳 (秒) ) -> Result { let conn = connect_db()?; let id = Uuid::new_v4().to_string(); let token = format!("sk-{}", Uuid::new_v4().to_string().replace("-", "")); let now = Utc::now().timestamp(); let expires_at = match expires_type.as_str() { "day" => Some(Utc::now().checked_add_signed(chrono::Duration::days(1)).unwrap().timestamp()), "week" => Some(Utc::now().checked_add_signed(chrono::Duration::weeks(1)).unwrap().timestamp()), "month" => Some(Utc::now().checked_add_signed(chrono::Duration::days(30)).unwrap().timestamp()), "custom" => custom_expires_at, // 使用自定义时间戳 _ => None, // "never" or other }; let user_token = UserToken { id: id.clone(), token: token.clone(), username: username.clone(), description: description.clone(), enabled: true, expires_type: expires_type.clone(), expires_at, max_ips, curfew_start: curfew_start.clone(), curfew_end: curfew_end.clone(), created_at: now, updated_at: now, last_used_at: None, total_requests: 0, total_tokens_used: 0, }; conn.execute( "INSERT INTO user_tokens ( id, token, username, description, enabled, expires_type, expires_at, max_ips, curfew_start, curfew_end, created_at, updated_at, total_requests, total_tokens_used ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)", params![ user_token.id, user_token.token, user_token.username, user_token.description, user_token.enabled, user_token.expires_type, user_token.expires_at, user_token.max_ips, user_token.curfew_start, user_token.curfew_end, user_token.created_at, user_token.updated_at, user_token.total_requests, user_token.total_tokens_used, ], ).map_err(|e| format!("Failed to insert user token: {}", e))?; Ok(user_token) } /// 列出所有令牌 pub fn list_tokens() -> Result, String> { let conn = connect_db()?; let mut stmt = conn.prepare("SELECT * FROM user_tokens ORDER BY created_at DESC") .map_err(|e| format!("Failed to prepare query: {}", e))?; let token_iter = stmt.query_map([], |row| { Ok(UserToken { id: row.get("id")?, token: row.get("token")?, username: row.get("username")?, description: row.get("description")?, enabled: row.get("enabled").unwrap_or(true), // 防御性默认值 expires_type: row.get("expires_type").unwrap_or("never".to_string()), // 防御性默认值 expires_at: row.get("expires_at").unwrap_or(None), max_ips: row.get("max_ips").unwrap_or(0), curfew_start: row.get("curfew_start").unwrap_or(None), curfew_end: row.get("curfew_end").unwrap_or(None), created_at: row.get("created_at")?, updated_at: row.get("updated_at")?, last_used_at: row.get("last_used_at").unwrap_or(None), total_requests: row.get("total_requests").unwrap_or(0), total_tokens_used: row.get("total_tokens_used").unwrap_or(0), }) }).map_err(|e| format!("Failed to query tokens: {}", e))?; let mut tokens = Vec::new(); for token in token_iter { tokens.push(token.map_err(|e| format!("Failed to parse token row: {}", e))?); } Ok(tokens) } /// 获取单个令牌信息 pub fn get_token_by_id(id: &str) -> Result, String> { let conn = connect_db()?; let mut stmt = conn.prepare("SELECT * FROM user_tokens WHERE id = ?1") .map_err(|e| format!("Failed to prepare query: {}", e))?; let token = stmt.query_row(params![id], |row| { Ok(UserToken { id: row.get("id")?, token: row.get("token")?, username: row.get("username")?, description: row.get("description")?, enabled: row.get("enabled")?, expires_type: row.get("expires_type")?, expires_at: row.get("expires_at")?, max_ips: row.get("max_ips")?, curfew_start: row.get("curfew_start").unwrap_or(None), curfew_end: row.get("curfew_end").unwrap_or(None), created_at: row.get("created_at")?, updated_at: row.get("updated_at")?, last_used_at: row.get("last_used_at")?, total_requests: row.get("total_requests")?, total_tokens_used: row.get("total_tokens_used")?, }) }).optional().map_err(|e| format!("Failed to query token: {}", e))?; Ok(token) } /// 根据 Token 值获取令牌信息 pub fn get_token_by_value(token: &str) -> Result, String> { let conn = connect_db()?; let mut stmt = conn.prepare("SELECT * FROM user_tokens WHERE token = ?1") .map_err(|e| format!("Failed to prepare query: {}", e))?; let token = stmt.query_row(params![token], |row| { Ok(UserToken { id: row.get("id")?, token: row.get("token")?, username: row.get("username")?, description: row.get("description")?, enabled: row.get("enabled")?, expires_type: row.get("expires_type")?, expires_at: row.get("expires_at")?, max_ips: row.get("max_ips")?, curfew_start: row.get("curfew_start").unwrap_or(None), curfew_end: row.get("curfew_end").unwrap_or(None), created_at: row.get("created_at")?, updated_at: row.get("updated_at")?, last_used_at: row.get("last_used_at")?, total_requests: row.get("total_requests")?, total_tokens_used: row.get("total_tokens_used")?, }) }).optional().map_err(|e| format!("Failed to query token: {}", e))?; Ok(token) } /// 更新令牌状态/备注等 pub fn update_token( id: &str, username: Option, description: Option, enabled: Option, max_ips: Option, curfew_start: Option>, curfew_end: Option> ) -> Result<(), String> { let conn = connect_db()?; let now = Utc::now().timestamp(); let mut query = "UPDATE user_tokens SET updated_at = ?1".to_string(); let mut params_vec: Vec> = vec![Box::new(now)]; let mut param_idx = 2; if let Some(user) = username { query.push_str(&format!(", username = ?{}", param_idx)); params_vec.push(Box::new(user)); param_idx += 1; } if let Some(desc) = description { query.push_str(&format!(", description = ?{}", param_idx)); params_vec.push(Box::new(desc)); param_idx += 1; } if let Some(en) = enabled { query.push_str(&format!(", enabled = ?{}", param_idx)); params_vec.push(Box::new(en)); param_idx += 1; } if let Some(ips) = max_ips { query.push_str(&format!(", max_ips = ?{}", param_idx)); params_vec.push(Box::new(ips)); param_idx += 1; } if let Some(start) = curfew_start { query.push_str(&format!(", curfew_start = ?{}", param_idx)); params_vec.push(Box::new(start)); param_idx += 1; } if let Some(end) = curfew_end { query.push_str(&format!(", curfew_end = ?{}", param_idx)); params_vec.push(Box::new(end)); param_idx += 1; } query.push_str(&format!(" WHERE id = ?{}", param_idx)); params_vec.push(Box::new(id.to_string())); // 将 Vec> 转换为 &[&dyn ToSql] let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect(); conn.execute(&query, params_refs.as_slice()) .map_err(|e| format!("Failed to update user token: {}", e))?; Ok(()) } /// 续期令牌 pub fn renew_token(id: &str, expires_type: &str) -> Result<(), String> { let conn = connect_db()?; let now = Utc::now().timestamp(); let expires_at = match expires_type { "day" => Some(Utc::now().checked_add_signed(chrono::Duration::days(1)).unwrap().timestamp()), "week" => Some(Utc::now().checked_add_signed(chrono::Duration::weeks(1)).unwrap().timestamp()), "month" => Some(Utc::now().checked_add_signed(chrono::Duration::days(30)).unwrap().timestamp()), _ => None, // "never" or other }; conn.execute( "UPDATE user_tokens SET expires_type = ?1, expires_at = ?2, updated_at = ?3, enabled = 1 WHERE id = ?4", params![expires_type, expires_at, now, id], ).map_err(|e| format!("Failed to renew token: {}", e))?; Ok(()) } /// 删除令牌 pub fn delete_token(id: &str) -> Result<(), String> { let conn = connect_db()?; conn.execute("DELETE FROM user_tokens WHERE id = ?1", params![id]) .map_err(|e| format!("Failed to delete token: {}", e))?; Ok(()) } /// 获取令牌的所有 IP 绑定 pub fn get_token_ips(token_id: &str) -> Result, String> { let conn = connect_db()?; let mut stmt = conn.prepare("SELECT * FROM token_ip_bindings WHERE token_id = ?1 ORDER BY last_seen_at DESC") .map_err(|e| format!("Failed to prepare query: {}", e))?; let iter = stmt.query_map(params![token_id], |row| { Ok(TokenIpBinding { id: row.get("id")?, token_id: row.get("token_id")?, ip_address: row.get("ip_address")?, first_seen_at: row.get("first_seen_at")?, last_seen_at: row.get("last_seen_at")?, request_count: row.get("request_count")?, user_agent: row.get("user_agent")?, }) }).map_err(|e| format!("Failed to query token IPs: {}", e))?; let mut bindings = Vec::new(); for b in iter { bindings.push(b.map_err(|e| format!("Failed to parse binding row: {}", e))?); } Ok(bindings) } /// 记录/更新令牌使用情况 (同时处理 user_tokens 和 token_ip_bindings) pub fn record_token_usage_and_ip( token_id: &str, ip: &str, model: &str, input_tokens: i32, output_tokens: i32, status: u16, user_agent: Option ) -> Result<(), String> { let mut conn = connect_db()?; let tx = conn.transaction().map_err(|e| format!("Failed to create transaction: {}", e))?; let now = Utc::now().timestamp(); // 1. 更新 user_tokens 主表 tx.execute( "UPDATE user_tokens SET last_used_at = ?1, total_requests = total_requests + 1, total_tokens_used = total_tokens_used + ?2 WHERE id = ?3", params![now, input_tokens + output_tokens, token_id], ).map_err(|e| format!("Failed to update user_tokens stats: {}", e))?; // 2. 更新或插入 token_ip_bindings 表 let binding_exists: bool = tx.query_row( "SELECT EXISTS(SELECT 1 FROM token_ip_bindings WHERE token_id = ?1 AND ip_address = ?2)", params![token_id, ip], |row| row.get(0), ).unwrap_or(false); if binding_exists { tx.execute( "UPDATE token_ip_bindings SET last_seen_at = ?1, request_count = request_count + 1, user_agent = COALESCE(?2, user_agent) WHERE token_id = ?3 AND ip_address = ?4", params![now, user_agent, token_id, ip], ).map_err(|e| format!("Failed to update ip binding: {}", e))?; } else { let binding_id = Uuid::new_v4().to_string(); tx.execute( "INSERT INTO token_ip_bindings ( id, token_id, ip_address, first_seen_at, last_seen_at, request_count, user_agent ) VALUES (?1, ?2, ?3, ?4, ?5, 1, ?6)", params![binding_id, token_id, ip, now, now, user_agent], ).map_err(|e| format!("Failed to insert ip binding: {}", e))?; } // 3. 插入 token_usage_logs 表 let log_id = Uuid::new_v4().to_string(); tx.execute( "INSERT INTO token_usage_logs ( id, token_id, ip_address, model, input_tokens, output_tokens, request_time, status ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", params![ log_id, token_id, ip, model, input_tokens, output_tokens, now, status ], ).map_err(|e| format!("Failed to insert usage log: {}", e))?; tx.commit().map_err(|e| format!("Failed to commit transaction: {}", e))?; Ok(()) } /// 检查 Token 是否有效 (包含过期时间检查和 IP 限制检查) /// 返回: (是否有效, 拒绝原因) pub fn validate_token(token_str: &str, ip: &str) -> Result<(bool, Option), String> { let token_opt = get_token_by_value(token_str)?; if let Some(token) = token_opt { // 1. 检查过期时间 if let Some(expires_at) = token.expires_at { if expires_at < Utc::now().timestamp() { return Ok((false, Some("Your token has expired. Please contact the administrator to renew it.".to_string()))); } } // 2. 检查 IP 限制 if token.max_ips > 0 { let conn = connect_db()?; // 检查当前 IP 是否已绑定 let is_bound: bool = conn.query_row( "SELECT EXISTS(SELECT 1 FROM token_ip_bindings WHERE token_id = ?1 AND ip_address = ?2)", params![token.id, ip], |row| row.get(0) ).unwrap_or(false); if !is_bound { // 如果未绑定,检查是否达到上限 let current_ip_count: i32 = conn.query_row( "SELECT COUNT(*) FROM token_ip_bindings WHERE token_id = ?1", params![token.id], |row| row.get(0) ).unwrap_or(0); if current_ip_count >= token.max_ips { return Ok((false, Some(format!("IP limit reached ({}/{}). Please contact the administrator to increase the limit.", current_ip_count, token.max_ips)))); } } } // 3. 检查宵禁时间 (Curfew) // 逻辑:如果当前北京时间在 start 和 end 之间,则拒绝 // 格式:HH:MM // 使用固定 UTC+8 (北京时间),不依赖服务器本地时区 if let (Some(start_str), Some(end_str)) = (&token.curfew_start, &token.curfew_end) { if !start_str.is_empty() && !end_str.is_empty() { let beijing_offset = FixedOffset::east_opt(8 * 3600).unwrap(); let now_beijing = Utc::now().with_timezone(&beijing_offset); let current_time_str = format!("{:02}:{:02}", now_beijing.hour(), now_beijing.minute()); // 跨午夜处理: start > end (e.g. 23:00 to 06:00) // 正常: start < end (e.g. 09:00 to 18:00) let is_curfew = if start_str > end_str { current_time_str >= *start_str || current_time_str < *end_str } else { current_time_str >= *start_str && current_time_str < *end_str }; if is_curfew { return Ok((false, Some(format!("Service is not available between {} and {} Beijing Time (Curfew enabled). Current Beijing time: {}", start_str, end_str, current_time_str)))); } } } // 一切正常,Token 有效 Ok((true, None)) } else { Ok((false, Some("Invalid token. Please check your API key.".to_string()))) } } /// 获取 IP 关联的用户名 (用于 IP 管理页面) /// 返回最近一次使用该 IP 的 Token 所属的用户名 pub fn get_username_for_ip(ip: &str) -> Result, String> { let conn = connect_db()?; let result: Option = conn.query_row( "SELECT t.username FROM token_ip_bindings b JOIN user_tokens t ON b.token_id = t.id WHERE b.ip_address = ?1 ORDER BY b.last_seen_at DESC LIMIT 1", params![ip], |row| row.get(0), ).optional().map_err(|e| format!("Failed to query username by ip: {}", e))?; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_create_and_query_token() { let _ = init_db(); // Ensure DB is initialized // Use a random username to avoid collisions in existing DB runs during dev let username = format!("TestUser_{}", Uuid::new_v4()); let token_res = create_token(username.clone(), "day".to_string(), Some("Test token".to_string()), 0, None, None, None); assert!(token_res.is_ok()); let token = token_res.unwrap(); assert_eq!(token.username, username); assert!(token.token.starts_with("sk-")); let fetched = get_token_by_id(&token.id); assert!(fetched.is_ok()); assert_eq!(fetched.unwrap().unwrap().username, username); } }