File size: 6,375 Bytes
a2cbcac
d5ba3a3
 
 
a2cbcac
 
 
 
 
 
 
d5ba3a3
 
a2cbcac
 
d5ba3a3
 
 
 
 
a2cbcac
d5ba3a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cbcac
d5ba3a3
 
 
a2cbcac
 
d5ba3a3
 
 
 
 
a2cbcac
 
d5ba3a3
a2cbcac
d5ba3a3
 
 
a2cbcac
 
d5ba3a3
a2cbcac
d5ba3a3
 
 
a2cbcac
 
 
 
 
 
 
 
 
 
d5ba3a3
 
a2cbcac
 
d5ba3a3
 
 
 
a2cbcac
 
d5ba3a3
a2cbcac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ba3a3
a2cbcac
 
 
 
 
d5ba3a3
 
a2cbcac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import RetryPolicy
from .state import AgentState, create_initial_state
from ..agents.data_collection_agent import data_collection_agent_node
from ..agents.technical_analysis_agent import technical_analysis_agent_node
from ..agents.news_intelligence_agent import news_intelligence_agent_node
from ..agents.portfolio_manager_agent import portfolio_manager_agent_node


def _log_partial(updates: dict, agent_name: str) -> None:
    """Log interesting fields from a partial-state update dict."""
    print(f"\n{agent_name} Agent Complete:")

    if agent_name == "Data Collection":
        data_results = updates.get('data_collection_results')
        if data_results:
            market_data = data_results.get('market_data', {})
            print(f"Current Price: ${market_data.get('current_price', 'N/A')}")

    elif agent_name == "Technical Analysis":
        tech_results = updates.get('technical_analysis_results')
        if tech_results:
            print(f"Technical Success: {tech_results.get('success', False)}")

    elif agent_name == "News Intelligence":
        news_results = updates.get('news_intelligence_results')
        if news_results:
            print(f"News Success: {news_results.get('success', False)}")

    elif agent_name == "Portfolio Manager":
        portfolio_results = updates.get('portfolio_manager_results', {})
        for sym, sym_data in portfolio_results.items():
            if sym_data and sym_data.get('success'):
                print(f"Signal: {sym_data.get('trading_signal', 'N/A')} | "
                      f"Confidence: {sym_data.get('confidence_level', 'N/A')}")

    if updates.get('error'):
        print(f"Error: {updates['error']}")


async def debug_data_collection_node(state: AgentState) -> dict:
    """Data collection node with debug output."""
    updates = await data_collection_agent_node(state)
    _log_partial(updates, "Data Collection")
    return updates


async def debug_technical_analysis_node(state: AgentState) -> dict:
    """Technical analysis node with debug output."""
    updates = await technical_analysis_agent_node(state)
    _log_partial(updates, "Technical Analysis")
    return updates


async def debug_news_intelligence_node(state: AgentState) -> dict:
    """News intelligence node with debug output."""
    updates = await news_intelligence_agent_node(state)
    _log_partial(updates, "News Intelligence")
    return updates


async def debug_portfolio_manager_node(state: AgentState) -> dict:
    """Portfolio manager node with debug output."""
    updates = await portfolio_manager_agent_node(state)
    _log_partial(updates, "Portfolio Manager")
    return updates


def create_workflow() -> StateGraph:
    """
    Create LangGraph workflow connecting all agents.
            
        Returns:
        StateGraph: Configured workflow graph
    """
    # Initialize workflow
    _api_retry = RetryPolicy(max_attempts=3, initial_interval=2.0)

    workflow = StateGraph(AgentState)
    
    workflow.add_node("data_collection", debug_data_collection_node, retry=_api_retry)
    workflow.add_node("technical_analysis", debug_technical_analysis_node, retry=_api_retry)
    workflow.add_node("news_intelligence", debug_news_intelligence_node, retry=_api_retry)
    workflow.add_node("portfolio_manager", debug_portfolio_manager_node, retry=_api_retry)
    
    # Define linear flow
    workflow.add_edge(START, "data_collection")
    workflow.add_edge("data_collection", "technical_analysis")
    workflow.add_edge("technical_analysis", "news_intelligence")
    workflow.add_edge("news_intelligence", "portfolio_manager")
    workflow.add_edge("portfolio_manager", END)
    
    return workflow


async def run_analysis(symbols: list[str], session_id: str = "default", analysis_date: Optional[str] = None) -> Dict[str, Any]:
    """
    Run complete analysis workflow for symbols.
        
        Args:
            symbols: List of stock symbols to analyze
        session_id: Session identifier
        analysis_date: Date for analysis in YYYY-MM-DD format (optional, defaults to today)
            
        Returns:
        Dict with analysis results
        """
    try:
        # Create workflow
        workflow = create_workflow()
        app = workflow.compile(checkpointer=InMemorySaver())
        
        # Initialize state with analysis date
        initial_state = create_initial_state(session_id, symbols, analysis_date)
        
        # Run workflow
        config = {"configurable": {"thread_id": session_id}, "recursion_limit": 30}
        result = await app.ainvoke(initial_state, config)
        
        # Extract results
        return {
            'success': True,
            'session_id': session_id,
            'symbols': symbols,
            'analysis_date': analysis_date,
            'results': {
                'data_collection': result.get('data_collection_results'),
                'technical_analysis': result.get('technical_analysis_results'),
                'news_intelligence': result.get('news_intelligence_results'),
                'portfolio_manager': result.get('portfolio_manager_results')
            },
            'final_step': result.get('current_step'),
            'error': result.get('error')
        }
        
    except Exception as e:
        print(f"Workflow error: {e}")
        return {
            'success': False,
            'error': str(e),
            'symbols': symbols,
            'session_id': session_id,
            'analysis_date': analysis_date
        }


def should_continue(state: AgentState) -> str:
    """
    Simple conditional logic for workflow routing.
    
    Args:
        state: Current workflow state
        
    Returns:
        Next step or END
    """
    if state.get('error'):
        return END
    
    current_step = state.get('current_step', '')
    
    if current_step == 'data_collection_complete':
        return 'technical_analysis'
    elif current_step == 'technical_analysis_complete':
        return 'news_intelligence'
    elif current_step == 'news_intelligence_complete':
        return 'portfolio_manager'
    elif current_step == 'portfolio_management_complete':
        return END
    else:
        return END