Spaces:
Sleeping
Sleeping
| """ | |
| 安全模块 - 提供输入验证和安全检查功能 | |
| """ | |
| import os | |
| import re | |
| from typing import List, Tuple, Optional | |
| class SecurityValidator: | |
| """安全验证器""" | |
| # 危险文件扩展名 | |
| DANGEROUS_EXTENSIONS = { | |
| '.exe', '.bat', '.cmd', '.com', '.pif', '.scr', '.vbs', '.js', '.jar', | |
| '.php', '.asp', '.aspx', '.jsp', '.py', '.pl', '.sh', '.ps1' | |
| } | |
| # 最大文件大小(字节) | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| # 最大文件名长度 | |
| MAX_FILENAME_LENGTH = 255 | |
| # 危险关键词模式 | |
| DANGEROUS_PATTERNS = [ | |
| r'<script[^>]*>.*?</script>', | |
| r'javascript:', | |
| r'data:text/html', | |
| r'vbscript:', | |
| r'onload\s*=', | |
| r'onerror\s*=', | |
| r'eval\s*\(', | |
| r'exec\s*\(', | |
| r'system\s*\(', | |
| ] | |
| def _is_cjk(char: str) -> bool: | |
| """ | |
| 判断字符是否为常见东亚字符(中文、日文、韩文、全角符号等) | |
| """ | |
| return re.match(r"[\u4e00-\u9fff\u3400-\u4dbf\u3000-\u303f\uff00-\uffef\u3040-\u30ff\uac00-\ud7af]", char) is not None | |
| def validate_file_upload(cls, file_path: str) -> Tuple[bool, str]: | |
| """ | |
| 验证文件上传的安全性 | |
| Args: | |
| file_path: 文件路径 | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| try: | |
| # 检查文件是否存在 | |
| if not os.path.exists(file_path): | |
| return False, "文件不存在" | |
| # 检查文件大小 | |
| file_size = os.path.getsize(file_path) | |
| if file_size > cls.MAX_FILE_SIZE: | |
| return False, f"文件过大 ({file_size} bytes, 最大: {cls.MAX_FILE_SIZE})" | |
| # 检查文件名 | |
| filename = os.path.basename(file_path) | |
| if len(filename) > cls.MAX_FILENAME_LENGTH: | |
| return False, f"文件名过长 ({len(filename)} 字符, 最大: {cls.MAX_FILENAME_LENGTH})" | |
| # 检查文件扩展名 | |
| _, ext = os.path.splitext(filename.lower()) | |
| if ext in cls.DANGEROUS_EXTENSIONS: | |
| return False, f"不支持的文件类型: {ext}" | |
| # 检查文件名中的危险字符 | |
| if not cls._is_safe_filename(filename): | |
| return False, "文件名包含危险字符" | |
| return True, "" | |
| except Exception as e: | |
| return False, f"文件验证失败: {str(e)}" | |
| def validate_text_input(cls, text: str, max_length: int = 10000) -> Tuple[bool, str]: | |
| """ | |
| 验证文本输入的安全性 | |
| Args: | |
| text: 输入文本 | |
| max_length: 最大长度 | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| if not text: | |
| return True, "" | |
| # 检查长度 | |
| if len(text) > max_length: | |
| return False, f"输入过长 ({len(text)} 字符, 最大: {max_length})" | |
| # 检查危险模式 | |
| for pattern in cls.DANGEROUS_PATTERNS: | |
| if re.search(pattern, text, re.IGNORECASE | re.DOTALL): | |
| return False, "输入包含潜在危险内容" | |
| # 检查非常规字符比例(放宽中文/日文/韩文及常见标点) | |
| def _is_safe_char(c: str) -> bool: | |
| if not c: | |
| return True | |
| # 字母数字与空白 | |
| if c.isalnum() or c.isspace(): | |
| return True | |
| # 常见中英文标点 | |
| if re.match(r"[\.,;:!\?\-_'\"(){}\[\]\\/@#%&\+=<>~\^\|$,。;:!?、()【】《》—…·]", c): | |
| return True | |
| # CJK 字符与全角符号 | |
| if SecurityValidator._is_cjk(c): | |
| return True | |
| return False | |
| total_len = len(text) | |
| if total_len > 0: | |
| unsafe_count = sum(1 for ch in text if not _is_safe_char(ch)) | |
| # 只有当非常规字符比例很高时才判为风险(阈值 0.5) | |
| if unsafe_count / total_len > 0.5: | |
| return False, "输入包含大量非常规字符" | |
| return True, "" | |
| def sanitize_filename(cls, filename: str) -> str: | |
| """ | |
| 清理文件名,移除危险字符 | |
| Args: | |
| filename: 原始文件名 | |
| Returns: | |
| 清理后的文件名 | |
| """ | |
| # 移除路径分隔符和危险字符 | |
| dangerous_chars = r'[<>:"/\\|?*\x00-\x1f]' | |
| sanitized = re.sub(dangerous_chars, '_', filename) | |
| # 限制长度 | |
| if len(sanitized) > cls.MAX_FILENAME_LENGTH: | |
| name, ext = os.path.splitext(sanitized) | |
| max_name_length = cls.MAX_FILENAME_LENGTH - len(ext) | |
| sanitized = name[:max_name_length] + ext | |
| return sanitized | |
| def _is_safe_filename(cls, filename: str) -> bool: | |
| """ | |
| 检查文件名是否安全 | |
| Args: | |
| filename: 文件名 | |
| Returns: | |
| 是否安全 | |
| """ | |
| # 检查危险字符 | |
| dangerous_chars = r'[<>:"/\\|?*\x00-\x1f]' | |
| if re.search(dangerous_chars, filename): | |
| return False | |
| # 检查保留名称(Windows) | |
| reserved_names = { | |
| 'CON', 'PRN', 'AUX', 'NUL', | |
| 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8', 'COM9', | |
| 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9' | |
| } | |
| name_without_ext = os.path.splitext(filename)[0].upper() | |
| if name_without_ext in reserved_names: | |
| return False | |
| return True | |
| class InputValidator: | |
| """输入验证器""" | |
| def validate_message(message: str) -> Tuple[bool, str]: | |
| """ | |
| 验证用户消息 | |
| Args: | |
| message: 用户消息 | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| if not message or not message.strip(): | |
| return False, "消息不能为空" | |
| # 放宽消息长度上限(本地校验) | |
| return SecurityValidator.validate_text_input(message, max_length=200000) | |
| def validate_custom_prompt(prompt: str) -> Tuple[bool, str]: | |
| """ | |
| 验证自定义提示词 | |
| Args: | |
| prompt: 自定义提示词 | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| if not prompt: | |
| return True, "" | |
| # 放宽自定义提示词长度上限(本地校验) | |
| return SecurityValidator.validate_text_input(prompt, max_length=200000) | |
| def validate_file_list(files: List) -> Tuple[bool, str]: | |
| """ | |
| 验证文件列表 | |
| Args: | |
| files: 文件列表 | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| if not files: | |
| return True, "" | |
| file_list = files if isinstance(files, (list, tuple)) else [files] | |
| for file_obj in file_list: | |
| file_path = getattr(file_obj, "name", None) or str(file_obj) | |
| is_valid, error_msg = SecurityValidator.validate_file_upload(file_path) | |
| if not is_valid: | |
| return False, error_msg | |
| return True, "" | |
| # 全局验证器实例 | |
| security_validator = SecurityValidator() | |
| input_validator = InputValidator() | |