From 96f56a3c43de54b3315770f2a42bd80efc693307 Mon Sep 17 00:00:00 2001 From: JobSmithManipulation <143315462+JobSmithManipulation@users.noreply.github.com> Date: Fri, 27 Sep 2024 19:15:38 +0800 Subject: [PATCH] add huggingface model (#2624) ### What problem does this PR solve? #2469 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu --- api/apps/llm_app.py | 4 ++ conf/llm_factories.json | 9 ++++- rag/llm/__init__.py | 5 ++- rag/llm/chat_model.py | 1 + rag/llm/embedding_model.py | 37 ++++++++++++++++++ web/src/assets/svg/llm/huggingface.svg | 37 ++++++++++++++++++ web/src/pages/user-setting/constants.tsx | 1 + .../user-setting/setting-model/constant.ts | 1 + .../setting-model/ollama-modal/index.tsx | 38 ++++++++++++++----- 9 files changed, 120 insertions(+), 13 deletions(-) create mode 100644 web/src/assets/svg/llm/huggingface.svg diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index ca0333ac5..0ddfcc38c 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -155,6 +155,10 @@ def add_llm(): elif factory == "LocalAI": llm_name = req["llm_name"]+"___LocalAI" api_key = "xxxxxxxxxxxxxxx" + + elif factory == "HuggingFace": + llm_name = req["llm_name"]+"___HuggingFace" + api_key = "xxxxxxxxxxxxxxx" elif factory == "OpenAI-API-Compatible": llm_name = req["llm_name"]+"___OpenAI-API" diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 6aece225e..4daa014b0 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -2344,6 +2344,13 @@ "tags": "LLM", "status": "1", "llm": [] - } + }, + { + "name": "HuggingFace", + "logo": "", + "tags": "TEXT EMBEDDING", + "status": "1", + "llm": [] + } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 84c1adf01..441a2a553 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -18,7 +18,7 @@ from .chat_model import * from .cv_model import * from .rerank_model import * from .sequence2txt_model import * -from .tts_model import * +from .tts_model import * EmbeddingModel = { "Ollama": OllamaEmbed, @@ -46,7 +46,8 @@ EmbeddingModel = { "SILICONFLOW": SILICONFLOWEmbed, "Replicate": ReplicateEmbed, "BaiduYiyan": BaiduYiyanEmbed, - "Voyage AI": VoyageEmbed + "Voyage AI": VoyageEmbed, + "HuggingFace":HuggingFaceEmbed, } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index bfce43499..479f6fe6d 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1414,3 +1414,4 @@ class GoogleChat(Base): yield ans + "\n**ERROR**: " + str(e) yield response._chunks[-1].usage_metadata.total_token_count + \ No newline at end of file diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 498b57426..4189a022f 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -678,3 +678,40 @@ class VoyageEmbed(Base): texts=text, model=self.model_name, input_type="query" ) return np.array(res.embeddings), res.total_tokens + + +class HuggingFaceEmbed(Base): + def __init__(self, key, model_name, base_url=None): + if not model_name: + raise ValueError("Model name cannot be None") + self.key = key + self.model_name = model_name + self.base_url = base_url or "http://127.0.0.1:8080" + + def encode(self, texts: list, batch_size=32): + embeddings = [] + for text in texts: + response = requests.post( + f"{self.base_url}/embed", + json={"inputs": text}, + headers={'Content-Type': 'application/json'} + ) + if response.status_code == 200: + embedding = response.json() + embeddings.append(embedding[0]) + else: + raise Exception(f"Error: {response.status_code} - {response.text}") + return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) + + def encode_queries(self, text): + response = requests.post( + f"{self.base_url}/embed", + json={"inputs": text}, + headers={'Content-Type': 'application/json'} + ) + if response.status_code == 200: + embedding = response.json() + return np.array(embedding[0]), num_tokens_from_string(text) + else: + raise Exception(f"Error: {response.status_code} - {response.text}") + diff --git a/web/src/assets/svg/llm/huggingface.svg b/web/src/assets/svg/llm/huggingface.svg new file mode 100644 index 000000000..43c5d3c0c --- /dev/null +++ b/web/src/assets/svg/llm/huggingface.svg @@ -0,0 +1,37 @@ + + + + + + + + + + + diff --git a/web/src/pages/user-setting/constants.tsx b/web/src/pages/user-setting/constants.tsx index c65e27103..e8360487e 100644 --- a/web/src/pages/user-setting/constants.tsx +++ b/web/src/pages/user-setting/constants.tsx @@ -26,4 +26,5 @@ export const LocalLlmFactories = [ 'TogetherAI', 'Replicate', 'OpenRouter', + 'HuggingFace', ]; diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index 8ac93e782..7be3c4d23 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -40,6 +40,7 @@ export const IconMap = { Anthropic: 'anthropic', 'Voyage AI': 'voyage', 'Google Cloud': 'google-cloud', + HuggingFace: 'huggingface', }; export const BedrockRegionList = [ diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index 1cf961139..489c0f255 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -8,6 +8,20 @@ type FieldType = IAddLlmRequestBody & { vision: boolean }; const { Option } = Select; +const llmFactoryToUrlMap = { + Ollama: 'https://huggingface.co/docs/text-embeddings-inference/quick_tour', + Xinference: 'https://inference.readthedocs.io/en/latest/user_guide', + LocalAI: 'https://localai.io/docs/getting-started/models/', + 'LM-Studio': 'https://lmstudio.ai/docs/basics', + 'OpenAI-API-Compatible': 'https://platform.openai.com/docs/models/gpt-4', + TogetherAI: 'https://docs.together.ai/docs/deployment-options', + Replicate: 'https://replicate.com/docs/topics/deployments', + OpenRouter: 'https://openrouter.ai/docs', + HuggingFace: + 'https://huggingface.co/docs/text-embeddings-inference/quick_tour', +}; +type LlmFactory = keyof typeof llmFactoryToUrlMap; + const OllamaModal = ({ visible, hideModal, @@ -35,7 +49,9 @@ const OllamaModal = ({ onOk?.(data); }; - + const url = + llmFactoryToUrlMap[llmFactory as LlmFactory] || + 'https://huggingface.co/docs/text-embeddings-inference/quick_tour'; return ( { return ( - + {t('ollamaLink', { name: llmFactory })} {originNode} @@ -72,10 +84,16 @@ const OllamaModal = ({ rules={[{ required: true, message: t('modelTypeMessage') }]} >