mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 03:46:00 +08:00
Feat: add gpustack model provider (#4469)
### What problem does this PR solve? Add GPUStack as a new model provider. [GPUStack](https://github.com/gpustack/gpustack) is an open-source GPU cluster manager for running LLMs. Currently, locally deployed models in GPUStack cannot integrate well with RAGFlow. GPUStack provides both OpenAI compatible APIs (Models / Chat Completions / Embeddings / Speech2Text / TTS) and other APIs like Rerank. We would like to use GPUStack as a model provider in ragflow. [GPUStack Docs](https://docs.gpustack.ai/latest/quickstart/) Related issue: https://github.com/infiniflow/ragflow/issues/4064. ### Type of change - [x] New Feature (non-breaking change which adds functionality) ### Testing Instructions 1. Install GPUStack and deploy the `llama-3.2-1b-instruct` llm, `bge-m3` text embedding model, `bge-reranker-v2-m3` rerank model, `faster-whisper-medium` Speech-to-Text model, `cosyvoice-300m-sft` in GPUStack. 2. Add provider in ragflow settings. 3. Testing in ragflow.
This commit is contained in:
parent
e478586a8e
commit
7944aacafa
@ -329,7 +329,7 @@ def my_llms():
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_app():
|
||||
self_deploied = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
||||
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
@ -339,7 +339,7 @@ def list_app():
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
|
||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||
for o in objs:
|
||||
|
@ -2543,6 +2543,13 @@
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
"name": "GPUStack",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,TTS,SPEECH2TEXT,TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"llm": []
|
||||
}
|
||||
]
|
||||
}
|
@ -42,6 +42,7 @@ from .embedding_model import (
|
||||
VoyageEmbed,
|
||||
HuggingFaceEmbed,
|
||||
VolcEngineEmbed,
|
||||
GPUStackEmbed,
|
||||
)
|
||||
from .chat_model import (
|
||||
GptTurbo,
|
||||
@ -80,6 +81,7 @@ from .chat_model import (
|
||||
AnthropicChat,
|
||||
GoogleChat,
|
||||
HuggingFaceChat,
|
||||
GPUStackChat,
|
||||
)
|
||||
|
||||
from .cv_model import (
|
||||
@ -116,6 +118,7 @@ from .rerank_model import (
|
||||
BaiduYiyanRerank,
|
||||
VoyageRerank,
|
||||
QWenRerank,
|
||||
GPUStackRerank,
|
||||
)
|
||||
from .sequence2txt_model import (
|
||||
GPTSeq2txt,
|
||||
@ -123,6 +126,7 @@ from .sequence2txt_model import (
|
||||
AzureSeq2txt,
|
||||
XinferenceSeq2txt,
|
||||
TencentCloudSeq2txt,
|
||||
GPUStackSeq2txt,
|
||||
)
|
||||
from .tts_model import (
|
||||
FishAudioTTS,
|
||||
@ -130,6 +134,7 @@ from .tts_model import (
|
||||
OpenAITTS,
|
||||
SparkTTS,
|
||||
XinferenceTTS,
|
||||
GPUStackTTS,
|
||||
)
|
||||
|
||||
EmbeddingModel = {
|
||||
@ -161,6 +166,7 @@ EmbeddingModel = {
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace": HuggingFaceEmbed,
|
||||
"VolcEngine": VolcEngineEmbed,
|
||||
"GPUStack": GPUStackEmbed,
|
||||
}
|
||||
|
||||
CvModel = {
|
||||
@ -220,6 +226,7 @@ ChatModel = {
|
||||
"Anthropic": AnthropicChat,
|
||||
"Google Cloud": GoogleChat,
|
||||
"HuggingFace": HuggingFaceChat,
|
||||
"GPUStack": GPUStackChat,
|
||||
}
|
||||
|
||||
RerankModel = {
|
||||
@ -237,6 +244,7 @@ RerankModel = {
|
||||
"BaiduYiyan": BaiduYiyanRerank,
|
||||
"Voyage AI": VoyageRerank,
|
||||
"Tongyi-Qianwen": QWenRerank,
|
||||
"GPUStack": GPUStackRerank,
|
||||
}
|
||||
|
||||
Seq2txtModel = {
|
||||
@ -245,6 +253,7 @@ Seq2txtModel = {
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt,
|
||||
"GPUStack": GPUStackSeq2txt,
|
||||
}
|
||||
|
||||
TTSModel = {
|
||||
@ -253,4 +262,5 @@ TTSModel = {
|
||||
"OpenAI": OpenAITTS,
|
||||
"XunFei Spark": SparkTTS,
|
||||
"Xinference": XinferenceTTS,
|
||||
"GPUStack": GPUStackTTS,
|
||||
}
|
||||
|
@ -1514,3 +1514,11 @@ class GoogleChat(Base):
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield response._chunks[-1].usage_metadata.total_token_count
|
||||
|
||||
class GPUStackChat(Base):
|
||||
def __init__(self, key=None, model_name="", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1-openai":
|
||||
base_url = os.path.join(base_url, "v1-openai")
|
||||
super().__init__(key, model_name, base_url)
|
@ -799,3 +799,14 @@ class VolcEngineEmbed(OpenAIEmbed):
|
||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
||||
super().__init__(ark_api_key,model_name,base_url)
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1-openai":
|
||||
base_url = os.path.join(base_url, "v1-openai")
|
||||
|
||||
print(key,base_url)
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
@ -18,10 +18,12 @@ import threading
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from huggingface_hub import snapshot_download
|
||||
import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
@ -457,3 +459,53 @@ class QWenRerank(Base):
|
||||
return rank, resp.usage.total_tokens
|
||||
else:
|
||||
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
|
||||
|
||||
class GPUStackRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url
|
||||
):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
|
||||
self.model_name = model_name
|
||||
self.base_url = str(URL(base_url)/ "v1" / "rerank")
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {key}",
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts),
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=self.headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if "results" not in response_json:
|
||||
return rank, 0
|
||||
|
||||
token_count = 0
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
|
||||
for result in response_json["results"]:
|
||||
rank[result["index"]] = result["relevance_score"]
|
||||
|
||||
return (
|
||||
rank,
|
||||
token_count,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import requests
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
import io
|
||||
@ -191,3 +192,14 @@ class TencentCloudSeq2txt(Base):
|
||||
return "**ERROR**: " + str(e), 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1-openai":
|
||||
base_url = os.path.join(base_url, "v1-openai")
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
@ -355,3 +355,35 @@ class OllamaTTS(Base):
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
class GPUStackTTS:
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1-openai/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
14
web/src/assets/svg/llm/gpustack.svg
Normal file
14
web/src/assets/svg/llm/gpustack.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 20 KiB |
@ -72,6 +72,7 @@ export const IconMap = {
|
||||
'nomic-ai': 'nomic-ai',
|
||||
jinaai: 'jina',
|
||||
'sentence-transformers': 'sentence-transformers',
|
||||
GPUStack: 'gpustack',
|
||||
};
|
||||
|
||||
export const TimezoneList = [
|
||||
|
@ -31,6 +31,7 @@ export const LocalLlmFactories = [
|
||||
'Replicate',
|
||||
'OpenRouter',
|
||||
'HuggingFace',
|
||||
'GPUStack',
|
||||
];
|
||||
|
||||
export enum TenantRole {
|
||||
|
@ -29,6 +29,7 @@ const llmFactoryToUrlMap = {
|
||||
OpenRouter: 'https://openrouter.ai/docs',
|
||||
HuggingFace:
|
||||
'https://huggingface.co/docs/text-embeddings-inference/quick_tour',
|
||||
GPUStack: 'https://docs.gpustack.ai/latest/quickstart',
|
||||
};
|
||||
type LlmFactory = keyof typeof llmFactoryToUrlMap;
|
||||
|
||||
@ -76,6 +77,13 @@ const OllamaModal = ({
|
||||
{ value: 'speech2text', label: 'sequence2text' },
|
||||
{ value: 'tts', label: 'tts' },
|
||||
],
|
||||
GPUStack: [
|
||||
{ value: 'chat', label: 'chat' },
|
||||
{ value: 'embedding', label: 'embedding' },
|
||||
{ value: 'rerank', label: 'rerank' },
|
||||
{ value: 'speech2text', label: 'sequence2text' },
|
||||
{ value: 'tts', label: 'tts' },
|
||||
],
|
||||
Default: [
|
||||
{ value: 'chat', label: 'chat' },
|
||||
{ value: 'embedding', label: 'embedding' },
|
||||
|
Loading…
x
Reference in New Issue
Block a user