LangGraph 图结构设计
概述
TradingAgents 基于 LangGraph 框架构建,采用有向无环图(DAG)结构来组织智能体的工作流程。这种设计确保了智能体之间的有序协作和信息的正确流转。
图结构架构
整体工作流图
graph TD
START([开始]) --> INIT[初始化状态]
INIT --> PARALLEL_ANALYSIS{并行分析阶段}
PARALLEL_ANALYSIS --> FA[基本面分析师]
PARALLEL_ANALYSIS --> TA[技术分析师]
PARALLEL_ANALYSIS --> NA[新闻分析师]
PARALLEL_ANALYSIS --> SA[社交媒体分析师]
FA --> COLLECT_ANALYSIS[收集分析结果]
TA --> COLLECT_ANALYSIS
NA --> COLLECT_ANALYSIS
SA --> COLLECT_ANALYSIS
COLLECT_ANALYSIS --> RESEARCH_MANAGER[研究经理]
RESEARCH_MANAGER --> DEBATE_INIT[初始化辩论]
DEBATE_INIT --> BULL_RESEARCHER[看涨研究员]
DEBATE_INIT --> BEAR_RESEARCHER[看跌研究员]
BULL_RESEARCHER --> DEBATE_ROUND{辩论轮次}
BEAR_RESEARCHER --> DEBATE_ROUND
DEBATE_ROUND -->|继续辩论| BULL_RESEARCHER
DEBATE_ROUND -->|辩论结束| RESEARCH_CONSENSUS[研究共识]
RESEARCH_CONSENSUS --> TRADER[交易员]
TRADER --> TRADING_DECISION[交易决策]
TRADING_DECISION --> RISK_PARALLEL{并行风险评估}
RISK_PARALLEL --> AGGRESSIVE_RISK[激进风险评估]
RISK_PARALLEL --> CONSERVATIVE_RISK[保守风险评估]
RISK_PARALLEL --> NEUTRAL_RISK[中性风险评估]
AGGRESSIVE_RISK --> RISK_DEBATE[风险辩论]
CONSERVATIVE_RISK --> RISK_DEBATE
NEUTRAL_RISK --> RISK_DEBATE
RISK_DEBATE --> RISK_MANAGER[风险经理]
RISK_MANAGER --> PORTFOLIO_DECISION[投资组合决策]
PORTFOLIO_DECISION --> END([结束])
核心组件设计
1. 图构建器 (GraphSetup)
class GraphSetup:
"""LangGraph 图结构设置"""
def build_graph(self) -> StateGraph:
"""构建完整的交易决策图"""
# 创建状态图
workflow = StateGraph(AgentState)
# 添加节点
self._add_analysis_nodes(workflow)
self._add_research_nodes(workflow)
self._add_trading_nodes(workflow)
self._add_risk_nodes(workflow)
# 添加边和条件逻辑
self._add_edges(workflow)
self._add_conditional_edges(workflow)
# 设置入口和出口
workflow.set_entry_point("initialize")
workflow.set_finish_point("portfolio_decision")
return workflow.compile()
def _add_analysis_nodes(self, workflow: StateGraph):
"""添加分析师节点"""
workflow.add_node("fundamentals_analyst", self.fundamentals_analyst)
workflow.add_node("technical_analyst", self.technical_analyst)
workflow.add_node("news_analyst", self.news_analyst)
workflow.add_node("social_analyst", self.social_analyst)
def _add_research_nodes(self, workflow: StateGraph):
"""添加研究员节点"""
workflow.add_node("research_manager", self.research_manager)
workflow.add_node("bull_researcher", self.bull_researcher)
workflow.add_node("bear_researcher", self.bear_researcher)
def _add_trading_nodes(self, workflow: StateGraph):
"""添加交易节点"""
workflow.add_node("trader", self.trader)
def _add_risk_nodes(self, workflow: StateGraph):
"""添加风险管理节点"""
workflow.add_node("aggressive_risk", self.aggressive_risk)
workflow.add_node("conservative_risk", self.conservative_risk)
workflow.add_node("neutral_risk", self.neutral_risk)
workflow.add_node("risk_manager", self.risk_manager)
2. 条件逻辑 (ConditionalLogic)
class ConditionalLogic:
"""图的条件逻辑控制"""
def should_continue_debate(self, state: AgentState) -> str:
"""判断是否继续研究员辩论"""
current_round = state.get("debate_round", 0)
max_rounds = self.config.get("max_debate_rounds", 3)
# 检查辩论轮次
if current_round >= max_rounds:
return "end_debate"
# 检查是否达成共识
if self._has_consensus(state):
return "end_debate"
# 检查分歧是否足够大
if self._has_significant_disagreement(state):
return "continue_debate"
return "end_debate"
def route_to_risk_assessment(self, state: AgentState) -> List[str]:
"""路由到风险评估节点"""
trading_decision = state.get("trader_decision", {})
risk_level = trading_decision.get("risk_level", "medium")
# 根据风险级别决定评估路径
if risk_level == "high":
return ["aggressive_risk", "conservative_risk", "neutral_risk"]
elif risk_level == "low":
return ["conservative_risk", "neutral_risk"]
else:
return ["neutral_risk"]
def should_approve_trade(self, state: AgentState) -> str:
"""判断是否批准交易"""
risk_assessment = state.get("risk_assessment", {})
risk_score = risk_assessment.get("overall_risk_score", 0.5)
# 风险阈值检查
if risk_score > self.config.get("risk_threshold", 0.8):
return "reject_trade"
# 一致性检查
if self._risk_assessments_consistent(state):
return "approve_trade"
return "request_review"
3. 状态传播 (Propagator)
class Propagator:
"""状态传播管理器"""
def propagate(self, symbol: str, date: str) -> Tuple[AgentState, Dict]:
"""执行完整的传播流程"""
# 初始化状态
initial_state = self._initialize_state(symbol, date)
# 执行图传播
final_state = self.graph.invoke(initial_state)
# 提取决策结果
decision = self._extract_decision(final_state)
return final_state, decision
def _initialize_state(self, symbol: str, date: str) -> AgentState:
"""初始化智能体状态"""
return AgentState(
ticker=symbol,
date=date,
analyst_reports={},
research_reports={},
trader_decision={},
risk_assessment={},
portfolio_decision={},
messages=[],
metadata={}
)
def _extract_decision(self, state: AgentState) -> Dict:
"""从最终状态提取决策信息"""
return {
"action": state.portfolio_decision.get("action", "hold"),
"quantity": state.portfolio_decision.get("quantity", 0),
"confidence": state.portfolio_decision.get("confidence", 0.5),
"reasoning": state.portfolio_decision.get("reasoning", ""),
"risk_score": state.risk_assessment.get("overall_risk_score", 0.5)
}
节点类型详解
1. 分析节点 (Analysis Nodes)
def fundamentals_analyst_node(state: AgentState) -> AgentState:
"""基本面分析师节点"""
# 获取数据
data = get_fundamental_data(state.ticker, state.date)
# 执行分析
analysis = fundamentals_analyst.analyze(data)
# 更新状态
state.analyst_reports["fundamentals"] = analysis
return state
2. 决策节点 (Decision Nodes)
def trader_node(state: AgentState) -> AgentState:
"""交易员决策节点"""
# 收集所有分析报告
all_reports = {
**state.analyst_reports,
**state.research_reports
}
# 制定交易决策
decision = trader.make_decision(all_reports)
# 更新状态
state.trader_decision = decision
return state
3. 并行节点 (Parallel Nodes)
def parallel_analysis_node(state: AgentState) -> AgentState:
"""并行分析节点"""
# 并行执行多个分析师
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(fundamentals_analyst.analyze, state): "fundamentals",
executor.submit(technical_analyst.analyze, state): "technical",
executor.submit(news_analyst.analyze, state): "news",
executor.submit(social_analyst.analyze, state): "social"
}
# 收集结果
for future in as_completed(futures):
analyst_type = futures[future]
result = future.result()
state.analyst_reports[analyst_type] = result
return state
边和路由设计
1. 顺序边 (Sequential Edges)
# 简单的顺序连接
workflow.add_edge("initialize", "parallel_analysis")
workflow.add_edge("parallel_analysis", "research_manager")
workflow.add_edge("research_manager", "trader")
2. 条件边 (Conditional Edges)
# 基于条件的路由
workflow.add_conditional_edges(
"debate_round",
conditional_logic.should_continue_debate,
{
"continue_debate": "bull_researcher",
"end_debate": "research_consensus"
}
)
3. 并行边 (Parallel Edges)
# 并行执行多个节点
workflow.add_conditional_edges(
"trading_decision",
conditional_logic.route_to_risk_assessment,
{
"aggressive_risk": "aggressive_risk_node",
"conservative_risk": "conservative_risk_node",
"neutral_risk": "neutral_risk_node"
}
)
状态管理
1. 状态结构
@dataclass
class AgentState:
"""智能体状态数据结构"""
# 基本信息
ticker: str
date: str
# 分析结果
analyst_reports: Dict[str, Any]
research_reports: Dict[str, Any]
# 决策信息
trader_decision: Dict[str, Any]
risk_assessment: Dict[str, Any]
portfolio_decision: Dict[str, Any]
# 通信记录
messages: List[BaseMessage]
# 元数据
metadata: Dict[str, Any]
# 控制信息
debate_round: int = 0
risk_round: int = 0
2. 状态更新
class StateManager:
"""状态管理器"""
def update_state(self, state: AgentState, updates: Dict) -> AgentState:
"""安全地更新状态"""
# 深拷贝状态
new_state = copy.deepcopy(state)
# 应用更新
for key, value in updates.items():
if hasattr(new_state, key):
setattr(new_state, key, value)
# 验证状态一致性
self._validate_state(new_state)
return new_state
def _validate_state(self, state: AgentState):
"""验证状态一致性"""
# 检查必需字段
required_fields = ["ticker", "date"]
for field in required_fields:
if not getattr(state, field):
raise ValueError(f"Required field {field} is missing")
# 检查数据类型
if not isinstance(state.analyst_reports, dict):
raise TypeError("analyst_reports must be a dictionary")
错误处理和恢复
1. 节点级错误处理
def safe_node_execution(node_func):
"""节点执行的安全包装器"""
def wrapper(state: AgentState) -> AgentState:
try:
return node_func(state)
except Exception as e:
# 记录错误
logger.error(f"Node {node_func.__name__} failed: {e}")
# 添加错误信息到状态
state.metadata["errors"] = state.metadata.get("errors", [])
state.metadata["errors"].append({
"node": node_func.__name__,
"error": str(e),
"timestamp": datetime.now().isoformat()
})
return state
return wrapper
2. 图级错误恢复
class GraphRecovery:
"""图执行恢复机制"""
def execute_with_recovery(self, graph, initial_state):
"""带恢复机制的图执行"""
try:
return graph.invoke(initial_state)
except Exception as e:
# 尝试从检查点恢复
if checkpoint := self._find_last_checkpoint(initial_state):
logger.info("Recovering from checkpoint")
return self._recover_from_checkpoint(graph, checkpoint)
# 降级执行
logger.warning("Falling back to degraded execution")
return self._degraded_execution(initial_state)
这种图结构设计确保了智能体工作流的清晰性、可维护性和容错性,同时提供了灵活的扩展机制。