Spaces:
Runtime error
Runtime error
| """ | |
| Message validation and sanitization for WebSocket communications. | |
| This module provides validation and sanitization for incoming WebSocket messages | |
| to ensure security and data integrity. | |
| """ | |
| import re | |
| import html | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| from datetime import datetime, timedelta | |
| logger = logging.getLogger(__name__) | |
| class MessageValidator: | |
| """Validates and sanitizes WebSocket messages for security.""" | |
| # Maximum message length (characters) | |
| MAX_MESSAGE_LENGTH = 10000 | |
| # Maximum messages per minute per connection | |
| MAX_MESSAGES_PER_MINUTE = 30 | |
| # Supported programming languages | |
| SUPPORTED_LANGUAGES = { | |
| 'python', 'javascript', 'java', 'cpp', 'c', 'csharp', 'go', | |
| 'rust', 'typescript', 'php', 'ruby', 'swift', 'kotlin', 'scala' | |
| } | |
| # Patterns for potentially malicious content | |
| MALICIOUS_PATTERNS = [ | |
| r'<script[^>]*>.*?</script>', # Script tags | |
| r'javascript:', # JavaScript URLs | |
| r'on\w+\s*=', # Event handlers | |
| r'<iframe[^>]*>.*?</iframe>', # Iframes | |
| r'<object[^>]*>.*?</object>', # Objects | |
| r'<embed[^>]*>.*?</embed>', # Embeds | |
| ] | |
| def __init__(self): | |
| """Initialize the message validator.""" | |
| self.rate_limit_tracker = {} # Track message rates per connection | |
| self.compiled_patterns = [re.compile(pattern, re.IGNORECASE | re.DOTALL) | |
| for pattern in self.MALICIOUS_PATTERNS] | |
| def validate_message(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Validate and sanitize a chat message. | |
| Args: | |
| data: Message data from WebSocket | |
| Returns: | |
| Dict containing validation result and sanitized content | |
| """ | |
| errors = [] | |
| # Check required fields | |
| if not isinstance(data, dict): | |
| return { | |
| 'valid': False, | |
| 'errors': ['Message data must be a dictionary'], | |
| 'sanitized_content': None | |
| } | |
| if 'content' not in data: | |
| errors.append('Message content is required') | |
| if 'session_id' not in data: | |
| errors.append('Session ID is required') | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'sanitized_content': None | |
| } | |
| content = data['content'] | |
| session_id = data['session_id'] | |
| # Validate content type | |
| if not isinstance(content, str): | |
| errors.append('Message content must be a string') | |
| # Validate session_id type | |
| if not isinstance(session_id, str): | |
| errors.append('Session ID must be a string') | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'sanitized_content': None | |
| } | |
| # Check message length | |
| if len(content) > self.MAX_MESSAGE_LENGTH: | |
| errors.append(f'Message too long (max {self.MAX_MESSAGE_LENGTH} characters)') | |
| # Check for empty content | |
| if not content.strip(): | |
| errors.append('Message content cannot be empty') | |
| # Check rate limiting | |
| rate_limit_error = self._check_rate_limit(session_id) | |
| if rate_limit_error: | |
| errors.append(rate_limit_error) | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'sanitized_content': None | |
| } | |
| # Check for malicious patterns before sanitization | |
| malicious_patterns = self._check_malicious_patterns(content) | |
| if malicious_patterns: | |
| errors.extend(malicious_patterns) | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'sanitized_content': None | |
| } | |
| # Sanitize content | |
| sanitized_content = self._sanitize_content(content) | |
| return { | |
| 'valid': True, | |
| 'errors': [], | |
| 'sanitized_content': sanitized_content | |
| } | |
| def validate_language_switch(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Validate a language switch request. | |
| Args: | |
| data: Language switch data from WebSocket | |
| Returns: | |
| Dict containing validation result and validated language | |
| """ | |
| errors = [] | |
| # Check required fields | |
| if not isinstance(data, dict): | |
| return { | |
| 'valid': False, | |
| 'errors': ['Language switch data must be a dictionary'], | |
| 'language': None | |
| } | |
| if 'language' not in data: | |
| errors.append('Language is required') | |
| if 'session_id' not in data: | |
| errors.append('Session ID is required') | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'language': None | |
| } | |
| language = data['language'] | |
| session_id = data['session_id'] | |
| # Validate types | |
| if not isinstance(language, str): | |
| errors.append('Language must be a string') | |
| if not isinstance(session_id, str): | |
| errors.append('Session ID must be a string') | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'language': None | |
| } | |
| # Normalize language | |
| normalized_language = language.lower().strip() | |
| # Check if language is supported | |
| if normalized_language not in self.SUPPORTED_LANGUAGES: | |
| errors.append(f'Unsupported language: {language}. Supported languages: {", ".join(sorted(self.SUPPORTED_LANGUAGES))}') | |
| # Check rate limiting | |
| rate_limit_error = self._check_rate_limit(session_id) | |
| if rate_limit_error: | |
| errors.append(rate_limit_error) | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors, | |
| 'language': None | |
| } | |
| return { | |
| 'valid': True, | |
| 'errors': [], | |
| 'language': normalized_language | |
| } | |
| def validate_typing_event(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Validate a typing event. | |
| Args: | |
| data: Typing event data from WebSocket | |
| Returns: | |
| Dict containing validation result | |
| """ | |
| errors = [] | |
| # Check data type | |
| if not isinstance(data, dict): | |
| return { | |
| 'valid': False, | |
| 'errors': ['Typing event data must be a dictionary'] | |
| } | |
| # Session ID is optional for typing events but if present, validate it | |
| if 'session_id' in data and not isinstance(data['session_id'], str): | |
| errors.append('Session ID must be a string') | |
| if errors: | |
| return { | |
| 'valid': False, | |
| 'errors': errors | |
| } | |
| return { | |
| 'valid': True, | |
| 'errors': [] | |
| } | |
| def _sanitize_content(self, content: str) -> str: | |
| """ | |
| Sanitize message content to prevent XSS and other attacks. | |
| Args: | |
| content: Raw message content | |
| Returns: | |
| str: Sanitized content | |
| """ | |
| # HTML escape to prevent XSS | |
| sanitized = html.escape(content) | |
| # Remove null bytes | |
| sanitized = sanitized.replace('\x00', '') | |
| # Normalize whitespace but preserve code formatting | |
| lines = sanitized.split('\n') | |
| normalized_lines = [] | |
| for line in lines: | |
| # Preserve leading whitespace for code blocks | |
| stripped = line.rstrip() | |
| normalized_lines.append(stripped) | |
| # Remove excessive empty lines (more than 3 consecutive) | |
| result_lines = [] | |
| empty_count = 0 | |
| for line in normalized_lines: | |
| if not line.strip(): | |
| empty_count += 1 | |
| if empty_count <= 3: | |
| result_lines.append(line) | |
| else: | |
| empty_count = 0 | |
| result_lines.append(line) | |
| return '\n'.join(result_lines) | |
| def _check_malicious_patterns(self, content: str) -> List[str]: | |
| """ | |
| Check for potentially malicious patterns in content. | |
| Args: | |
| content: Content to check | |
| Returns: | |
| List[str]: List of detected malicious patterns | |
| """ | |
| detected_patterns = [] | |
| for pattern in self.compiled_patterns: | |
| if pattern.search(content): | |
| detected_patterns.append(f'Potentially malicious content detected') | |
| break # Don't reveal specific patterns for security | |
| return detected_patterns | |
| def _check_rate_limit(self, session_id: str) -> Optional[str]: | |
| """ | |
| Check if the session is exceeding rate limits. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| Optional[str]: Error message if rate limit exceeded, None otherwise | |
| """ | |
| now = datetime.utcnow() | |
| minute_key = now.strftime('%Y-%m-%d-%H-%M') | |
| # Initialize tracking for this session if needed | |
| if session_id not in self.rate_limit_tracker: | |
| self.rate_limit_tracker[session_id] = {} | |
| session_tracker = self.rate_limit_tracker[session_id] | |
| # Clean up old entries (keep only current and previous minute) | |
| current_minute = minute_key | |
| previous_minute = (now.replace(second=0, microsecond=0) - | |
| timedelta(minutes=1)).strftime('%Y-%m-%d-%H-%M') | |
| keys_to_keep = {current_minute, previous_minute} | |
| keys_to_remove = [k for k in session_tracker.keys() if k not in keys_to_keep] | |
| for key in keys_to_remove: | |
| del session_tracker[key] | |
| # Check current minute count | |
| current_count = session_tracker.get(current_minute, 0) | |
| if current_count >= self.MAX_MESSAGES_PER_MINUTE: | |
| return f'Rate limit exceeded. Maximum {self.MAX_MESSAGES_PER_MINUTE} messages per minute.' | |
| # Increment counter | |
| session_tracker[current_minute] = current_count + 1 | |
| return None | |
| def get_supported_languages(self) -> List[str]: | |
| """ | |
| Get list of supported programming languages. | |
| Returns: | |
| List[str]: Sorted list of supported languages | |
| """ | |
| return sorted(list(self.SUPPORTED_LANGUAGES)) | |
| def cleanup_rate_limit_tracker(self) -> None: | |
| """Clean up old rate limit tracking data.""" | |
| now = datetime.utcnow() | |
| cutoff_time = now - timedelta(minutes=5) | |
| cutoff_key = cutoff_time.strftime('%Y-%m-%d-%H-%M') | |
| sessions_to_clean = [] | |
| for session_id, session_tracker in self.rate_limit_tracker.items(): | |
| keys_to_remove = [k for k in session_tracker.keys() if k < cutoff_key] | |
| for key in keys_to_remove: | |
| del session_tracker[key] | |
| # Remove empty session trackers | |
| if not session_tracker: | |
| sessions_to_clean.append(session_id) | |
| for session_id in sessions_to_clean: | |
| del self.rate_limit_tracker[session_id] | |
| logger.debug(f"Cleaned up rate limit tracker, removed {len(sessions_to_clean)} empty sessions") | |
| def create_message_validator() -> MessageValidator: | |
| """ | |
| Factory function to create a MessageValidator instance. | |
| Returns: | |
| MessageValidator: Configured message validator | |
| """ | |
| return MessageValidator() |