add support for LM Studio (#1663)

### What problem does this PR solve?

#1602 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾 2024-07-24 12:46:43 +08:00 committed by GitHub
parent 100b3165d8
commit d96348eb22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 9791 additions and 26 deletions

View File

@ -21,7 +21,7 @@ from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
import requests
@manager.route('/factories', methods=['GET']) @manager.route('/factories', methods=['GET'])
@login_required @login_required
@ -189,9 +189,13 @@ def add_llm():
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256" "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
) )
m, tc = mdl.describe(img_url) res = requests.get(img_url)
if not tc: if res.status_code == 200:
raise Exception(m) m, tc = mdl.describe(res.content)
if not tc:
raise Exception(m)
else:
raise ConnectionError("fail to download the test picture")
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(e) msg += f"\nFail to access model({llm['llm_name']})." + str(e)
else: else:

View File

@ -2208,6 +2208,13 @@
"model_type": "image2text" "model_type": "image2text"
} }
] ]
},
{
"name": "LM-Studio",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
"status": "1",
"llm": []
} }
] ]
} }

View File

@ -34,8 +34,9 @@ EmbeddingModel = {
"BAAI": DefaultEmbedding, "BAAI": DefaultEmbedding,
"Mistral": MistralEmbed, "Mistral": MistralEmbed,
"Bedrock": BedrockEmbed, "Bedrock": BedrockEmbed,
"Gemini":GeminiEmbed, "Gemini": GeminiEmbed,
"NVIDIA":NvidiaEmbed "NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed
} }
@ -47,10 +48,11 @@ CvModel = {
"Tongyi-Qianwen": QWenCV, "Tongyi-Qianwen": QWenCV,
"ZHIPU-AI": Zhipu4V, "ZHIPU-AI": Zhipu4V,
"Moonshot": LocalCV, "Moonshot": LocalCV,
'Gemini':GeminiCV, "Gemini": GeminiCV,
'OpenRouter':OpenRouterCV, "OpenRouter": OpenRouterCV,
"LocalAI":LocalAICV, "LocalAI": LocalAICV,
"NVIDIA":NvidiaCV "NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV
} }
@ -69,12 +71,13 @@ ChatModel = {
"MiniMax": MiniMaxChat, "MiniMax": MiniMaxChat,
"Minimax": MiniMaxChat, "Minimax": MiniMaxChat,
"Mistral": MistralChat, "Mistral": MistralChat,
'Gemini' : GeminiChat, "Gemini": GeminiChat,
"Bedrock": BedrockChat, "Bedrock": BedrockChat,
"Groq": GroqChat, "Groq": GroqChat,
'OpenRouter':OpenRouterChat, "OpenRouter": OpenRouterChat,
"StepFun":StepFunChat, "StepFun": StepFunChat,
"NVIDIA":NvidiaChat "NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat
} }
@ -83,7 +86,8 @@ RerankModel = {
"Jina": JinaRerank, "Jina": JinaRerank,
"Youdao": YoudaoRerank, "Youdao": YoudaoRerank,
"Xinference": XInferenceRerank, "Xinference": XInferenceRerank,
"NVIDIA":NvidiaRerank "NVIDIA": NvidiaRerank,
"LM-Studio": LmStudioRerank
} }

View File

@ -976,3 +976,15 @@ class NvidiaChat(Base):
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
yield total_tokens yield total_tokens
class LmStudioChat(Base):
def __init__(self, key, model_name, base_url):
from os.path import join
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
self.model_name = model_name

View File

@ -440,15 +440,8 @@ class LocalAICV(Base):
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
if not isinstance(image, bytes) and not isinstance( b64 = self.image2base64(image)
image, BytesIO prompt = self.prompt(b64)
): # if url string
prompt = self.prompt(image)
for i in range(len(prompt)):
prompt[i]["content"]["image_url"]["url"] = image
else:
b64 = self.image2base64(image)
prompt = self.prompt(b64)
for i in range(len(prompt)): for i in range(len(prompt)):
for c in prompt[i]["content"]: for c in prompt[i]["content"]:
if "text" in c: if "text" in c:
@ -680,3 +673,14 @@ class NvidiaCV(Base):
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>', "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
} }
] ]
class LmStudioCV(LocalAICV):
def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split('/')[-1] != 'v1':
self.base_url = os.path.join(base_url,'v1')
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
self.model_name = model_name
self.lang = lang

View File

@ -500,3 +500,24 @@ class NvidiaEmbed(Base):
def encode_queries(self, text): def encode_queries(self, text):
embds, cnt = self.encode([text]) embds, cnt = self.encode([text])
return np.array(embds[0]), cnt return np.array(embds[0]), cnt
class LmStudioEmbed(Base):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts, model=self.model_name)
return (
np.array([d.embedding for d in res.data]),
1024,
) # local embedding for LmStudio donot count tokens
def encode_queries(self, text):
res = self.client.embeddings.create(text, model=self.model_name)
return np.array(res.data[0].embedding), 1024

View File

@ -202,3 +202,11 @@ class NvidiaRerank(Base):
} }
res = requests.post(self.base_url, headers=self.headers, json=data).json() res = requests.post(self.base_url, headers=self.headers, json=data).json()
return (np.array([d["logit"] for d in res["rankings"]]), token_count) return (np.array([d["logit"] for d in res["rankings"]]), token_count)
class LmStudioRerank(Base):
def __init__(self, key, model_name, base_url):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 730 KiB

View File

@ -17,4 +17,4 @@ export const UserSettingIconMap = {
export * from '@/constants/setting'; export * from '@/constants/setting';
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI']; export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio'];

View File

@ -20,7 +20,8 @@ export const IconMap = {
OpenRouter: 'open-router', OpenRouter: 'open-router',
LocalAI: 'local-ai', LocalAI: 'local-ai',
StepFun: 'stepfun', StepFun: 'stepfun',
NVIDIA:'nvidia' NVIDIA:'nvidia',
'LM-Studio':'lm-studio'
}; };
export const BedrockRegionList = [ export const BedrockRegionList = [