mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:05:51 +08:00
feat: Add support for TEI API key authentication (#11006)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
16c41585e1
commit
096c0ad564
@ -34,3 +34,11 @@ model_credential_schema:
|
|||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
||||||
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: false
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||||||
|
|
||||||
server_url = server_url.removesuffix("/")
|
server_url = server_url.removesuffix("/")
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
api_key = credentials.get("api_key")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
|
||||||
|
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
server_url = credentials["server_url"]
|
server_url = credentials["server_url"]
|
||||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
headers = {"Content-Type": "application/json"}
|
||||||
|
api_key = credentials.get("api_key")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||||
if extra_args.model_type != "reranker":
|
if extra_args.model_type != "reranker":
|
||||||
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
raise CredentialsValidateFailedError("Current model is not a rerank model")
|
||||||
|
|
||||||
|
@ -26,13 +26,15 @@ cache_lock = Lock()
|
|||||||
|
|
||||||
class TeiHelper:
|
class TeiHelper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
def get_tei_extra_parameter(
|
||||||
|
server_url: str, model_name: str, headers: Optional[dict] = None
|
||||||
|
) -> TeiModelExtraParameter:
|
||||||
TeiHelper._clean_cache()
|
TeiHelper._clean_cache()
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
if model_name not in cache:
|
if model_name not in cache:
|
||||||
cache[model_name] = {
|
cache[model_name] = {
|
||||||
"expires": time() + 300,
|
"expires": time() + 300,
|
||||||
"value": TeiHelper._get_tei_extra_parameter(server_url),
|
"value": TeiHelper._get_tei_extra_parameter(server_url, headers),
|
||||||
}
|
}
|
||||||
return cache[model_name]["value"]
|
return cache[model_name]["value"]
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ class TeiHelper:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
|
||||||
"""
|
"""
|
||||||
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||||
"""
|
"""
|
||||||
@ -61,7 +63,7 @@ class TeiHelper:
|
|||||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = session.get(url, timeout=10)
|
response = session.get(url, headers=headers, timeout=10)
|
||||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||||
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@ -86,7 +88,7 @@ class TeiHelper:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
|
||||||
"""
|
"""
|
||||||
Invoke tokenize endpoint
|
Invoke tokenize endpoint
|
||||||
|
|
||||||
@ -114,15 +116,15 @@ class TeiHelper:
|
|||||||
:param server_url: server url
|
:param server_url: server url
|
||||||
:param texts: texts to tokenize
|
:param texts: texts to tokenize
|
||||||
"""
|
"""
|
||||||
resp = httpx.post(
|
url = f"{server_url}/tokenize"
|
||||||
f"{server_url}/tokenize",
|
json_data = {"inputs": texts}
|
||||||
json={"inputs": texts},
|
resp = httpx.post(url, json=json_data, headers=headers)
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Invoke embeddings endpoint
|
Invoke embeddings endpoint
|
||||||
|
|
||||||
@ -147,15 +149,14 @@ class TeiHelper:
|
|||||||
:param texts: texts to embed
|
:param texts: texts to embed
|
||||||
"""
|
"""
|
||||||
# Use OpenAI compatible API here, which has usage tracking
|
# Use OpenAI compatible API here, which has usage tracking
|
||||||
resp = httpx.post(
|
url = f"{server_url}/v1/embeddings"
|
||||||
f"{server_url}/v1/embeddings",
|
json_data = {"input": texts}
|
||||||
json={"input": texts},
|
resp = httpx.post(url, json=json_data, headers=headers)
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Invoke rerank endpoint
|
Invoke rerank endpoint
|
||||||
|
|
||||||
@ -173,10 +174,7 @@ class TeiHelper:
|
|||||||
:param candidates: candidates to rerank
|
:param candidates: candidates to rerank
|
||||||
"""
|
"""
|
||||||
params = {"query": query, "texts": docs, "return_text": True}
|
params = {"query": query, "texts": docs, "return_text": True}
|
||||||
|
url = f"{server_url}/rerank"
|
||||||
response = httpx.post(
|
response = httpx.post(url, json=params, headers=headers)
|
||||||
server_url + "/rerank",
|
|
||||||
json=params,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
server_url = server_url.removesuffix("/")
|
server_url = server_url.removesuffix("/")
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
api_key = credentials["api_key"]
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
# get model properties
|
# get model properties
|
||||||
context_size = self._get_context_size(model, credentials)
|
context_size = self._get_context_size(model, credentials)
|
||||||
max_chunks = self._get_max_chunks(model, credentials)
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
|
|
||||||
# get tokenized results from TEI
|
# get tokenized results from TEI
|
||||||
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||||
|
|
||||||
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||||
# Check if the number of tokens is larger than the context size
|
# Check if the number of tokens is larger than the context size
|
||||||
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
for i in _iter:
|
for i in _iter:
|
||||||
iter_texts = inputs[i : i + max_chunks]
|
iter_texts = inputs[i : i + max_chunks]
|
||||||
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
|
||||||
embeddings = results["data"]
|
embeddings = results["data"]
|
||||||
embeddings = [embedding["embedding"] for embedding in embeddings]
|
embeddings = [embedding["embedding"] for embedding in embeddings]
|
||||||
batched_embeddings.extend(embeddings)
|
batched_embeddings.extend(embeddings)
|
||||||
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
server_url = server_url.removesuffix("/")
|
server_url = server_url.removesuffix("/")
|
||||||
|
|
||||||
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
headers = {
|
||||||
|
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
|
||||||
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
server_url = credentials["server_url"]
|
server_url = credentials["server_url"]
|
||||||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
api_key = credentials.get("api_key")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
|
||||||
print(extra_args)
|
print(extra_args)
|
||||||
if extra_args.model_type != "embedding":
|
if extra_args.model_type != "embedding":
|
||||||
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
raise CredentialsValidateFailedError("Current model is not a embedding model")
|
||||||
|
@ -20,6 +20,7 @@ env =
|
|||||||
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||||
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
|
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
|
||||||
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
|
TEI_RERANK_SERVER_URL = http://a.abc.com:11451
|
||||||
|
TEI_API_KEY = ttttttttttttttt
|
||||||
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
|
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
|
||||||
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
|
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
|
||||||
XINFERENCE_CHAT_MODEL_UID = chat
|
XINFERENCE_CHAT_MODEL_UID = chat
|
||||||
|
@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||||||
model="reranker",
|
model="reranker",
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
texts=["hello", "world"],
|
texts=["hello", "world"],
|
||||||
user="abc-123",
|
user="abc-123",
|
||||||
|
@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||||||
model="embedding",
|
model="embedding",
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
credentials={
|
credentials={
|
||||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||||
},
|
},
|
||||||
query="Who is Kasumi?",
|
query="Who is Kasumi?",
|
||||||
docs=[
|
docs=[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user