feat: Add support for TEI API key authentication (#11006)

Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
kenwoodjw 2024-11-23 23:55:35 +08:00 committed by GitHub
parent 16c41585e1
commit 096c0ad564
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 26 deletions

View File

@ -34,3 +34,11 @@ model_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080 zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
- variable: api_key
label:
en_US: API Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try: try:
results = TeiHelper.invoke_rerank(server_url, query, docs) results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
rerank_documents = [] rerank_documents = []
for result in results: for result in results:
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
""" """
try: try:
server_url = credentials["server_url"] server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
if extra_args.model_type != "reranker": if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model") raise CredentialsValidateFailedError("Current model is not a rerank model")

View File

@ -26,13 +26,15 @@ cache_lock = Lock()
class TeiHelper: class TeiHelper:
@staticmethod @staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: def get_tei_extra_parameter(
server_url: str, model_name: str, headers: Optional[dict] = None
) -> TeiModelExtraParameter:
TeiHelper._clean_cache() TeiHelper._clean_cache()
with cache_lock: with cache_lock:
if model_name not in cache: if model_name not in cache:
cache[model_name] = { cache[model_name] = {
"expires": time() + 300, "expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url), "value": TeiHelper._get_tei_extra_parameter(server_url, headers),
} }
return cache[model_name]["value"] return cache[model_name]["value"]
@ -47,7 +49,7 @@ class TeiHelper:
pass pass
@staticmethod @staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
""" """
get tei model extra parameter like model_type, max_input_length, max_batch_requests get tei model extra parameter like model_type, max_input_length, max_batch_requests
""" """
@ -61,7 +63,7 @@ class TeiHelper:
session.mount("https://", HTTPAdapter(max_retries=3)) session.mount("https://", HTTPAdapter(max_retries=3))
try: try:
response = session.get(url, timeout=10) response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200: if response.status_code != 200:
@ -86,7 +88,7 @@ class TeiHelper:
) )
@staticmethod @staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
""" """
Invoke tokenize endpoint Invoke tokenize endpoint
@ -114,15 +116,15 @@ class TeiHelper:
:param server_url: server url :param server_url: server url
:param texts: texts to tokenize :param texts: texts to tokenize
""" """
resp = httpx.post( url = f"{server_url}/tokenize"
f"{server_url}/tokenize", json_data = {"inputs": texts}
json={"inputs": texts}, resp = httpx.post(url, json=json_data, headers=headers)
)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
@staticmethod @staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict: def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
""" """
Invoke embeddings endpoint Invoke embeddings endpoint
@ -147,15 +149,14 @@ class TeiHelper:
:param texts: texts to embed :param texts: texts to embed
""" """
# Use OpenAI compatible API here, which has usage tracking # Use OpenAI compatible API here, which has usage tracking
resp = httpx.post( url = f"{server_url}/v1/embeddings"
f"{server_url}/v1/embeddings", json_data = {"input": texts}
json={"input": texts}, resp = httpx.post(url, json=json_data, headers=headers)
)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
@staticmethod @staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
""" """
Invoke rerank endpoint Invoke rerank endpoint
@ -173,10 +174,7 @@ class TeiHelper:
:param candidates: candidates to rerank :param candidates: candidates to rerank
""" """
params = {"query": query, "texts": docs, "return_text": True} params = {"query": query, "texts": docs, "return_text": True}
url = f"{server_url}/rerank"
response = httpx.post( response = httpx.post(url, json=params, headers=headers)
server_url + "/rerank",
json=params,
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials["api_key"]
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# get model properties # get model properties
context_size = self._get_context_size(model, credentials) context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials) max_chunks = self._get_max_chunks(model, credentials)
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0 used_tokens = 0
# get tokenized results from TEI # get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size # Check if the number of tokens is larger than the context size
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0 used_tokens = 0
for i in _iter: for i in _iter:
iter_texts = inputs[i : i + max_chunks] iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts) results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
embeddings = results["data"] embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings] embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings) batched_embeddings.extend(embeddings)
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
}
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
num_tokens = sum(len(tokens) for tokens in batch_tokens) num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens return num_tokens
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
""" """
try: try:
server_url = credentials["server_url"] server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
print(extra_args) print(extra_args)
if extra_args.model_type != "embedding": if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model") raise CredentialsValidateFailedError("Current model is not a embedding model")

View File

@ -20,6 +20,7 @@ env =
OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
TEI_RERANK_SERVER_URL = http://a.abc.com:11451 TEI_RERANK_SERVER_URL = http://a.abc.com:11451
TEI_API_KEY = ttttttttttttttt
UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
XINFERENCE_CHAT_MODEL_UID = chat XINFERENCE_CHAT_MODEL_UID = chat

View File

@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
model="reranker", model="reranker",
credentials={ credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
) )
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
model=model_name, model=model_name,
credentials={ credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
) )
@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name, model=model_name,
credentials={ credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
texts=["hello", "world"], texts=["hello", "world"],
user="abc-123", user="abc-123",

View File

@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
model="embedding", model="embedding",
credentials={ credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"), "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
) )
@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
model=model_name, model=model_name,
credentials={ credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"), "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
) )
@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
model=model_name, model=model_name,
credentials={ credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"), "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
}, },
query="Who is Kasumi?", query="Who is Kasumi?",
docs=[ docs=[