VersionRAG / graph_manager.py
shahbazdev0's picture
Update graph_manager.py
640ec41 verified
# graph_manager.py - Version Graph Management (FIXED CHANGE DETECTION)
import networkx as nx
from typing import List, Dict, Optional, Set
import json
from datetime import datetime
import difflib
class GraphManager:
"""Manages version graph with documents, versions, and changes"""
def __init__(self, user_id: str):
self.user_id = user_id
self.graph = nx.DiGraph()
self.document_versions = {} # document_name -> [versions]
self.version_content = {} # (document, version) -> content
def add_document_version(self, document_name: str, version: str,
content: str, metadata: Dict = None):
"""Add a new version of a document to the graph"""
# Create document node if it doesn't exist
if document_name not in self.graph:
self.graph.add_node(document_name, node_type='document',
metadata=metadata or {})
self.document_versions[document_name] = []
# Create version node
version_node = f"{document_name}:{version}"
self.graph.add_node(
version_node,
node_type='version',
version=version,
document=document_name,
timestamp=datetime.now().isoformat(),
metadata=metadata or {}
)
# Link document to version
self.graph.add_edge(document_name, version_node, edge_type='has_version')
# Store content
self.version_content[(document_name, version)] = content
# Add to version list
if version not in self.document_versions[document_name]:
self.document_versions[document_name].append(version)
self.document_versions[document_name].sort()
# Link to previous version if exists
versions = self.document_versions[document_name]
if len(versions) > 1:
prev_version = versions[versions.index(version) - 1]
prev_node = f"{document_name}:{prev_version}"
self.graph.add_edge(prev_node, version_node, edge_type='next_version')
def add_version_with_changes(self, document_name: str, version: str,
changes: Dict):
"""Add a version with explicit change tracking"""
version_node = f"{document_name}:{version}"
# Create change node
change_node = f"{version_node}:changes"
self.graph.add_node(
change_node,
node_type='changes',
additions=changes.get('additions', []),
deletions=changes.get('deletions', []),
modifications=changes.get('modifications', []),
timestamp=datetime.now().isoformat()
)
# Link version to changes
self.graph.add_edge(version_node, change_node, edge_type='has_changes')
def get_all_documents(self) -> List[str]:
"""Get list of all documents"""
return [node for node, data in self.graph.nodes(data=True)
if data.get('node_type') == 'document']
def get_document_versions(self, document_name: str) -> List[str]:
"""Get all versions of a document"""
return self.document_versions.get(document_name, [])
def get_version_info(self, document_name: str, version: str) -> Dict:
"""Get information about a specific version"""
version_node = f"{document_name}:{version}"
if version_node in self.graph:
return self.graph.nodes[version_node]
return {}
def get_changes_between_versions(self, document_name: str,
version1: str, version2: str,
max_display: int = 100) -> Dict:
"""Compute changes between two versions - FIXED TO SHOW ALL CHANGES"""
content1 = self.version_content.get((document_name, version1), "")
content2 = self.version_content.get((document_name, version2), "")
if not content1 or not content2:
return {
'additions': [],
'deletions': [],
'modifications': [],
'total_additions': 0,
'total_deletions': 0,
'total_modifications': 0,
'showing_limit': 0
}
# Compute diff
lines1 = content1.split('\n')
lines2 = content2.split('\n')
diff = difflib.unified_diff(lines1, lines2, lineterm='')
additions = []
deletions = []
modifications = []
for line in diff:
if line.startswith('+') and not line.startswith('+++'):
additions.append(line[1:])
elif line.startswith('-') and not line.startswith('---'):
deletions.append(line[1:])
elif line.startswith('?'):
modifications.append(line[1:])
# FIXED: Return with total counts and display limit
return {
'additions': additions[:max_display], # Show first N for UI performance
'deletions': deletions[:max_display],
'modifications': modifications[:max_display],
'total_additions': len(additions), # ✅ TOTAL count
'total_deletions': len(deletions), # ✅ TOTAL count
'total_modifications': len(modifications), # ✅ TOTAL count
'showing_limit': max_display, # ✅ Display limit
'truncated': len(additions) > max_display or len(deletions) > max_display or len(modifications) > max_display
}
def query_version_graph(self, query: str) -> List[Dict]:
"""Query the version graph for relevant versions"""
results = []
for node, data in self.graph.nodes(data=True):
if data.get('node_type') == 'version':
# Simple keyword matching (can be enhanced with embeddings)
if any(term.lower() in str(data).lower() for term in query.split()):
results.append({
'node': node,
'data': data
})
return results
def export_graph(self) -> Dict:
"""Export graph structure"""
return {
'nodes': dict(self.graph.nodes(data=True)),
'edges': list(self.graph.edges(data=True)),
'document_versions': self.document_versions
}
def import_graph(self, graph_data: Dict):
"""Import graph structure"""
self.graph = nx.DiGraph()
for node, data in graph_data['nodes'].items():
self.graph.add_node(node, **data)
for source, target, data in graph_data['edges']:
self.graph.add_edge(source, target, **data)
self.document_versions = graph_data.get('document_versions', {})