diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index e021b9dda..303be41c6 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -28,7 +28,6 @@ import os import json import requests import asyncio -from rag.svr.jina_server import Prompt,Generation class Base(ABC): def __init__(self, key, model_name, base_url): @@ -413,6 +412,7 @@ class LocalLLM(Base): self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): + from rag.svr.jina_server import Prompt,Generation if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -420,6 +420,7 @@ class LocalLLM(Base): return Prompt(message=history, gen_conf=gen_conf) def _stream_response(self, endpoint, prompt): + from rag.svr.jina_server import Prompt,Generation answer = "" try: res = self.client.stream_doc(