From 84defafc14f3569d9c1169b18a8e629dba108981 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 9 Jun 2024 13:17:44 -0700 Subject: [PATCH] feat: unified chat completions endpoint --- backend/apps/ollama/main.py | 7 +++++- backend/main.py | 34 +++++++++++++++++++++++++++-- src/lib/components/chat/Chat.svelte | 4 +--- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index d406f0670..9ada17262 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -849,9 +849,14 @@ async def generate_chat_completion( # TODO: we should update this part once Ollama supports other types +class OpenAIChatMessageContent(BaseModel): + type: str + model_config = ConfigDict(extra="allow") + + class OpenAIChatMessage(BaseModel): role: str - content: str + content: Union[str, OpenAIChatMessageContent] model_config = ConfigDict(extra="allow") diff --git a/backend/main.py b/backend/main.py index d7fa940ff..ff87b3da7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -25,8 +25,17 @@ from starlette.responses import StreamingResponse, Response from apps.socket.main import app as socket_app -from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models -from apps.openai.main import app as openai_app, get_all_models as get_openai_models +from apps.ollama.main import ( + app as ollama_app, + OpenAIChatCompletionForm, + get_all_models as get_ollama_models, + generate_openai_chat_completion as generate_ollama_chat_completion, +) +from apps.openai.main import ( + app as openai_app, + get_all_models as get_openai_models, + generate_chat_completion as generate_openai_chat_completion, +) from apps.audio.main import app as audio_app from apps.images.main import app as images_app @@ -485,6 +494,27 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + + print(model) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**form_data), user=user + ) + else: + return await generate_openai_chat_completion(form_data, user=user) + + @app.post("/api/chat/completed") async def chat_completed(form_data: dict, user=Depends(get_verified_user)): data = form_data diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 939e4fc24..8aad3ff48 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1134,9 +1134,7 @@ titleModelId, userPrompt, $chatId, - titleModel?.owned_by === 'openai' ?? false - ? `${OPENAI_API_BASE_URL}` - : `${OLLAMA_API_BASE_URL}/v1` + `${WEBUI_BASE_URL}/api` ); return title;