diff --git a/api/core/model_runtime/model_providers/cohere/cohere.yaml b/api/core/model_runtime/model_providers/cohere/cohere.yaml index c889a6bfe0..bd40057fe9 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.yaml +++ b/api/core/model_runtime/model_providers/cohere/cohere.yaml @@ -32,6 +32,15 @@ provider_credential_schema: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key show_on: [ ] + - variable: base_url + label: + zh_Hans: API Base + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 + en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 model_credential_schema: model: label: @@ -70,3 +79,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key + - variable: base_url + label: + zh_Hans: API Base + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 + en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 6ace77b813..f9fae5e8ca 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) if stop: model_parameters['end_sequences'] = stop @@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], + def _handle_generate_stream_response(self, model: str, credentials: dict, + response: Iterator[GenerateStreamedResponse], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) if stop: model_parameters['stop_sequences'] = stop @@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.tokenize( text=text, diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 4194f27eb9..d2fdb30c6f 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel): ) # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.rerank( query=query, documents=docs, diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 8269a41810..0540fb740f 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return [] # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) response = client.tokenize( text=text, @@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) # call embedding model response = client.embed(