DAG任务流中的上下文状态传递机制

在基于DAG的任务编排系统中,上下文状态传递是确保各节点协同工作的关键。以下是详细的状态传递机制设计:

上下文传递的核心概念

DAG中的状态传递主要解决三个问题:

  1. 如何将上游节点输出传递给下游节点
  2. 如何维护全局上下文和任务状态
  3. 如何处理并行执行中的状态同步

1. 数据流定义机制

显式数据依赖定义

# DAG定义示例
dag = DAG(name="power_report_generation")

# 节点定义
text_analysis = Node(
    id="text_analysis",
    model="text_understanding_model",
    outputs=["key_points", "entities", "sentiment"]
)

data_processing = Node(
    id="data_processing",
    model="power_data_processor",
    outputs=["charts", "metrics", "trends"]
)

report_writing = Node(
    id="report_writing",
    model="report_generator",
    inputs={
        "key_points": "text_analysis.key_points",
        "entities": "text_analysis.entities",
        "charts": "data_processing.charts",
        "metrics": "data_processing.metrics"
    },
    outputs=["report_sections", "executive_summary"]
)

# 添加节点到DAG
dag.add_node(text_analysis)
dag.add_node(data_processing)
dag.add_node(report_writing)

# 定义边(可选,因为已经通过inputs隐式定义了)
dag.add_edge(text_analysis, report_writing)
dag.add_edge(data_processing, report_writing)

在这个例子中:

2. 上下文管理系统

分层上下文结构

- 全局上下文 (GlobalContext)
  ├── 任务级上下文 (TaskContext)
  │   ├── 节点上下文1 (NodeContext)
  │   ├── 节点上下文2 (NodeContext)
  │   └── ...
  └── 系统级上下文 (SystemContext)

实现代码

class WorkflowContext:
    def __init__(self, workflow_id):
        self.workflow_id = workflow_id
        self.global_vars = {}  # 全局变量
        self.node_outputs = {}  # 节点输出结果
        self.execution_state = {}  # 执行状态
        self.metadata = {}  # 元数据
        
    def set_node_output(self, node_id, output_name, value):
        """存储节点的特定输出"""
        if node_id not in self.node_outputs:
            self.node_outputs[node_id] = {}
        self.node_outputs[node_id][output_name] = value
        
    def get_node_output(self, node_id, output_name):
        """获取节点的特定输出"""
        return self.node_outputs.get(node_id, {}).get(output_name)
        
    def get_input_value(self, input_ref):
        """根据引用路径获取输入值
        例如: "text_analysis.key_points"
        """
        if "." not in input_ref:
            return self.global_vars.get(input_ref)
            
        node_id, output_name = input_ref.split(".", 1)
        return self.get_node_output(node_id, output_name)
        
    def prepare_node_inputs(self, node):
        """准备节点所需的所有输入"""
        inputs = {}
        for input_name, input_ref in node.inputs.items():
            inputs[input_name] = self.get_input_value(input_ref)
        return inputs
        
    def set_global(self, key, value):
        """设置全局变量"""
        self.global_vars[key] = value
        
    def update_execution_state(self, node_id, state):
        """更新节点执行状态"""
        self.execution_state[node_id] = {
            "state": state,
            "updated_at": time.time()
        }

3. 执行引擎中的状态传递

class DAGExecutionEngine:
    def __init__(self, model_registry, storage_service):
        self.model_registry = model_registry
        self.storage = storage_service
        
    def execute_node(self, workflow_context, node):
        # 1. 准备节点输入
        inputs = workflow_context.prepare_node_inputs(node)
        
        # 2. 获取模型
        model = self.model_registry.get_model(node.model)
        
        # 3. 执行模型
        workflow_context.update_execution_state(node.id, "RUNNING")
        try:
            outputs = model.execute(inputs)
            
            # 4. 存储输出
            for output_name in node.outputs:
                if output_name in outputs:
                    workflow_context.set_node_output(
                        node.id, output_name, outputs[output_name]
                    )
                else:
                    raise ValueError(f"Model did not produce expected output: {output_name}")
                    
            workflow_context.update_execution_state(node.id, "COMPLETED")
            
            # 5. 处理大型输出的存储
            self._handle_large_outputs(workflow_context, node, outputs)
            
            return True
            
        except Exception as e:
            workflow_context.update_execution_state(node.id, "FAILED")
            workflow_context.set_node_output(node.id, "error", str(e))
            return False
            
    def _handle_large_outputs(self, workflow_context, node, outputs):
        """处理大型输出,例如生成的图像或文档"""
        for output_name in node.outputs:
            if output_name in outputs and self._is_large_output(outputs[output_name]):
                # 存储到外部存储服务
                storage_path = f"workflow/{workflow_context.workflow_id}/node/{node.id}/{output_name}"
                reference = self.storage.store(storage_path, outputs[output_name])
                
                # 替换上下文中的值为引用
                workflow_context.set_node_output(node.id, output_name, {
                    "type": "reference",
                    "uri": reference
                })

4. 复杂数据类型和结构化上下文

为电力行业的特定需求,可以定义结构化的上下文模式:

class PowerAnalysisContext(TypedDict):
    """电力分析上下文数据结构"""
    power_station_id: str
    time_period: Dict[str, str]  # start_date, end_date
    metrics: List[str]
    weather_conditions: Optional[Dict]
    generation_data: Optional[Dict]
    consumption_data: Optional[Dict]
    anomalies: Optional[List[Dict]]
    recommendations: Optional[List[Dict]]

使用时:

# 类型检查和验证
def validate_power_context(context: PowerAnalysisContext) -> bool:
    required_fields = ['power_station_id', 'time_period', 'metrics']
    return all(field in context and context[field] for field in required_fields)

# 节点使用类型化上下文
class PowerAnalysisNode(Node):
    def process(self, context: WorkflowContext):
        # 从上下文获取电力分析数据
        power_data: PowerAnalysisContext = context.get_typed('power_analysis')
        
        # 验证数据
        if not validate_power_context(power_data):
            raise ValueError("Invalid power analysis context")
            
        # 处理并更新上下文
        analysis_results = self.model.analyze(power_data)
        
        # 更新上下文中的特定字段
        updated_power_data = {**power_data, **analysis_results}
        context.set_typed('power_analysis', updated_power_data)

5. 上下文传递模式

直接依赖模式

最基本的上下文传递方式:

A → B → C

节点B直接依赖节点A的输出,节点C直接依赖节点B的输出。

选择性依赖模式

下游节点可以选择性地只使用上游节点的部分输出:

report_generator = Node(
    id="report_generator",
    inputs={
        # 只使用数据处理节点的图表输出,忽略其他输出
        "charts": "data_processing.charts"
    }
)

合并依赖模式

一个节点可以同时依赖多个上游节点的输出:

    A
   / \
  B   C
   \ /
    D

节点D合并来自B和C的输出:

summary_node = Node(
    id="summary_generation",
    inputs={
        "text_insights": "text_analysis.insights",
        "data_insights": "data_analysis.insights",
        "recommended_actions": "recommendation_engine.actions"
    }
)

条件依赖模式

基于某些条件动态决定使用哪些上游输出:

class ConditionalNode(Node):
    def prepare_inputs(self, context):
        # 基于上下文动态选择输入源
        if context.get_global("report_type") == "executive":
            return {
                "content": context.get_node_output("executive_summary", "content")
            }
        else:
            return {
                "content": context.get_node_output("technical_analysis", "content")
            }

6. 并行执行中的状态同步

在并行执行模式下,状态同步尤为重要:

class ParallelExecutionManager:
    def __init__(self, max_workers=10):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.context_lock = threading.RLock()
        
    def execute_parallel_nodes(self, workflow_context, nodes):
        futures = {}
        results = {}
        
        # 提交所有可并行执行的节点
        for node in nodes:
            future = self.executor.submit(
                self._execute_node_safe, workflow_context, node
            )
            futures[future] = node.id
            
        # 等待所有执行完成
        for future in as_completed(futures):
            node_id = futures[future]
            try:
                result = future.result()
                results[node_id] = result
            except Exception as e:
                results[node_id] = {"status": "error", "error": str(e)}
                
        return results
        
    def _execute_node_safe(self, workflow_context, node):
        # 准备节点输入(加锁读取上下文)
        with self.context_lock:
            inputs = workflow_context.prepare_node_inputs(node)
            
        # 执行模型(不需要加锁)
        model = self.model_registry.get_model(node.model)
        outputs = model.execute(inputs)
        
        # 存储输出(加锁写入上下文)
        with self.context_lock:
            for output_name in node.outputs:
                if output_name in outputs:
                    workflow_context.set_node_output(
                        node.id, output_name, outputs[output_name]
                    )
                    
        return {"status": "success", "outputs": list(outputs.keys())}

7. 实际应用示例:电力报告生成DAG

下面是一个完整的电力报告生成DAG示例,展示上下文如何在各节点间传递:

# 1. 定义DAG
power_report_dag = DAG(name="power_station_monthly_report")

# 2. 定义节点
data_collection = Node(
    id="data_collection",
    model="power_data_collector",
    params={
        "time_range": "last_month",
        "metrics": ["generation", "consumption", "efficiency", "outages"]
    },
    outputs=["raw_data", "station_metadata"]
)

data_processing = Node(
    id="data_processing",
    model="data_processor",
    inputs={
        "raw_data": "data_collection.raw_data"
    },
    outputs=["processed_data", "anomalies", "trends"]
)

chart_generation = Node(
    id="chart_generation",
    model="chart_generator",
    inputs={
        "data": "data_processing.processed_data",
        "trends": "data_processing.trends"
    },
    outputs=["charts"]
)

policy_analysis = Node(
    id="policy_analysis",
    model="policy_analyzer",
    inputs={
        "station_data": "data_collection.station_metadata",
        "performance_data": "data_processing.processed_data"
    },
    outputs=["policy_insights", "compliance_status"]
)

report_writing = Node(
    id="report_writing",
    model="report_writer",
    inputs={
        "processed_data": "data_processing.processed_data",
        "anomalies": "data_processing.anomalies",
        "charts": "chart_generation.charts",
        "policy_insights": "policy_analysis.policy_insights",
        "compliance_status": "policy_analysis.compliance_status"
    },
    outputs=["report_content", "executive_summary"]
)

ppt_generation = Node(
    id="ppt_generation",
    model="ppt_generator",
    inputs={
        "content": "report_writing.report_content",
        "summary": "report_writing.executive_summary",
        "charts": "chart_generation.charts"
    },
    outputs=["presentation"]
)

# 3. 构建DAG
power_report_dag.add_nodes([
    data_collection,
    data_processing,
    chart_generation,
    policy_analysis,
    report_writing,
    ppt_generation
])

# 4. 执行DAG
executor = DAGExecutor(model_registry)
workflow_context = WorkflowContext(workflow_id="monthly_report_2025_03")

# 设置全局参数
workflow_context.set_global("station_id", "POWER-STATION-123")
workflow_context.set_global("report_date", "2025-03-01")
workflow_context.set_global("report_type", "monthly")

# 执行DAG
result = executor.execute_dag(power_report_dag, workflow_context)

# 5. 获取最终结果
presentation = workflow_context.get_node_output("ppt_generation", "presentation")

8. 高级上下文传递功能

上下文继承与覆盖

class ContextInheritanceManager:
    def create_child_context(self, parent_context, override_values=None):
        """创建继承父上下文的子上下文"""
        child_context = WorkflowContext(f"{parent_context.workflow_id}:child")
        
        # 继承全局变量
        child_context.global_vars = dict(parent_context.global_vars)
        
        # 应用覆盖值
        if override_values:
            for key, value in override_values.items():
                child_context.set_global(key, value)
                
        return child_context

上下文快照与回滚

class ContextVersionManager:
    def __init__(self):
        self.snapshots = {}
        
    def create_snapshot(self, context, snapshot_id=None):
        """创建上下文快照"""
        if snapshot_id is None:
            snapshot_id = f"{context.workflow_id}:snapshot:{uuid.uuid4()}"
            
        # 深拷贝上下文状态
        self.snapshots[snapshot_id] = {
            "global_vars": copy.deepcopy(context.global_vars),
            "node_outputs": copy.deepcopy(context.node_outputs),
            "timestamp": time.time()
        }
        
        return snapshot_id
        
    def restore_snapshot(self, context, snapshot_id):
        """从快照恢复上下文"""
        if snapshot_id not in self.snapshots:
            raise ValueError(f"Snapshot not found: {snapshot_id}")
            
        snapshot = self.snapshots[snapshot_id]
        
        # 恢复上下文状态
        context.global_vars = copy.deepcopy(snapshot["global_vars"])
        context.node_outputs = copy.deepcopy(snapshot["node_outputs"])
        
        return context

上下文传播策略

class ContextPropagationPolicy:
    def __init__(self):
        self.propagation_rules = {}
        
    def add_rule(self, source_node, target_node, field_mapping):
        """添加上下文传播规则"""
        key = (source_node, target_node)
        self.propagation_rules[key] = field_mapping
        
    def apply_rules(self, workflow_context, completed_node_id):
        """应用传播规则"""
        for (source, target), mapping in self.propagation_rules.items():
            if source == completed_node_id:
                # 应用映射规则
                for source_field, target_field in mapping.items():
                    value = workflow_context.get_node_output(source, source_field)
                    
                    # 可以定义特殊转换逻辑
                    if isinstance(target_field, dict) and "transform" in target_field:
                        transform_func = target_field["transform"]
                        value = transform_func(value)
                        target_field = target_field["field"]
                        
                    # 设置目标节点的输入
                    workflow_context.set_node_input(target, target_field, value)

这些机制结合起来,可以实现电力行业智能体系统中复杂的上下文传递需求,确保各个模型之间协同工作,并保持任务状态的一致性和可靠性。