Merge pull request #3 from hetaoBackend/feat/interrupt

feat: support interrupt protocol
This commit is contained in:
Li Xin 2025-04-16 10:57:00 +08:00 committed by GitHub
commit 3cd84e1ec7
4 changed files with 37 additions and 22 deletions

View File

@ -49,26 +49,20 @@ curl --location 'http://localhost:8000/api/chat/stream' \
"messages": [
{
"role": "user",
"content": "what is mcp?"
"content": "make the last step be comprehensive"
}
],
"thread_id": "test_thread_2",
"auto_accepted_plan": false,
"feedback": "[EDIT PLAN] make the last step be comprehensive"
"interrupt_feedback": "edit_plan"
}
```
### Accepted the plan
### Accept the plan
```
{
"messages": [
{
"role": "user",
"content": "what is mcp?"
}
],
"thread_id": "test_thread_2",
"auto_accepted_plan": false,
"feedback": "[ACCEPTED]"
"interrupt_feedback": "accepted"
}
```

View File

@ -75,10 +75,10 @@ def human_feedback_node(
# check if the plan is auto accepted
auto_accepted_plan = state.get("auto_accepted_plan", False)
if not auto_accepted_plan:
feedback = interrupt(current_plan)
feedback = interrupt("Please Review the Plan.")
# if the feedback is not accepted, return the planner node
if feedback and str(feedback).upper() != "[ACCEPTED]":
if feedback and str(feedback).upper().startswith("[EDIT_PLAN]"):
return Command(
update={
"messages": [
@ -87,7 +87,7 @@ def human_feedback_node(
},
goto="planner",
)
elif feedback and str(feedback).upper() == "[ACCEPTED]":
elif feedback and str(feedback).upper().startswith("[ACCEPTED]"):
logger.info("Plan is accepted by user.")
else:
raise TypeError(f"Interrupt value of {feedback} is not supported.")

View File

@ -44,7 +44,7 @@ async def chat_stream(request: ChatRequest):
request.max_plan_iterations,
request.max_step_num,
request.auto_accepted_plan,
request.feedback,
request.interrupt_feedback,
),
media_type="text/event-stream",
)
@ -56,11 +56,15 @@ async def _astream_workflow_generator(
max_plan_iterations: int,
max_step_num: int,
auto_accepted_plan: bool,
feedback: str,
interrupt_feedback: str,
):
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
if not auto_accepted_plan and feedback:
input_ = Command(resume=feedback)
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]["content"]}"
input_ = Command(resume=resume_msg)
async for agent, _, event_data in graph.astream(
input_,
config={
@ -68,9 +72,26 @@ async def _astream_workflow_generator(
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
},
stream_mode=["messages"],
stream_mode=["messages", "updates"],
subgraphs=True,
):
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
{"text": "Accept", "value": "accepted"},
{"text": "Edit", "value": "edit_plan"},
],
},
)
continue
message_chunk, message_metadata = cast(
tuple[AIMessageChunk, dict[str, any]], event_data
)

View File

@ -22,8 +22,8 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel):
messages: List[ChatMessage] = Field(
..., description="History of messages between the user and the assistant"
messages: Optional[List[ChatMessage]] = Field(
[], description="History of messages between the user and the assistant"
)
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
thread_id: Optional[str] = Field(
@ -38,6 +38,6 @@ class ChatRequest(BaseModel):
auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan"
)
feedback: Optional[str] = Field(
None, description="Feedback from the user on the plan"
interrupt_feedback: Optional[str] = Field(
None, description="Interrupt feedback from the user on the plan"
)