feat: add proxy configuration for Cohere model (#4152)

This commit is contained in:
Moonlit 2024-05-07 18:12:13 +08:00 committed by GitHub
parent 591b993685
commit 2fdd64c1b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 7 deletions

View File

@ -32,6 +32,15 @@ provider_credential_schema:
zh_Hans: 在此输入您的 API Key zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key en_US: Enter your API Key
show_on: [ ] show_on: [ ]
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
model_credential_schema: model_credential_schema:
model: model:
label: label:
@ -70,3 +79,12 @@ model_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入您的 API Key zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key en_US: Enter your API Key
- variable: base_url
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base如 https://api.cohere.ai/v1
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1

View File

@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop: if stop:
model_parameters['end_sequences'] = stop model_parameters['end_sequences'] = stop
@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], def _handle_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm stream response
@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
if stop: if stop:
model_parameters['stop_sequences'] = stop model_parameters['stop_sequences'] = stop
@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:return: number of tokens :return: number of tokens
""" """
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize( response = client.tokenize(
text=text, text=text,

View File

@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
) )
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.rerank( response = client.rerank(
query=query, query=query,
documents=docs, documents=docs,

View File

@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return [] return []
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
response = client.tokenize( response = client.tokenize(
text=text, text=text,
@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return: embeddings and used tokens :return: embeddings and used tokens
""" """
# initialize client # initialize client
client = cohere.Client(credentials.get('api_key')) client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
# call embedding model # call embedding model
response = client.embed( response = client.embed(