mirror of
https://git.mirrors.martin98.com/https://github.com/bytedance/deer-flow
synced 2025-08-18 03:35:53 +08:00
Merge pull request #1 from hetaoBackend/feat/server
feat: implement basic server logic
This commit is contained in:
commit
a759c168fa
29
server.py
Normal file
29
server.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Server script for running the Lite Deep Research API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting Lite Deep Research API server")
|
||||||
|
reload = True
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
reload = False
|
||||||
|
uvicorn.run(
|
||||||
|
"src.server:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
reload=reload,
|
||||||
|
log_level="info",
|
||||||
|
)
|
3
src/server/__init__.py
Normal file
3
src/server/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .app import app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
107
src/server/app.py
Normal file
107
src/server/app.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List, cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||||
|
|
||||||
|
from src.graph.builder import build_graph
|
||||||
|
from src.server.chat_request import ChatMessage, ChatRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Lite Deep Research API",
|
||||||
|
description="API for Lite Deep Research",
|
||||||
|
version="0.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allows all origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all methods
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = build_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/chat/stream")
|
||||||
|
async def chat_stream(request: ChatRequest):
|
||||||
|
thread_id = request.thread_id
|
||||||
|
if thread_id == "__default__":
|
||||||
|
thread_id = str(uuid4())
|
||||||
|
return StreamingResponse(
|
||||||
|
_astream_workflow_generator(
|
||||||
|
request.model_dump()["messages"],
|
||||||
|
thread_id,
|
||||||
|
request.max_plan_iterations,
|
||||||
|
request.max_step_num,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _astream_workflow_generator(
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
thread_id: str,
|
||||||
|
max_plan_iterations: int,
|
||||||
|
max_step_num: int,
|
||||||
|
):
|
||||||
|
async for agent, _, event_data in graph.astream(
|
||||||
|
{"messages": messages},
|
||||||
|
config={
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"max_plan_iterations": max_plan_iterations,
|
||||||
|
"max_step_num": max_step_num,
|
||||||
|
},
|
||||||
|
stream_mode=["messages"],
|
||||||
|
subgraphs=True,
|
||||||
|
):
|
||||||
|
message_chunk, message_metadata = cast(
|
||||||
|
tuple[AIMessageChunk, dict[str, any]], event_data
|
||||||
|
)
|
||||||
|
event_stream_message: dict[str, any] = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"agent": agent[0].split(":")[0],
|
||||||
|
"id": message_chunk.id,
|
||||||
|
"role": "assistant",
|
||||||
|
"content": message_chunk.content,
|
||||||
|
}
|
||||||
|
if message_chunk.response_metadata.get("finish_reason"):
|
||||||
|
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
||||||
|
"finish_reason"
|
||||||
|
)
|
||||||
|
if isinstance(message_chunk, ToolMessage):
|
||||||
|
# Tool Message - Return the result of the tool call
|
||||||
|
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||||
|
yield _make_event("tool_call_result", event_stream_message)
|
||||||
|
else:
|
||||||
|
# AI Message - Raw message tokens
|
||||||
|
if message_chunk.tool_calls:
|
||||||
|
# AI Message - Tool Call
|
||||||
|
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||||
|
event_stream_message["tool_call_chunks"] = (
|
||||||
|
message_chunk.tool_call_chunks
|
||||||
|
)
|
||||||
|
yield _make_event("tool_calls", event_stream_message)
|
||||||
|
elif message_chunk.tool_call_chunks:
|
||||||
|
# AI Message - Tool Call Chunks
|
||||||
|
event_stream_message["tool_call_chunks"] = (
|
||||||
|
message_chunk.tool_call_chunks
|
||||||
|
)
|
||||||
|
yield _make_event("tool_call_chunks", event_stream_message)
|
||||||
|
else:
|
||||||
|
# AI Message - Raw message tokens
|
||||||
|
yield _make_event("message_chunk", event_stream_message)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event(event_type: str, data: dict[str, any]):
|
||||||
|
if data.get("content") == "":
|
||||||
|
data.pop("content")
|
||||||
|
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
37
src/server/chat_request.py
Normal file
37
src/server/chat_request.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ContentItem(BaseModel):
|
||||||
|
type: str = Field(..., description="The type of content (text, image, etc.)")
|
||||||
|
text: Optional[str] = Field(None, description="The text content if type is 'text'")
|
||||||
|
image_url: Optional[str] = Field(
|
||||||
|
None, description="The image URL if type is 'image'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str = Field(
|
||||||
|
..., description="The role of the message sender (user or assistant)"
|
||||||
|
)
|
||||||
|
content: Union[str, List[ContentItem]] = Field(
|
||||||
|
...,
|
||||||
|
description="The content of the message, either a string or a list of content items",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
messages: List[ChatMessage] = Field(
|
||||||
|
..., description="History of messages between the user and the assistant"
|
||||||
|
)
|
||||||
|
debug: Optional[bool] = Field(False, description="Whether to enable debug logging")
|
||||||
|
thread_id: Optional[str] = Field(
|
||||||
|
"__default__", description="A specific conversation identifier"
|
||||||
|
)
|
||||||
|
max_plan_iterations: Optional[int] = Field(
|
||||||
|
1, description="The maximum number of plan iterations"
|
||||||
|
)
|
||||||
|
max_step_num: Optional[int] = Field(
|
||||||
|
3, description="The maximum number of steps in a plan"
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user