diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 2e117cadf..24dfda0b3 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from api.db import StatusEnum, LLMType from api.db.db_models import TenantLLM from api.utils.api_utils import get_json_result -from rag.llm import EmbeddingModel, ChatModel, RerankModel +from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel @manager.route('/factories', methods=['GET']) @@ -126,6 +126,9 @@ def add_llm(): api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \ f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \ f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}' + elif factory == "LocalAI": + llm_name = req["llm_name"]+"___LocalAI" + api_key = "xxxxxxxxxxxxxxx" else: llm_name = req["llm_name"] api_key = "xxxxxxxxxxxxxxx" @@ -176,6 +179,21 @@ def add_llm(): except Exception as e: msg += f"\nFail to access model({llm['llm_name']})." + str( e) + elif llm["model_type"] == LLMType.IMAGE2TEXT.value: + mdl = CvModel[factory]( + key=None, model_name=llm["llm_name"], base_url=llm["api_base"] + ) + try: + img_url = ( + "https://upload.wikimedia.org/wikipedia/comm" + "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256" + "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + ) + m, tc = mdl.describe(img_url) + if not tc: + raise Exception(m) + except Exception as e: + msg += f"\nFail to access model({llm['llm_name']})." + str(e) else: # TODO: check other type of models pass diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 0a438cba5..ff774f6d9 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -157,6 +157,13 @@ "status": "1", "llm": [] }, + { + "name": "LocalAI", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "status": "1", + "llm": [] + }, { "name": "Moonshot", "logo": "", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 2833319a4..ede0736b7 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -21,6 +21,7 @@ from .rerank_model import * EmbeddingModel = { "Ollama": OllamaEmbed, + "LocalAI": LocalAIEmbed, "OpenAI": OpenAIEmbed, "Azure-OpenAI": AzureEmbed, "Xinference": XinferenceEmbed, @@ -46,7 +47,8 @@ CvModel = { "ZHIPU-AI": Zhipu4V, "Moonshot": LocalCV, 'Gemini':GeminiCV, - 'OpenRouter':OpenRouterCV + 'OpenRouter':OpenRouterCV, + "LocalAI":LocalAICV } @@ -56,6 +58,7 @@ ChatModel = { "ZHIPU-AI": ZhipuChat, "Tongyi-Qianwen": QWenChat, "Ollama": OllamaChat, + "LocalAI": LocalAIChat, "Xinference": XinferenceChat, "Moonshot": MoonshotChat, "DeepSeek": DeepSeekChat, @@ -67,7 +70,7 @@ ChatModel = { 'Gemini' : GeminiChat, "Bedrock": BedrockChat, "Groq": GroqChat, - 'OpenRouter':OpenRouterChat + 'OpenRouter':OpenRouterChat, } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cf00a7fc5..fdcc2db9a 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -348,6 +348,82 @@ class OllamaChat(Base): yield 0 +class LocalAIChat(Base): + def __init__(self, key, model_name, base_url): + if base_url[-1] == "/": + base_url = base_url[:-1] + self.base_url = base_url + "/v1/chat/completions" + self.model_name = model_name.split("___")[0] + + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_tokens"]: + del gen_conf[k] + headers = { + "Content-Type": "application/json", + } + payload = json.dumps( + {"model": self.model_name, "messages": history, **gen_conf} + ) + try: + response = requests.request( + "POST", url=self.base_url, headers=headers, data=payload + ) + response = response.json() + ans = response["choices"][0]["message"]["content"].strip() + if response["choices"][0]["finish_reason"] == "length": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + return ans, response["usage"]["total_tokens"] + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + total_tokens = 0 + try: + headers = { + "Content-Type": "application/json", + } + payload = json.dumps( + { + "model": self.model_name, + "messages": history, + "stream": True, + **gen_conf, + } + ) + response = requests.request( + "POST", + url=self.base_url, + headers=headers, + data=payload, + ) + for resp in response.content.decode("utf-8").split("\n\n"): + if "choices" not in resp: + continue + resp = json.loads(resp[6:]) + if "delta" in resp["choices"][0]: + text = resp["choices"][0]["delta"]["content"] + else: + continue + ans += text + total_tokens += 1 + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + + class LocalLLM(Base): class RPCProxy: def __init__(self, host, port): diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 09b7347f4..e63d6a0a3 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -189,6 +189,35 @@ class OllamaCV(Base): return "**ERROR**: " + str(e), 0 +class LocalAICV(Base): + def __init__(self, key, model_name, base_url, lang="Chinese"): + self.client = OpenAI(api_key="empty", base_url=base_url) + self.model_name = model_name.split("___")[0] + self.lang = lang + + def describe(self, image, max_tokens=300): + if not isinstance(image, bytes) and not isinstance( + image, BytesIO + ): # if url string + prompt = self.prompt(image) + for i in range(len(prompt)): + prompt[i]["content"]["image_url"]["url"] = image + else: + b64 = self.image2base64(image) + prompt = self.prompt(b64) + for i in range(len(prompt)): + for c in prompt[i]["content"]: + if "text" in c: + c["type"] = "text" + + res = self.client.chat.completions.create( + model=self.model_name, + messages=prompt, + max_tokens=max_tokens, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + + class XinferenceCV(Base): def __init__(self, key, model_name="", lang="Chinese", base_url=""): self.client = OpenAI(api_key="xxx", base_url=base_url) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 1774bc285..d1290981d 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -111,6 +111,24 @@ class OpenAIEmbed(Base): return np.array(res.data[0].embedding), res.usage.total_tokens +class LocalAIEmbed(Base): + def __init__(self, key, model_name, base_url): + self.base_url = base_url + "/embeddings" + self.headers = { + "Content-Type": "application/json", + } + self.model_name = model_name.split("___")[0] + + def encode(self, texts: list, batch_size=None): + data = {"model": self.model_name, "input": texts, "encoding_type": "float"} + res = requests.post(self.base_url, headers=self.headers, json=data).json() + + return np.array([d["embedding"] for d in res["data"]]), 1024 + + def encode_queries(self, text): + embds, cnt = self.encode([text]) + return np.array(embds[0]), cnt + class AzureEmbed(OpenAIEmbed): def __init__(self, key, model_name, **kwargs): self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") @@ -443,4 +461,4 @@ class GeminiEmbed(Base): task_type="retrieval_document", title="Embedding of single string") token_count = num_tokens_from_string(text) - return np.array(result['embedding']),token_count \ No newline at end of file + return np.array(result['embedding']),token_count diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 11bd314ca..4cf23bc4c 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -135,7 +135,7 @@ class YoudaoRerank(DefaultRerank): if isinstance(scores, float): res.append(scores) else: res.extend(scores) return np.array(res), token_count - + class XInferenceRerank(Base): def __init__(self, key="xxxxxxx", model_name="", base_url=""): @@ -156,3 +156,11 @@ class XInferenceRerank(Base): } res = requests.post(self.base_url, headers=self.headers, json=data).json() return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"] + + +class LocalAIRerank(Base): + def __init__(self, key, model_name, base_url): + pass + + def similarity(self, query: str, texts: list): + raise NotImplementedError("The LocalAIRerank has not been implement") diff --git a/web/src/pages/user-setting/constants.tsx b/web/src/pages/user-setting/constants.tsx index e812ebaf5..5eb74b9a0 100644 --- a/web/src/pages/user-setting/constants.tsx +++ b/web/src/pages/user-setting/constants.tsx @@ -17,4 +17,4 @@ export const UserSettingIconMap = { export * from '@/constants/setting'; -export const LocalLlmFactories = ['Ollama', 'Xinference']; +export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI']; diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index ece087c1e..ffa71241e 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -75,6 +75,7 @@ const OllamaModal = ({ +