在基于DAG的任务编排系统中,上下文状态传递是确保各节点协同工作的关键。以下是详细的状态传递机制设计:
DAG中的状态传递主要解决三个问题:
# DAG定义示例
dag = DAG(name="power_report_generation")
# 节点定义
text_analysis = Node(
id="text_analysis",
model="text_understanding_model",
outputs=["key_points", "entities", "sentiment"]
)
data_processing = Node(
id="data_processing",
model="power_data_processor",
outputs=["charts", "metrics", "trends"]
)
report_writing = Node(
id="report_writing",
model="report_generator",
inputs={
"key_points": "text_analysis.key_points",
"entities": "text_analysis.entities",
"charts": "data_processing.charts",
"metrics": "data_processing.metrics"
},
outputs=["report_sections", "executive_summary"]
)
# 添加节点到DAG
dag.add_node(text_analysis)
dag.add_node(data_processing)
dag.add_node(report_writing)
# 定义边(可选,因为已经通过inputs隐式定义了)
dag.add_edge(text_analysis, report_writing)
dag.add_edge(data_processing, report_writing)
在这个例子中:
inputs
字典明确声明了它需要的上游节点输出- 全局上下文 (GlobalContext)
├── 任务级上下文 (TaskContext)
│ ├── 节点上下文1 (NodeContext)
│ ├── 节点上下文2 (NodeContext)
│ └── ...
└── 系统级上下文 (SystemContext)
class WorkflowContext:
def __init__(self, workflow_id):
self.workflow_id = workflow_id
self.global_vars = {} # 全局变量
self.node_outputs = {} # 节点输出结果
self.execution_state = {} # 执行状态
self.metadata = {} # 元数据
def set_node_output(self, node_id, output_name, value):
"""存储节点的特定输出"""
if node_id not in self.node_outputs:
self.node_outputs[node_id] = {}
self.node_outputs[node_id][output_name] = value
def get_node_output(self, node_id, output_name):
"""获取节点的特定输出"""
return self.node_outputs.get(node_id, {}).get(output_name)
def get_input_value(self, input_ref):
"""根据引用路径获取输入值
例如: "text_analysis.key_points"
"""
if "." not in input_ref:
return self.global_vars.get(input_ref)
node_id, output_name = input_ref.split(".", 1)
return self.get_node_output(node_id, output_name)
def prepare_node_inputs(self, node):
"""准备节点所需的所有输入"""
inputs = {}
for input_name, input_ref in node.inputs.items():
inputs[input_name] = self.get_input_value(input_ref)
return inputs
def set_global(self, key, value):
"""设置全局变量"""
self.global_vars[key] = value
def update_execution_state(self, node_id, state):
"""更新节点执行状态"""
self.execution_state[node_id] = {
"state": state,
"updated_at": time.time()
}
class DAGExecutionEngine:
def __init__(self, model_registry, storage_service):
self.model_registry = model_registry
self.storage = storage_service
def execute_node(self, workflow_context, node):
# 1. 准备节点输入
inputs = workflow_context.prepare_node_inputs(node)
# 2. 获取模型
model = self.model_registry.get_model(node.model)
# 3. 执行模型
workflow_context.update_execution_state(node.id, "RUNNING")
try:
outputs = model.execute(inputs)
# 4. 存储输出
for output_name in node.outputs:
if output_name in outputs:
workflow_context.set_node_output(
node.id, output_name, outputs[output_name]
)
else:
raise ValueError(f"Model did not produce expected output: {output_name}")
workflow_context.update_execution_state(node.id, "COMPLETED")
# 5. 处理大型输出的存储
self._handle_large_outputs(workflow_context, node, outputs)
return True
except Exception as e:
workflow_context.update_execution_state(node.id, "FAILED")
workflow_context.set_node_output(node.id, "error", str(e))
return False
def _handle_large_outputs(self, workflow_context, node, outputs):
"""处理大型输出,例如生成的图像或文档"""
for output_name in node.outputs:
if output_name in outputs and self._is_large_output(outputs[output_name]):
# 存储到外部存储服务
storage_path = f"workflow/{workflow_context.workflow_id}/node/{node.id}/{output_name}"
reference = self.storage.store(storage_path, outputs[output_name])
# 替换上下文中的值为引用
workflow_context.set_node_output(node.id, output_name, {
"type": "reference",
"uri": reference
})
为电力行业的特定需求,可以定义结构化的上下文模式:
class PowerAnalysisContext(TypedDict):
"""电力分析上下文数据结构"""
power_station_id: str
time_period: Dict[str, str] # start_date, end_date
metrics: List[str]
weather_conditions: Optional[Dict]
generation_data: Optional[Dict]
consumption_data: Optional[Dict]
anomalies: Optional[List[Dict]]
recommendations: Optional[List[Dict]]
使用时:
# 类型检查和验证
def validate_power_context(context: PowerAnalysisContext) -> bool:
required_fields = ['power_station_id', 'time_period', 'metrics']
return all(field in context and context[field] for field in required_fields)
# 节点使用类型化上下文
class PowerAnalysisNode(Node):
def process(self, context: WorkflowContext):
# 从上下文获取电力分析数据
power_data: PowerAnalysisContext = context.get_typed('power_analysis')
# 验证数据
if not validate_power_context(power_data):
raise ValueError("Invalid power analysis context")
# 处理并更新上下文
analysis_results = self.model.analyze(power_data)
# 更新上下文中的特定字段
updated_power_data = {**power_data, **analysis_results}
context.set_typed('power_analysis', updated_power_data)
最基本的上下文传递方式:
A → B → C
节点B直接依赖节点A的输出,节点C直接依赖节点B的输出。
下游节点可以选择性地只使用上游节点的部分输出:
report_generator = Node(
id="report_generator",
inputs={
# 只使用数据处理节点的图表输出,忽略其他输出
"charts": "data_processing.charts"
}
)
一个节点可以同时依赖多个上游节点的输出:
A
/ \
B C
\ /
D
节点D合并来自B和C的输出:
summary_node = Node(
id="summary_generation",
inputs={
"text_insights": "text_analysis.insights",
"data_insights": "data_analysis.insights",
"recommended_actions": "recommendation_engine.actions"
}
)
基于某些条件动态决定使用哪些上游输出:
class ConditionalNode(Node):
def prepare_inputs(self, context):
# 基于上下文动态选择输入源
if context.get_global("report_type") == "executive":
return {
"content": context.get_node_output("executive_summary", "content")
}
else:
return {
"content": context.get_node_output("technical_analysis", "content")
}
在并行执行模式下,状态同步尤为重要:
class ParallelExecutionManager:
def __init__(self, max_workers=10):
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.context_lock = threading.RLock()
def execute_parallel_nodes(self, workflow_context, nodes):
futures = {}
results = {}
# 提交所有可并行执行的节点
for node in nodes:
future = self.executor.submit(
self._execute_node_safe, workflow_context, node
)
futures[future] = node.id
# 等待所有执行完成
for future in as_completed(futures):
node_id = futures[future]
try:
result = future.result()
results[node_id] = result
except Exception as e:
results[node_id] = {"status": "error", "error": str(e)}
return results
def _execute_node_safe(self, workflow_context, node):
# 准备节点输入(加锁读取上下文)
with self.context_lock:
inputs = workflow_context.prepare_node_inputs(node)
# 执行模型(不需要加锁)
model = self.model_registry.get_model(node.model)
outputs = model.execute(inputs)
# 存储输出(加锁写入上下文)
with self.context_lock:
for output_name in node.outputs:
if output_name in outputs:
workflow_context.set_node_output(
node.id, output_name, outputs[output_name]
)
return {"status": "success", "outputs": list(outputs.keys())}
下面是一个完整的电力报告生成DAG示例,展示上下文如何在各节点间传递:
# 1. 定义DAG
power_report_dag = DAG(name="power_station_monthly_report")
# 2. 定义节点
data_collection = Node(
id="data_collection",
model="power_data_collector",
params={
"time_range": "last_month",
"metrics": ["generation", "consumption", "efficiency", "outages"]
},
outputs=["raw_data", "station_metadata"]
)
data_processing = Node(
id="data_processing",
model="data_processor",
inputs={
"raw_data": "data_collection.raw_data"
},
outputs=["processed_data", "anomalies", "trends"]
)
chart_generation = Node(
id="chart_generation",
model="chart_generator",
inputs={
"data": "data_processing.processed_data",
"trends": "data_processing.trends"
},
outputs=["charts"]
)
policy_analysis = Node(
id="policy_analysis",
model="policy_analyzer",
inputs={
"station_data": "data_collection.station_metadata",
"performance_data": "data_processing.processed_data"
},
outputs=["policy_insights", "compliance_status"]
)
report_writing = Node(
id="report_writing",
model="report_writer",
inputs={
"processed_data": "data_processing.processed_data",
"anomalies": "data_processing.anomalies",
"charts": "chart_generation.charts",
"policy_insights": "policy_analysis.policy_insights",
"compliance_status": "policy_analysis.compliance_status"
},
outputs=["report_content", "executive_summary"]
)
ppt_generation = Node(
id="ppt_generation",
model="ppt_generator",
inputs={
"content": "report_writing.report_content",
"summary": "report_writing.executive_summary",
"charts": "chart_generation.charts"
},
outputs=["presentation"]
)
# 3. 构建DAG
power_report_dag.add_nodes([
data_collection,
data_processing,
chart_generation,
policy_analysis,
report_writing,
ppt_generation
])
# 4. 执行DAG
executor = DAGExecutor(model_registry)
workflow_context = WorkflowContext(workflow_id="monthly_report_2025_03")
# 设置全局参数
workflow_context.set_global("station_id", "POWER-STATION-123")
workflow_context.set_global("report_date", "2025-03-01")
workflow_context.set_global("report_type", "monthly")
# 执行DAG
result = executor.execute_dag(power_report_dag, workflow_context)
# 5. 获取最终结果
presentation = workflow_context.get_node_output("ppt_generation", "presentation")
class ContextInheritanceManager:
def create_child_context(self, parent_context, override_values=None):
"""创建继承父上下文的子上下文"""
child_context = WorkflowContext(f"{parent_context.workflow_id}:child")
# 继承全局变量
child_context.global_vars = dict(parent_context.global_vars)
# 应用覆盖值
if override_values:
for key, value in override_values.items():
child_context.set_global(key, value)
return child_context
class ContextVersionManager:
def __init__(self):
self.snapshots = {}
def create_snapshot(self, context, snapshot_id=None):
"""创建上下文快照"""
if snapshot_id is None:
snapshot_id = f"{context.workflow_id}:snapshot:{uuid.uuid4()}"
# 深拷贝上下文状态
self.snapshots[snapshot_id] = {
"global_vars": copy.deepcopy(context.global_vars),
"node_outputs": copy.deepcopy(context.node_outputs),
"timestamp": time.time()
}
return snapshot_id
def restore_snapshot(self, context, snapshot_id):
"""从快照恢复上下文"""
if snapshot_id not in self.snapshots:
raise ValueError(f"Snapshot not found: {snapshot_id}")
snapshot = self.snapshots[snapshot_id]
# 恢复上下文状态
context.global_vars = copy.deepcopy(snapshot["global_vars"])
context.node_outputs = copy.deepcopy(snapshot["node_outputs"])
return context
class ContextPropagationPolicy:
def __init__(self):
self.propagation_rules = {}
def add_rule(self, source_node, target_node, field_mapping):
"""添加上下文传播规则"""
key = (source_node, target_node)
self.propagation_rules[key] = field_mapping
def apply_rules(self, workflow_context, completed_node_id):
"""应用传播规则"""
for (source, target), mapping in self.propagation_rules.items():
if source == completed_node_id:
# 应用映射规则
for source_field, target_field in mapping.items():
value = workflow_context.get_node_output(source, source_field)
# 可以定义特殊转换逻辑
if isinstance(target_field, dict) and "transform" in target_field:
transform_func = target_field["transform"]
value = transform_func(value)
target_field = target_field["field"]
# 设置目标节点的输入
workflow_context.set_node_input(target, target_field, value)
这些机制结合起来,可以实现电力行业智能体系统中复杂的上下文传递需求,确保各个模型之间协同工作,并保持任务状态的一致性和可靠性。