mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 14:35:52 +08:00
add support for LM Studio (#1663)
### What problem does this PR solve? #1602 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
parent
100b3165d8
commit
d96348eb22
@ -21,7 +21,7 @@ from api.db import StatusEnum, LLMType
|
|||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
||||||
|
import requests
|
||||||
|
|
||||||
@manager.route('/factories', methods=['GET'])
|
@manager.route('/factories', methods=['GET'])
|
||||||
@login_required
|
@login_required
|
||||||
@ -189,9 +189,13 @@ def add_llm():
|
|||||||
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
|
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
|
||||||
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
)
|
)
|
||||||
m, tc = mdl.describe(img_url)
|
res = requests.get(img_url)
|
||||||
if not tc:
|
if res.status_code == 200:
|
||||||
raise Exception(m)
|
m, tc = mdl.describe(res.content)
|
||||||
|
if not tc:
|
||||||
|
raise Exception(m)
|
||||||
|
else:
|
||||||
|
raise ConnectionError("fail to download the test picture")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
||||||
else:
|
else:
|
||||||
|
@ -2208,6 +2208,13 @@
|
|||||||
"model_type": "image2text"
|
"model_type": "image2text"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "LM-Studio",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
||||||
|
"status": "1",
|
||||||
|
"llm": []
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -34,8 +34,9 @@ EmbeddingModel = {
|
|||||||
"BAAI": DefaultEmbedding,
|
"BAAI": DefaultEmbedding,
|
||||||
"Mistral": MistralEmbed,
|
"Mistral": MistralEmbed,
|
||||||
"Bedrock": BedrockEmbed,
|
"Bedrock": BedrockEmbed,
|
||||||
"Gemini":GeminiEmbed,
|
"Gemini": GeminiEmbed,
|
||||||
"NVIDIA":NvidiaEmbed
|
"NVIDIA": NvidiaEmbed,
|
||||||
|
"LM-Studio": LmStudioEmbed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -47,10 +48,11 @@ CvModel = {
|
|||||||
"Tongyi-Qianwen": QWenCV,
|
"Tongyi-Qianwen": QWenCV,
|
||||||
"ZHIPU-AI": Zhipu4V,
|
"ZHIPU-AI": Zhipu4V,
|
||||||
"Moonshot": LocalCV,
|
"Moonshot": LocalCV,
|
||||||
'Gemini':GeminiCV,
|
"Gemini": GeminiCV,
|
||||||
'OpenRouter':OpenRouterCV,
|
"OpenRouter": OpenRouterCV,
|
||||||
"LocalAI":LocalAICV,
|
"LocalAI": LocalAICV,
|
||||||
"NVIDIA":NvidiaCV
|
"NVIDIA": NvidiaCV,
|
||||||
|
"LM-Studio": LmStudioCV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -69,12 +71,13 @@ ChatModel = {
|
|||||||
"MiniMax": MiniMaxChat,
|
"MiniMax": MiniMaxChat,
|
||||||
"Minimax": MiniMaxChat,
|
"Minimax": MiniMaxChat,
|
||||||
"Mistral": MistralChat,
|
"Mistral": MistralChat,
|
||||||
'Gemini' : GeminiChat,
|
"Gemini": GeminiChat,
|
||||||
"Bedrock": BedrockChat,
|
"Bedrock": BedrockChat,
|
||||||
"Groq": GroqChat,
|
"Groq": GroqChat,
|
||||||
'OpenRouter':OpenRouterChat,
|
"OpenRouter": OpenRouterChat,
|
||||||
"StepFun":StepFunChat,
|
"StepFun": StepFunChat,
|
||||||
"NVIDIA":NvidiaChat
|
"NVIDIA": NvidiaChat,
|
||||||
|
"LM-Studio": LmStudioChat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +86,8 @@ RerankModel = {
|
|||||||
"Jina": JinaRerank,
|
"Jina": JinaRerank,
|
||||||
"Youdao": YoudaoRerank,
|
"Youdao": YoudaoRerank,
|
||||||
"Xinference": XInferenceRerank,
|
"Xinference": XInferenceRerank,
|
||||||
"NVIDIA":NvidiaRerank
|
"NVIDIA": NvidiaRerank,
|
||||||
|
"LM-Studio": LmStudioRerank
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -976,3 +976,15 @@ class NvidiaChat(Base):
|
|||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class LmStudioChat(Base):
|
||||||
|
def __init__(self, key, model_name, base_url):
|
||||||
|
from os.path import join
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("Local llm url cannot be None")
|
||||||
|
if base_url.split("/")[-1] != "v1":
|
||||||
|
self.base_url = join(base_url, "v1")
|
||||||
|
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
||||||
|
self.model_name = model_name
|
||||||
|
@ -440,15 +440,8 @@ class LocalAICV(Base):
|
|||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
if not isinstance(image, bytes) and not isinstance(
|
b64 = self.image2base64(image)
|
||||||
image, BytesIO
|
prompt = self.prompt(b64)
|
||||||
): # if url string
|
|
||||||
prompt = self.prompt(image)
|
|
||||||
for i in range(len(prompt)):
|
|
||||||
prompt[i]["content"]["image_url"]["url"] = image
|
|
||||||
else:
|
|
||||||
b64 = self.image2base64(image)
|
|
||||||
prompt = self.prompt(b64)
|
|
||||||
for i in range(len(prompt)):
|
for i in range(len(prompt)):
|
||||||
for c in prompt[i]["content"]:
|
for c in prompt[i]["content"]:
|
||||||
if "text" in c:
|
if "text" in c:
|
||||||
@ -680,3 +673,14 @@ class NvidiaCV(Base):
|
|||||||
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LmStudioCV(LocalAICV):
|
||||||
|
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("Local llm url cannot be None")
|
||||||
|
if base_url.split('/')[-1] != 'v1':
|
||||||
|
self.base_url = os.path.join(base_url,'v1')
|
||||||
|
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.lang = lang
|
||||||
|
@ -500,3 +500,24 @@ class NvidiaEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds, cnt = self.encode([text])
|
embds, cnt = self.encode([text])
|
||||||
return np.array(embds[0]), cnt
|
return np.array(embds[0]), cnt
|
||||||
|
|
||||||
|
|
||||||
|
class LmStudioEmbed(Base):
|
||||||
|
def __init__(self, key, model_name, base_url):
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("Local llm url cannot be None")
|
||||||
|
if base_url.split("/")[-1] != "v1":
|
||||||
|
self.base_url = os.path.join(base_url, "v1")
|
||||||
|
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
||||||
|
return (
|
||||||
|
np.array([d.embedding for d in res.data]),
|
||||||
|
1024,
|
||||||
|
) # local embedding for LmStudio donot count tokens
|
||||||
|
|
||||||
|
def encode_queries(self, text):
|
||||||
|
res = self.client.embeddings.create(text, model=self.model_name)
|
||||||
|
return np.array(res.data[0].embedding), 1024
|
||||||
|
@ -202,3 +202,11 @@ class NvidiaRerank(Base):
|
|||||||
}
|
}
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
||||||
|
|
||||||
|
|
||||||
|
class LmStudioRerank(Base):
|
||||||
|
def __init__(self, key, model_name, base_url):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
raise NotImplementedError("The LmStudioRerank has not been implement")
|
||||||
|
9704
web/src/assets/svg/llm/lm-studio.svg
Normal file
9704
web/src/assets/svg/llm/lm-studio.svg
Normal file
File diff suppressed because it is too large
Load Diff
After Width: | Height: | Size: 730 KiB |
@ -17,4 +17,4 @@ export const UserSettingIconMap = {
|
|||||||
|
|
||||||
export * from '@/constants/setting';
|
export * from '@/constants/setting';
|
||||||
|
|
||||||
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI'];
|
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio'];
|
||||||
|
@ -20,7 +20,8 @@ export const IconMap = {
|
|||||||
OpenRouter: 'open-router',
|
OpenRouter: 'open-router',
|
||||||
LocalAI: 'local-ai',
|
LocalAI: 'local-ai',
|
||||||
StepFun: 'stepfun',
|
StepFun: 'stepfun',
|
||||||
NVIDIA:'nvidia'
|
NVIDIA:'nvidia',
|
||||||
|
'LM-Studio':'lm-studio'
|
||||||
};
|
};
|
||||||
|
|
||||||
export const BedrockRegionList = [
|
export const BedrockRegionList = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user