mirror of
https://git.mirrors.martin98.com/https://github.com/bytedance/deer-flow
synced 2025-08-18 06:46:01 +08:00
Merge pull request #3 from hetaoBackend/feat/interrupt
feat: support interrupt protocol
This commit is contained in:
commit
3cd84e1ec7
@ -49,26 +49,20 @@ curl --location 'http://localhost:8000/api/chat/stream' \
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is mcp?"
|
||||
"content": "make the last step be comprehensive"
|
||||
}
|
||||
],
|
||||
"thread_id": "test_thread_2",
|
||||
"auto_accepted_plan": false,
|
||||
"feedback": "[EDIT PLAN] make the last step be comprehensive"
|
||||
"interrupt_feedback": "edit_plan"
|
||||
}
|
||||
```
|
||||
|
||||
### Accepted the plan
|
||||
### Accept the plan
|
||||
```
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is mcp?"
|
||||
}
|
||||
],
|
||||
"thread_id": "test_thread_2",
|
||||
"auto_accepted_plan": false,
|
||||
"feedback": "[ACCEPTED]"
|
||||
"interrupt_feedback": "accepted"
|
||||
}
|
||||
```
|
@ -75,10 +75,10 @@ def human_feedback_node(
|
||||
# check if the plan is auto accepted
|
||||
auto_accepted_plan = state.get("auto_accepted_plan", False)
|
||||
if not auto_accepted_plan:
|
||||
feedback = interrupt(current_plan)
|
||||
feedback = interrupt("Please Review the Plan.")
|
||||
|
||||
# if the feedback is not accepted, return the planner node
|
||||
if feedback and str(feedback).upper() != "[ACCEPTED]":
|
||||
if feedback and str(feedback).upper().startswith("[EDIT_PLAN]"):
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
@ -87,7 +87,7 @@ def human_feedback_node(
|
||||
},
|
||||
goto="planner",
|
||||
)
|
||||
elif feedback and str(feedback).upper() == "[ACCEPTED]":
|
||||
elif feedback and str(feedback).upper().startswith("[ACCEPTED]"):
|
||||
logger.info("Plan is accepted by user.")
|
||||
else:
|
||||
raise TypeError(f"Interrupt value of {feedback} is not supported.")
|
||||
|
@ -44,7 +44,7 @@ async def chat_stream(request: ChatRequest):
|
||||
request.max_plan_iterations,
|
||||
request.max_step_num,
|
||||
request.auto_accepted_plan,
|
||||
request.feedback,
|
||||
request.interrupt_feedback,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@ -56,11 +56,15 @@ async def _astream_workflow_generator(
|
||||
max_plan_iterations: int,
|
||||
max_step_num: int,
|
||||
auto_accepted_plan: bool,
|
||||
feedback: str,
|
||||
interrupt_feedback: str,
|
||||
):
|
||||
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
|
||||
if not auto_accepted_plan and feedback:
|
||||
input_ = Command(resume=feedback)
|
||||
if not auto_accepted_plan and interrupt_feedback:
|
||||
resume_msg = f"[{interrupt_feedback}]"
|
||||
# add the last message to the resume message
|
||||
if messages:
|
||||
resume_msg += f" {messages[-1]["content"]}"
|
||||
input_ = Command(resume=resume_msg)
|
||||
async for agent, _, event_data in graph.astream(
|
||||
input_,
|
||||
config={
|
||||
@ -68,9 +72,26 @@ async def _astream_workflow_generator(
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
},
|
||||
stream_mode=["messages"],
|
||||
stream_mode=["messages", "updates"],
|
||||
subgraphs=True,
|
||||
):
|
||||
if isinstance(event_data, dict):
|
||||
if "__interrupt__" in event_data:
|
||||
yield _make_event(
|
||||
"interrupt",
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"id": event_data["__interrupt__"][0].ns[0],
|
||||
"role": "assistant",
|
||||
"content": event_data["__interrupt__"][0].value,
|
||||
"finish_reason": "interrupt",
|
||||
"options": [
|
||||
{"text": "Accept", "value": "accepted"},
|
||||
{"text": "Edit", "value": "edit_plan"},
|
||||
],
|
||||
},
|
||||
)
|
||||
continue
|
||||
message_chunk, message_metadata = cast(
|
||||
tuple[AIMessageChunk, dict[str, any]], event_data
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ class ChatMessage(BaseModel):
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage] = Field(
|
||||
..., description="History of messages between the user and the assistant"
|
||||
messages: Optional[List[ChatMessage]] = Field(
|
||||
[], description="History of messages between the user and the assistant"
|
||||
)
|
||||
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
|
||||
thread_id: Optional[str] = Field(
|
||||
@ -38,6 +38,6 @@ class ChatRequest(BaseModel):
|
||||
auto_accepted_plan: Optional[bool] = Field(
|
||||
False, description="Whether to automatically accept the plan"
|
||||
)
|
||||
feedback: Optional[str] = Field(
|
||||
None, description="Feedback from the user on the plan"
|
||||
interrupt_feedback: Optional[str] = Field(
|
||||
None, description="Interrupt feedback from the user on the plan"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user