Feat: add gpustack model provider (#4469)

### What problem does this PR solve?

Add GPUStack as a new model provider.
[GPUStack](https://github.com/gpustack/gpustack) is an open-source GPU
cluster manager for running LLMs. Currently, locally deployed models in
GPUStack cannot integrate well with RAGFlow. GPUStack provides both
OpenAI compatible APIs (Models / Chat Completions / Embeddings /
Speech2Text / TTS) and other APIs like Rerank. We would like to use
GPUStack as a model provider in ragflow.

[GPUStack Docs](https://docs.gpustack.ai/latest/quickstart/)

Related issue: https://github.com/infiniflow/ragflow/issues/4064.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)



### Testing Instructions
1. Install GPUStack and deploy the `llama-3.2-1b-instruct` llm, `bge-m3`
text embedding model, `bge-reranker-v2-m3` rerank model,
`faster-whisper-medium` Speech-to-Text model, `cosyvoice-300m-sft` in
GPUStack.
2. Add provider in ragflow settings.
3. Testing in ragflow.
This commit is contained in:
Alex Chen 2025-01-15 14:15:58 +08:00 committed by GitHub
parent e478586a8e
commit 7944aacafa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 159 additions and 3 deletions

View File

@ -329,7 +329,7 @@ def my_llms():
@manager.route('/list', methods=['GET']) # noqa: F821 @manager.route('/list', methods=['GET']) # noqa: F821
@login_required @login_required
def list_app(): def list_app():
self_deploied = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type") model_type = request.args.get("model_type")
try: try:
@ -339,7 +339,7 @@ def list_app():
llms = [m.to_dict() llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted] for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms]) llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs: for o in objs:

View File

@ -2543,6 +2543,13 @@
"tags": "TEXT EMBEDDING", "tags": "TEXT EMBEDDING",
"status": "1", "status": "1",
"llm": [] "llm": []
},
{
"name": "GPUStack",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,TTS,SPEECH2TEXT,TEXT RE-RANK",
"status": "1",
"llm": []
} }
] ]
} }

View File

@ -42,6 +42,7 @@ from .embedding_model import (
VoyageEmbed, VoyageEmbed,
HuggingFaceEmbed, HuggingFaceEmbed,
VolcEngineEmbed, VolcEngineEmbed,
GPUStackEmbed,
) )
from .chat_model import ( from .chat_model import (
GptTurbo, GptTurbo,
@ -80,6 +81,7 @@ from .chat_model import (
AnthropicChat, AnthropicChat,
GoogleChat, GoogleChat,
HuggingFaceChat, HuggingFaceChat,
GPUStackChat,
) )
from .cv_model import ( from .cv_model import (
@ -116,6 +118,7 @@ from .rerank_model import (
BaiduYiyanRerank, BaiduYiyanRerank,
VoyageRerank, VoyageRerank,
QWenRerank, QWenRerank,
GPUStackRerank,
) )
from .sequence2txt_model import ( from .sequence2txt_model import (
GPTSeq2txt, GPTSeq2txt,
@ -123,6 +126,7 @@ from .sequence2txt_model import (
AzureSeq2txt, AzureSeq2txt,
XinferenceSeq2txt, XinferenceSeq2txt,
TencentCloudSeq2txt, TencentCloudSeq2txt,
GPUStackSeq2txt,
) )
from .tts_model import ( from .tts_model import (
FishAudioTTS, FishAudioTTS,
@ -130,6 +134,7 @@ from .tts_model import (
OpenAITTS, OpenAITTS,
SparkTTS, SparkTTS,
XinferenceTTS, XinferenceTTS,
GPUStackTTS,
) )
EmbeddingModel = { EmbeddingModel = {
@ -161,6 +166,7 @@ EmbeddingModel = {
"Voyage AI": VoyageEmbed, "Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed, "HuggingFace": HuggingFaceEmbed,
"VolcEngine": VolcEngineEmbed, "VolcEngine": VolcEngineEmbed,
"GPUStack": GPUStackEmbed,
} }
CvModel = { CvModel = {
@ -220,6 +226,7 @@ ChatModel = {
"Anthropic": AnthropicChat, "Anthropic": AnthropicChat,
"Google Cloud": GoogleChat, "Google Cloud": GoogleChat,
"HuggingFace": HuggingFaceChat, "HuggingFace": HuggingFaceChat,
"GPUStack": GPUStackChat,
} }
RerankModel = { RerankModel = {
@ -237,6 +244,7 @@ RerankModel = {
"BaiduYiyan": BaiduYiyanRerank, "BaiduYiyan": BaiduYiyanRerank,
"Voyage AI": VoyageRerank, "Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank, "Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
} }
Seq2txtModel = { Seq2txtModel = {
@ -245,6 +253,7 @@ Seq2txtModel = {
"Azure-OpenAI": AzureSeq2txt, "Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt, "Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt, "Tencent Cloud": TencentCloudSeq2txt,
"GPUStack": GPUStackSeq2txt,
} }
TTSModel = { TTSModel = {
@ -253,4 +262,5 @@ TTSModel = {
"OpenAI": OpenAITTS, "OpenAI": OpenAITTS,
"XunFei Spark": SparkTTS, "XunFei Spark": SparkTTS,
"Xinference": XinferenceTTS, "Xinference": XinferenceTTS,
"GPUStack": GPUStackTTS,
} }

View File

@ -1514,3 +1514,11 @@ class GoogleChat(Base):
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
yield response._chunks[-1].usage_metadata.total_token_count yield response._chunks[-1].usage_metadata.total_token_count
class GPUStackChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")
super().__init__(key, model_name, base_url)

View File

@ -30,7 +30,7 @@ import asyncio
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai import google.generativeai as genai
import json import json
@ -799,3 +799,14 @@ class VolcEngineEmbed(OpenAIEmbed):
ark_api_key = json.loads(key).get('ark_api_key', '') ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
super().__init__(ark_api_key,model_name,base_url) super().__init__(ark_api_key,model_name,base_url)
class GPUStackEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")
print(key,base_url)
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name

View File

@ -18,10 +18,12 @@ import threading
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
import httpx
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import os import os
from abc import ABC from abc import ABC
import numpy as np import numpy as np
from yarl import URL
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
@ -457,3 +459,53 @@ class QWenRerank(Base):
return rank, resp.usage.total_tokens return rank, resp.usage.total_tokens
else: else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}") raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
class GPUStackRerank(Base):
def __init__(
self, key, model_name, base_url
):
if not base_url:
raise ValueError("url cannot be None")
self.model_name = model_name
self.base_url = str(URL(base_url)/ "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}
try:
response = requests.post(
self.base_url, json=payload, headers=self.headers
)
response.raise_for_status()
response_json = response.json()
rank = np.zeros(len(texts), dtype=float)
if "results" not in response_json:
return rank, 0
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for result in response_json["results"]:
rank[result["index"]] = result["relevance_score"]
return (
rank,
token_count,
)
except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import requests import requests
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
import io import io
@ -191,3 +192,14 @@ class TencentCloudSeq2txt(Base):
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
class GPUStackSeq2txt(Base):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")
self.base_url = base_url
self.model_name = model_name
self.key = key

View File

@ -355,3 +355,35 @@ class OllamaTTS(Base):
for chunk in response.iter_content(): for chunk in response.iter_content():
if chunk: if chunk:
yield chunk yield chunk
class GPUStackTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def tts(self, text, voice="Chinese Female", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
response = requests.post(
f"{self.base_url}/v1-openai/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -72,6 +72,7 @@ export const IconMap = {
'nomic-ai': 'nomic-ai', 'nomic-ai': 'nomic-ai',
jinaai: 'jina', jinaai: 'jina',
'sentence-transformers': 'sentence-transformers', 'sentence-transformers': 'sentence-transformers',
GPUStack: 'gpustack',
}; };
export const TimezoneList = [ export const TimezoneList = [

View File

@ -31,6 +31,7 @@ export const LocalLlmFactories = [
'Replicate', 'Replicate',
'OpenRouter', 'OpenRouter',
'HuggingFace', 'HuggingFace',
'GPUStack',
]; ];
export enum TenantRole { export enum TenantRole {

View File

@ -29,6 +29,7 @@ const llmFactoryToUrlMap = {
OpenRouter: 'https://openrouter.ai/docs', OpenRouter: 'https://openrouter.ai/docs',
HuggingFace: HuggingFace:
'https://huggingface.co/docs/text-embeddings-inference/quick_tour', 'https://huggingface.co/docs/text-embeddings-inference/quick_tour',
GPUStack: 'https://docs.gpustack.ai/latest/quickstart',
}; };
type LlmFactory = keyof typeof llmFactoryToUrlMap; type LlmFactory = keyof typeof llmFactoryToUrlMap;
@ -76,6 +77,13 @@ const OllamaModal = ({
{ value: 'speech2text', label: 'sequence2text' }, { value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' }, { value: 'tts', label: 'tts' },
], ],
GPUStack: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' },
],
Default: [ Default: [
{ value: 'chat', label: 'chat' }, { value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' }, { value: 'embedding', label: 'embedding' },