Merge pull request #3 from hetaoBackend/feat/interrupt

feat: support interrupt protocol
This commit is contained in:
Li Xin 2025-04-16 10:57:00 +08:00 committed by GitHub
commit 3cd84e1ec7
4 changed files with 37 additions and 22 deletions

View File

@ -49,26 +49,20 @@ curl --location 'http://localhost:8000/api/chat/stream' \
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
"content": "what is mcp?" "content": "make the last step be comprehensive"
} }
], ],
"thread_id": "test_thread_2", "thread_id": "test_thread_2",
"auto_accepted_plan": false, "auto_accepted_plan": false,
"feedback": "[EDIT PLAN] make the last step be comprehensive" "interrupt_feedback": "edit_plan"
} }
``` ```
### Accepted the plan ### Accept the plan
``` ```
{ {
"messages": [
{
"role": "user",
"content": "what is mcp?"
}
],
"thread_id": "test_thread_2", "thread_id": "test_thread_2",
"auto_accepted_plan": false, "auto_accepted_plan": false,
"feedback": "[ACCEPTED]" "interrupt_feedback": "accepted"
} }
``` ```

View File

@ -75,10 +75,10 @@ def human_feedback_node(
# check if the plan is auto accepted # check if the plan is auto accepted
auto_accepted_plan = state.get("auto_accepted_plan", False) auto_accepted_plan = state.get("auto_accepted_plan", False)
if not auto_accepted_plan: if not auto_accepted_plan:
feedback = interrupt(current_plan) feedback = interrupt("Please Review the Plan.")
# if the feedback is not accepted, return the planner node # if the feedback is not accepted, return the planner node
if feedback and str(feedback).upper() != "[ACCEPTED]": if feedback and str(feedback).upper().startswith("[EDIT_PLAN]"):
return Command( return Command(
update={ update={
"messages": [ "messages": [
@ -87,7 +87,7 @@ def human_feedback_node(
}, },
goto="planner", goto="planner",
) )
elif feedback and str(feedback).upper() == "[ACCEPTED]": elif feedback and str(feedback).upper().startswith("[ACCEPTED]"):
logger.info("Plan is accepted by user.") logger.info("Plan is accepted by user.")
else: else:
raise TypeError(f"Interrupt value of {feedback} is not supported.") raise TypeError(f"Interrupt value of {feedback} is not supported.")

View File

@ -44,7 +44,7 @@ async def chat_stream(request: ChatRequest):
request.max_plan_iterations, request.max_plan_iterations,
request.max_step_num, request.max_step_num,
request.auto_accepted_plan, request.auto_accepted_plan,
request.feedback, request.interrupt_feedback,
), ),
media_type="text/event-stream", media_type="text/event-stream",
) )
@ -56,11 +56,15 @@ async def _astream_workflow_generator(
max_plan_iterations: int, max_plan_iterations: int,
max_step_num: int, max_step_num: int,
auto_accepted_plan: bool, auto_accepted_plan: bool,
feedback: str, interrupt_feedback: str,
): ):
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan} input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
if not auto_accepted_plan and feedback: if not auto_accepted_plan and interrupt_feedback:
input_ = Command(resume=feedback) resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]["content"]}"
input_ = Command(resume=resume_msg)
async for agent, _, event_data in graph.astream( async for agent, _, event_data in graph.astream(
input_, input_,
config={ config={
@ -68,9 +72,26 @@ async def _astream_workflow_generator(
"max_plan_iterations": max_plan_iterations, "max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num, "max_step_num": max_step_num,
}, },
stream_mode=["messages"], stream_mode=["messages", "updates"],
subgraphs=True, subgraphs=True,
): ):
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
{"text": "Accept", "value": "accepted"},
{"text": "Edit", "value": "edit_plan"},
],
},
)
continue
message_chunk, message_metadata = cast( message_chunk, message_metadata = cast(
tuple[AIMessageChunk, dict[str, any]], event_data tuple[AIMessageChunk, dict[str, any]], event_data
) )

View File

@ -22,8 +22,8 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
messages: List[ChatMessage] = Field( messages: Optional[List[ChatMessage]] = Field(
..., description="History of messages between the user and the assistant" [], description="History of messages between the user and the assistant"
) )
debug: Optional[bool] = Field(False, description="Whether to enable debug logging") debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
thread_id: Optional[str] = Field( thread_id: Optional[str] = Field(
@ -38,6 +38,6 @@ class ChatRequest(BaseModel):
auto_accepted_plan: Optional[bool] = Field( auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan" False, description="Whether to automatically accept the plan"
) )
feedback: Optional[str] = Field( interrupt_feedback: Optional[str] = Field(
None, description="Feedback from the user on the plan" None, description="Interrupt feedback from the user on the plan"
) )