diff --git a/server.py b/server.py new file mode 100644 index 0000000..7e1a39c --- /dev/null +++ b/server.py @@ -0,0 +1,29 @@ +""" +Server script for running the Lite Deep Research API. +""" + +import logging +import sys + +import uvicorn + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + logger.info("Starting Lite Deep Research API server") + reload = True + if sys.platform.startswith("win"): + reload = False + uvicorn.run( + "src.server:app", + host="0.0.0.0", + port=8000, + reload=reload, + log_level="info", + ) diff --git a/src/server/__init__.py b/src/server/__init__.py new file mode 100644 index 0000000..34f275e --- /dev/null +++ b/src/server/__init__.py @@ -0,0 +1,3 @@ +from .app import app + +__all__ = ["app"] diff --git a/src/server/app.py b/src/server/app.py new file mode 100644 index 0000000..853f4a9 --- /dev/null +++ b/src/server/app.py @@ -0,0 +1,107 @@ +import json +import logging +from typing import List, cast +from uuid import uuid4 + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from langchain_core.messages import AIMessageChunk, ToolMessage + +from src.graph.builder import build_graph +from src.server.chat_request import ChatMessage, ChatRequest + +logger = logging.getLogger(__name__) + +app = FastAPI( + title="Lite Deep Research API", + description="API for Lite Deep Research", + version="0.1.0", +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +graph = build_graph() + + +@app.post("/api/chat/stream") +async def chat_stream(request: ChatRequest): + thread_id = request.thread_id + if thread_id == "__default__": + thread_id = str(uuid4()) + return StreamingResponse( + _astream_workflow_generator( + request.model_dump()["messages"], + thread_id, + request.max_plan_iterations, + request.max_step_num, + ), + media_type="text/event-stream", + ) + + +async def _astream_workflow_generator( + messages: List[ChatMessage], + thread_id: str, + max_plan_iterations: int, + max_step_num: int, +): + async for agent, _, event_data in graph.astream( + {"messages": messages}, + config={ + "thread_id": thread_id, + "max_plan_iterations": max_plan_iterations, + "max_step_num": max_step_num, + }, + stream_mode=["messages"], + subgraphs=True, + ): + message_chunk, message_metadata = cast( + tuple[AIMessageChunk, dict[str, any]], event_data + ) + event_stream_message: dict[str, any] = { + "thread_id": thread_id, + "agent": agent[0].split(":")[0], + "id": message_chunk.id, + "role": "assistant", + "content": message_chunk.content, + } + if message_chunk.response_metadata.get("finish_reason"): + event_stream_message["finish_reason"] = message_chunk.response_metadata.get( + "finish_reason" + ) + if isinstance(message_chunk, ToolMessage): + # Tool Message - Return the result of the tool call + event_stream_message["tool_call_id"] = message_chunk.tool_call_id + yield _make_event("tool_call_result", event_stream_message) + else: + # AI Message - Raw message tokens + if message_chunk.tool_calls: + # AI Message - Tool Call + event_stream_message["tool_calls"] = message_chunk.tool_calls + event_stream_message["tool_call_chunks"] = ( + message_chunk.tool_call_chunks + ) + yield _make_event("tool_calls", event_stream_message) + elif message_chunk.tool_call_chunks: + # AI Message - Tool Call Chunks + event_stream_message["tool_call_chunks"] = ( + message_chunk.tool_call_chunks + ) + yield _make_event("tool_call_chunks", event_stream_message) + else: + # AI Message - Raw message tokens + yield _make_event("message_chunk", event_stream_message) + + +def _make_event(event_type: str, data: dict[str, any]): + if data.get("content") == "": + data.pop("content") + return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" diff --git a/src/server/chat_request.py b/src/server/chat_request.py new file mode 100644 index 0000000..e14a088 --- /dev/null +++ b/src/server/chat_request.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + + +class ContentItem(BaseModel): + type: str = Field(..., description="The type of content (text, image, etc.)") + text: Optional[str] = Field(None, description="The text content if type is 'text'") + image_url: Optional[str] = Field( + None, description="The image URL if type is 'image'" + ) + + +class ChatMessage(BaseModel): + role: str = Field( + ..., description="The role of the message sender (user or assistant)" + ) + content: Union[str, List[ContentItem]] = Field( + ..., + description="The content of the message, either a string or a list of content items", + ) + + +class ChatRequest(BaseModel): + messages: List[ChatMessage] = Field( + ..., description="History of messages between the user and the assistant" + ) + debug: Optional[bool] = Field(False, description="Whether to enable debug logging") + thread_id: Optional[str] = Field( + "__default__", description="A specific conversation identifier" + ) + max_plan_iterations: Optional[int] = Field( + 1, description="The maximum number of plan iterations" + ) + max_step_num: Optional[int] = Field( + 3, description="The maximum number of steps in a plan" + )