feat: add proxy configuration for Cohere model (#4152)

This commit is contained in:
Moonlit 2024-05-07 18:12:13 +08:00 committed by GitHub
parent 591b993685
commit 2fdd64c1b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 7 deletions

View File

@ -32,6 +32,15 @@ provider_credential_schema:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
show_on: [ ]
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
model_credential_schema:
model:
label:
@ -70,3 +79,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1

View File

@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop:
model_parameters['end_sequences'] = stop
@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
def _handle_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop:
model_parameters['stop_sequences'] = stop
@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: number of tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize(
text=text,

View File

@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
)
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.rerank(
query=query,
documents=docs,

View File

@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return []
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize(
text=text,
@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings and used tokens
"""
# initialize client
client = cohere.Client(credentials.get('api_key'))
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
# call embedding model
response = client.embed(