mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-14 07:28:17 +08:00
174 lines
6.2 KiB
Python
174 lines
6.2 KiB
Python
import os
|
|
import re
|
|
from typing import Union
|
|
|
|
import pytest
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
from requests import Response
|
|
from requests.exceptions import ConnectionError
|
|
from requests.sessions import Session
|
|
from xinference_client.client.restful.restful_client import (
|
|
Client,
|
|
RESTfulChatglmCppChatModelHandle,
|
|
RESTfulChatModelHandle,
|
|
RESTfulEmbeddingModelHandle,
|
|
RESTfulGenerateModelHandle,
|
|
RESTfulRerankModelHandle,
|
|
)
|
|
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
|
|
|
|
|
class MockXinferenceClass:
|
|
def get_chat_model(
|
|
self: Client, model_uid: str
|
|
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
|
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
|
raise RuntimeError("404 Not Found")
|
|
|
|
if "generate" == model_uid:
|
|
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
|
if "chat" == model_uid:
|
|
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
|
if "embedding" == model_uid:
|
|
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
|
if "rerank" == model_uid:
|
|
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
|
raise RuntimeError("404 Not Found")
|
|
|
|
def get(self: Session, url: str, **kwargs):
|
|
response = Response()
|
|
if "v1/models/" in url:
|
|
# get model uid
|
|
model_uid = url.split("/")[-1] or ""
|
|
if not re.match(
|
|
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
|
) and model_uid not in {"generate", "chat", "embedding", "rerank"}:
|
|
response.status_code = 404
|
|
response._content = b"{}"
|
|
return response
|
|
|
|
# check if url is valid
|
|
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
|
response.status_code = 404
|
|
response._content = b"{}"
|
|
return response
|
|
|
|
if model_uid in {"generate", "chat"}:
|
|
response.status_code = 200
|
|
response._content = b"""{
|
|
"model_type": "LLM",
|
|
"address": "127.0.0.1:43877",
|
|
"accelerators": [
|
|
"0",
|
|
"1"
|
|
],
|
|
"model_name": "chatglm3-6b",
|
|
"model_lang": [
|
|
"en"
|
|
],
|
|
"model_ability": [
|
|
"generate",
|
|
"chat"
|
|
],
|
|
"model_description": "latest chatglm3",
|
|
"model_format": "pytorch",
|
|
"model_size_in_billions": 7,
|
|
"quantization": "none",
|
|
"model_hub": "huggingface",
|
|
"revision": null,
|
|
"context_length": 2048,
|
|
"replica": 1
|
|
}"""
|
|
return response
|
|
|
|
elif model_uid == "embedding":
|
|
response.status_code = 200
|
|
response._content = b"""{
|
|
"model_type": "embedding",
|
|
"address": "127.0.0.1:43877",
|
|
"accelerators": [
|
|
"0",
|
|
"1"
|
|
],
|
|
"model_name": "bge",
|
|
"model_lang": [
|
|
"en"
|
|
],
|
|
"revision": null,
|
|
"max_tokens": 512
|
|
}"""
|
|
return response
|
|
|
|
elif "v1/cluster/auth" in url:
|
|
response.status_code = 200
|
|
response._content = b"""{
|
|
"auth": true
|
|
}"""
|
|
return response
|
|
|
|
def _check_cluster_authenticated(self):
|
|
self._cluster_authed = True
|
|
|
|
def rerank(
|
|
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
|
) -> dict:
|
|
# check if self._model_uid is a valid uuid
|
|
if (
|
|
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
|
and self._model_uid != "rerank"
|
|
):
|
|
raise RuntimeError("404 Not Found")
|
|
|
|
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
|
raise RuntimeError("404 Not Found")
|
|
|
|
if top_n is None:
|
|
top_n = 1
|
|
|
|
return {
|
|
"results": [
|
|
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
|
]
|
|
}
|
|
|
|
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
|
# check if self._model_uid is a valid uuid
|
|
if (
|
|
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
|
and self._model_uid != "embedding"
|
|
):
|
|
raise RuntimeError("404 Not Found")
|
|
|
|
if isinstance(input, str):
|
|
input = [input]
|
|
ipt_len = len(input)
|
|
|
|
embedding = Embedding(
|
|
object="list",
|
|
model=self._model_uid,
|
|
data=[
|
|
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
|
for i in range(ipt_len)
|
|
],
|
|
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
|
)
|
|
|
|
return embedding
|
|
|
|
|
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
|
if MOCK:
|
|
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
|
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
|
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
|
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
|
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
|
yield
|
|
|
|
if MOCK:
|
|
monkeypatch.undo()
|