From a18dde9b0dda6a32d87903c00da41e5211a5901c Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 21 Jan 2024 20:52:56 +0800 Subject: [PATCH] feat: add cohere llm and embedding (#2115) --- .../__base/large_language_model.py | 5 + .../model_providers/cohere/cohere.yaml | 45 +- .../model_providers/cohere/llm/__init__.py | 0 .../model_providers/cohere/llm/_position.yaml | 8 + .../cohere/llm/command-chat.yaml | 62 ++ .../cohere/llm/command-light-chat.yaml | 62 ++ .../llm/command-light-nightly-chat.yaml | 62 ++ .../cohere/llm/command-light-nightly.yaml | 44 ++ .../cohere/llm/command-light.yaml | 44 ++ .../cohere/llm/command-nightly-chat.yaml | 62 ++ .../cohere/llm/command-nightly.yaml | 44 ++ .../model_providers/cohere/llm/command.yaml | 44 ++ .../model_providers/cohere/llm/llm.py | 565 ++++++++++++++++++ .../cohere/text_embedding/__init__.py | 0 .../cohere/text_embedding/_position.yaml | 7 + .../embed-english-light-v2.0.yaml | 9 + .../embed-english-light-v3.0.yaml | 9 + .../text_embedding/embed-english-v2.0.yaml | 9 + .../text_embedding/embed-english-v3.0.yaml | 9 + .../embed-multilingual-light-v3.0.yaml | 9 + .../embed-multilingual-v2.0.yaml | 9 + .../embed-multilingual-v3.0.yaml | 9 + .../cohere/text_embedding/text_embedding.py | 234 ++++++++ api/core/spiltter/fixed_text_splitter.py | 3 + api/requirements.txt | 2 +- .../model_runtime/cohere/test_llm.py | 272 +++++++++ .../cohere/test_text_embedding.py | 64 ++ 27 files changed, 1689 insertions(+), 3 deletions(-) create mode 100644 api/core/model_runtime/model_providers/cohere/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/cohere/llm/_position.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-light.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py create mode 100644 api/tests/integration_tests/model_runtime/cohere/test_llm.py create mode 100644 api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 0bf6a385ac..75ea7bacef 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,6 @@ import logging import os +import re import time from abc import abstractmethod from typing import Generator, List, Optional, Union @@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel): """ raise NotImplementedError + def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: + """Cut off the text as soon as any stop words occur.""" + return re.split("|".join(stop), text, maxsplit=1)[0] + def _llm_result_to_stream(self, result: LLMResult) -> Generator: """ Transform llm result to stream diff --git a/api/core/model_runtime/model_providers/cohere/cohere.yaml b/api/core/model_runtime/model_providers/cohere/cohere.yaml index b9a5fcfe0c..c889a6bfe0 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.yaml +++ b/api/core/model_runtime/model_providers/cohere/cohere.yaml @@ -14,9 +14,12 @@ help: url: en_US: https://dashboard.cohere.com/api-keys supported_model_types: + - llm + - text-embedding - rerank configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -26,6 +29,44 @@ provider_credential_schema: type: secret-input required: true placeholder: - zh_Hans: 请填写 API Key - en_US: Please fill in API Key + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key show_on: [ ] +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/cohere/llm/__init__.py b/api/core/model_runtime/model_providers/cohere/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/cohere/llm/_position.yaml b/api/core/model_runtime/model_providers/cohere/llm/_position.yaml new file mode 100644 index 0000000000..367117c9e8 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/_position.yaml @@ -0,0 +1,8 @@ +- command-chat +- command-light-chat +- command-nightly-chat +- command-light-nightly-chat +- command +- command-light +- command-nightly +- command-light-nightly diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml new file mode 100644 index 0000000000..4bcfae6e5d --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml @@ -0,0 +1,62 @@ +model: command-chat +label: + zh_Hans: command-chat + en_US: command-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 + - name: preamble_override + label: + zh_Hans: 前导文本 + en_US: Preamble + type: string + help: + zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 + en_US: When specified, the default Cohere preamble will be replaced with the provided one. + required: false + - name: prompt_truncation + label: + zh_Hans: 提示截断 + en_US: Prompt Truncation + type: string + help: + zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 + en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + required: true + default: 'AUTO' + options: + - 'AUTO' + - 'OFF' +pricing: + input: '1.0' + output: '2.0' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml new file mode 100644 index 0000000000..8d8075967c --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml @@ -0,0 +1,62 @@ +model: command-light-chat +label: + zh_Hans: command-light-chat + en_US: command-light-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 + - name: preamble_override + label: + zh_Hans: 前导文本 + en_US: Preamble + type: string + help: + zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 + en_US: When specified, the default Cohere preamble will be replaced with the provided one. + required: false + - name: prompt_truncation + label: + zh_Hans: 提示截断 + en_US: Prompt Truncation + type: string + help: + zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 + en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + required: true + default: 'AUTO' + options: + - 'AUTO' + - 'OFF' +pricing: + input: '0.3' + output: '0.6' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml new file mode 100644 index 0000000000..4b6b66951e --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml @@ -0,0 +1,62 @@ +model: command-light-nightly-chat +label: + zh_Hans: command-light-nightly-chat + en_US: command-light-nightly-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 + - name: preamble_override + label: + zh_Hans: 前导文本 + en_US: Preamble + type: string + help: + zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 + en_US: When specified, the default Cohere preamble will be replaced with the provided one. + required: false + - name: prompt_truncation + label: + zh_Hans: 提示截断 + en_US: Prompt Truncation + type: string + help: + zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 + en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + required: true + default: 'AUTO' + options: + - 'AUTO' + - 'OFF' +pricing: + input: '0.3' + output: '0.6' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml new file mode 100644 index 0000000000..6a76c25019 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml @@ -0,0 +1,44 @@ +model: command-light-nightly +label: + zh_Hans: command-light-nightly + en_US: command-light-nightly +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 +pricing: + input: '0.3' + output: '0.6' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml new file mode 100644 index 0000000000..ff9a594b66 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml @@ -0,0 +1,44 @@ +model: command-light +label: + zh_Hans: command-light + en_US: command-light +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 +pricing: + input: '0.3' + output: '0.6' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml new file mode 100644 index 0000000000..811f237c88 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml @@ -0,0 +1,62 @@ +model: command-nightly-chat +label: + zh_Hans: command-nightly-chat + en_US: command-nightly-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 + - name: preamble_override + label: + zh_Hans: 前导文本 + en_US: Preamble + type: string + help: + zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 + en_US: When specified, the default Cohere preamble will be replaced with the provided one. + required: false + - name: prompt_truncation + label: + zh_Hans: 提示截断 + en_US: Prompt Truncation + type: string + help: + zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 + en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + required: true + default: 'AUTO' + options: + - 'AUTO' + - 'OFF' +pricing: + input: '1.0' + output: '2.0' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml new file mode 100644 index 0000000000..2c99bf7684 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml @@ -0,0 +1,44 @@ +model: command-nightly +label: + zh_Hans: command-nightly + en_US: command-nightly +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 +pricing: + input: '1.0' + output: '2.0' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command.yaml b/api/core/model_runtime/model_providers/cohere/llm/command.yaml new file mode 100644 index 0000000000..d41c2951fc --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command.yaml @@ -0,0 +1,44 @@ +model: command +label: + zh_Hans: command + en_US: command +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 256 + max: 4096 +pricing: + input: '1.0' + output: '2.0' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py new file mode 100644 index 0000000000..7b7687bc99 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -0,0 +1,565 @@ +import logging +from typing import Generator, List, Optional, Union, cast, Tuple + +import cohere +from cohere.responses import Chat, Generations +from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd +from cohere.responses.generation import StreamingText, StreamingGenerations + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, + PromptMessageContentType, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, + PromptMessageTool) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \ + InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class CohereLargeLanguageModel(LargeLanguageModel): + """ + Model class for Cohere large language model. + """ + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if model_mode == LLMMode.CHAT: + return self._chat_generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=user + ) + else: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=user + ) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model) + + try: + if model_mode == LLMMode.CHAT: + return self._num_tokens_from_messages(model, credentials, prompt_messages) + else: + return self._num_tokens_from_string(model, credentials, prompt_messages[0].content) + except Exception as e: + raise self._transform_invoke_error(e) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # get model mode + model_mode = self.get_model_mode(model) + + if model_mode == LLMMode.CHAT: + self._chat_generate( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content='ping')], + model_parameters={ + 'max_tokens': 20, + 'temperature': 0, + }, + stream=False + ) + else: + self._generate( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content='ping')], + model_parameters={ + 'max_tokens': 20, + 'temperature': 0, + }, + stream=False + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # initialize client + client = cohere.Client(credentials.get('api_key')) + + if stop: + model_parameters['end_sequences'] = stop + + response = client.generate( + prompt=prompt_messages[0].content, + model=model, + stream=stream, + **model_parameters, + ) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response(self, model: str, credentials: dict, response: Generations, + prompt_messages: list[PromptMessage]) \ + -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + assistant_text = response.generations[0].text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + # calculate num tokens + prompt_tokens = response.meta['billed_units']['input_tokens'] + completion_tokens = response.meta['billed_units']['output_tokens'] + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + index = 1 + full_assistant_content = '' + for chunk in response: + if isinstance(chunk, StreamingText): + chunk = cast(StreamingText, chunk) + text = chunk.text + + if text is None: + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=text + ) + + full_assistant_content += text + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) + ) + + index += 1 + elif chunk is None: + # calculate num tokens + prompt_tokens = response.meta['billed_units']['input_tokens'] + completion_tokens = response.meta['billed_units']['output_tokens'] + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage(content=''), + finish_reason=response.finish_reason, + usage=usage + ) + ) + break + + def _chat_generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm chat model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # initialize client + client = cohere.Client(credentials.get('api_key')) + + if user: + model_parameters['user_name'] = user + + message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + + # chat model + real_model = model + if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: + real_model = model.removesuffix('-chat') + + response = client.chat( + message=message, + chat_history=chat_histories, + model=real_model, + stream=stream, + return_preamble=True, + **model_parameters, + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) + + return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) + + def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, + prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ + -> LLMResult: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param stop: stop words + :return: llm response + """ + assistant_text = response.text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) + completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + if stop: + # enforce stop tokens + assistant_text = self.enforce_stop_tokens(assistant_text, stop) + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + # transform response + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + system_fingerprint=response.preamble + ) + + return response + + def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, + prompt_messages: list[PromptMessage], + stop: Optional[List[str]] = None) -> Generator: + """ + Handle llm chat stream response + + :param model: model name + :param response: response + :param prompt_messages: prompt messages + :param stop: stop words + :return: llm response chunk generator + """ + + def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, + preamble: Optional[str] = None) -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) + + full_assistant_prompt_message = AssistantPromptMessage( + content=full_text + ) + completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=preamble, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage(content=''), + finish_reason=finish_reason, + usage=usage + ) + ) + + index = 1 + full_assistant_content = '' + for chunk in response: + if isinstance(chunk, StreamTextGeneration): + chunk = cast(StreamTextGeneration, chunk) + text = chunk.text + + if text is None: + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=text + ) + + # stop + # notice: This logic can only cover few stop scenarios + if stop and text in stop: + yield final_response(full_assistant_content, index, 'stop') + break + + full_assistant_content += text + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) + ) + + index += 1 + elif isinstance(chunk, StreamEnd): + chunk = cast(StreamEnd, chunk) + yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) + index += 1 + + def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ + -> Tuple[str, list[dict]]: + """ + Convert prompt messages to message and chat histories + :param prompt_messages: prompt messages + :return: + """ + chat_histories = [] + for prompt_message in prompt_messages: + chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) + + # get latest message from chat histories and pop it + if len(chat_histories) > 0: + latest_message = chat_histories.pop() + message = latest_message['message'] + else: + raise ValueError('Prompt messages is empty') + + return message, chat_histories + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for Cohere model + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "USER", "message": message.content} + else: + sub_message_text = '' + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_text += message_content.data + + message_dict = {"role": "USER", "message": sub_message_text} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "CHATBOT", "message": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "USER", "message": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name is not None: + message_dict["user_name"] = message.name + + return message_dict + + def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: + """ + Calculate num tokens for text completion model. + + :param model: model name + :param credentials: credentials + :param text: prompt text + :return: number of tokens + """ + # initialize client + client = cohere.Client(credentials.get('api_key')) + + response = client.tokenize( + text=text, + model=model + ) + + return response.length + + def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: + """Calculate num tokens Cohere model.""" + messages = [self._convert_prompt_message_to_dict(m) for m in messages] + message_strs = [f"{message['role']}: {message['message']}" for message in messages] + message_str = "\n".join(message_strs) + + real_model = model + if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: + real_model = model.removesuffix('-chat') + + return self._num_tokens_from_string(real_model, credentials, message_str) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. + + :param model: model name + :param credentials: credentials + + :return: model schema + """ + # get model schema + models = self.predefined_models() + model_map = {model.model: model for model in models} + + mode = credentials.get('mode') + + if mode == 'chat': + base_model_schema = model_map['command-light-chat'] + else: + base_model_schema = model_map['command-light'] + + base_model_schema = cast(AIModelEntity, base_model_schema) + + base_model_schema_features = base_model_schema.features or [] + base_model_schema_model_properties = base_model_schema.model_properties or {} + base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] + + entity = AIModelEntity( + model=model, + label=I18nObject( + zh_Hans=model, + en_US=model + ), + model_type=ModelType.LLM, + features=[feature for feature in base_model_schema_features], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + key: property for key, property in base_model_schema_model_properties.items() + }, + parameter_rules=[rule for rule in base_model_schema_parameters_rules], + pricing=base_model_schema.pricing + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + cohere.CohereConnectionError + ], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ + cohere.CohereAPIError, + cohere.CohereError, + ] + } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/__init__.py b/api/core/model_runtime/model_providers/cohere/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml new file mode 100644 index 0000000000..967a946f34 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml @@ -0,0 +1,7 @@ +- embed-multilingual-v3.0 +- embed-multilingual-light-v3.0 +- embed-english-v3.0 +- embed-english-light-v3.0 +- embed-multilingual-v2.0 +- embed-english-v2.0 +- embed-english-light-v2.0 diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml new file mode 100644 index 0000000000..8d2aaf1737 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml @@ -0,0 +1,9 @@ +model: embed-english-light-v2.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml new file mode 100644 index 0000000000..43b79922e3 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: embed-english-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml new file mode 100644 index 0000000000..acee82b202 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml @@ -0,0 +1,9 @@ +model: embed-english-v2.0 +model_type: text-embedding +model_properties: + context_size: 4096 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml new file mode 100644 index 0000000000..0ad713253e --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml @@ -0,0 +1,9 @@ +model: embed-english-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml new file mode 100644 index 0000000000..c253067233 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: embed-multilingual-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml new file mode 100644 index 0000000000..4dbc37d5e8 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml @@ -0,0 +1,9 @@ +model: embed-multilingual-v2.0 +model_type: text-embedding +model_properties: + context_size: 768 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml new file mode 100644 index 0000000000..ec689ada1b --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml @@ -0,0 +1,9 @@ +model: embed-multilingual-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py new file mode 100644 index 0000000000..d824ed0b3d --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -0,0 +1,234 @@ +import time +from typing import Optional, Tuple + +import cohere +import numpy as np +from cohere.responses import Tokens + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ + InvokeAuthorizationError, InvokeBadRequestError, InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class CohereTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + embeddings: list[list[float]] = [[] for _ in range(len(texts))] + tokens = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + tokenize_response = self._tokenize( + model=model, + credentials=credentials, + text=text + ) + + for j in range(0, tokenize_response.length, context_size): + tokens += [tokenize_response.token_strings[j: j + context_size]] + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(tokens), max_chunks) + + for i in _iter: + # call embedding model + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + model=model, + credentials=credentials, + texts=["".join(token) for token in tokens[i: i + max_chunks]] + ) + + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + results: list[list[list[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] + for i in range(len(indices)): + results[indices[i]].append(batched_embeddings[i]) + num_tokens_in_batch[indices[i]].append(len(tokens[i])) + + for i in range(len(texts)): + _result = results[i] + if len(_result) == 0: + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + model=model, + credentials=credentials, + texts=[""] + ) + + used_tokens += embedding_used_tokens + average = embeddings_batch[0] + else: + average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + embeddings[i] = (average / np.linalg.norm(average)).tolist() + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=used_tokens + ) + + return TextEmbeddingResult( + embeddings=embeddings, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + + full_text = ' '.join(texts) + + try: + response = self._tokenize( + model=model, + credentials=credentials, + text=full_text + ) + except Exception as e: + raise self._transform_invoke_error(e) + + return response.length + + def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: + """ + Tokenize text + :param model: model name + :param credentials: model credentials + :param text: text to tokenize + :return: + """ + # initialize client + client = cohere.Client(credentials.get('api_key')) + + response = client.tokenize( + text=text, + model=model + ) + + return response + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # call embedding model + self._embedding_invoke( + model=model, + credentials=credentials, + texts=['ping'] + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: + """ + Invoke embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings and used tokens + """ + # initialize client + client = cohere.Client(credentials.get('api_key')) + + # call embedding model + response = client.embed( + texts=texts, + model=model, + input_type='search_document' if len(texts) > 1 else 'search_query' + ) + + return response.embeddings, response.meta['billed_units']['input_tokens'] + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + cohere.CohereConnectionError + ], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ + cohere.CohereAPIError, + cohere.CohereError, + ] + } diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py index b8f384eee9..a6895998cf 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/spiltter/fixed_text_splitter.py @@ -24,6 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): **kwargs: Any, ): def _token_encoder(text: str) -> int: + if not text: + return 0 + if embedding_model_instance: embedding_model_type_instance = embedding_model_instance.model_type_instance embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) diff --git a/api/requirements.txt b/api/requirements.txt index d5103da2bc..9e62f9bd75 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -54,7 +54,7 @@ zhipuai==1.0.7 werkzeug==2.3.8 pymilvus==2.3.0 qdrant-client==1.6.4 -cohere~=4.32 +cohere~=4.44 pyyaml~=6.0.1 numpy~=1.25.2 unstructured[docx,pptx,msg,md,ppt]~=0.10.27 diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py new file mode 100644 index 0000000000..613675c1fd --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -0,0 +1,272 @@ +import os +from typing import Generator + +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, + UserPromptMessage) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = CohereLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='command-light-chat', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='command-light-chat', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + } + ) + + +def test_validate_credentials_for_completion_model(): + model = CohereLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='command-light', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='command-light', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + } + ) + + +def test_invoke_completion_model(): + model = CohereLargeLanguageModel() + + credentials = { + 'api_key': os.environ.get('COHERE_API_KEY') + } + + result = model.invoke( + model='command-light', + credentials=credentials, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens': 1 + }, + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 + + +def test_invoke_stream_completion_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model='command-light', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens': 100 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_chat_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model='command-light-chat', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'p': 0.99, + 'presence_penalty': 0.0, + 'frequency_penalty': 0.0, + 'max_tokens': 10 + }, + stop=['How'], + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + for chunk in model._llm_result_to_stream(result): + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_invoke_stream_chat_model(): + model = CohereLargeLanguageModel() + + result = model.invoke( + model='command-light-chat', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens': 100 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = CohereLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='command-light', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert num_tokens == 3 + + num_tokens = model.get_num_tokens( + model='command-light-chat', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert num_tokens == 15 + + +def test_fine_tuned_model(): + model = CohereLargeLanguageModel() + + # test invoke + result = model.invoke( + model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY'), + 'mode': 'completion' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens': 100 + }, + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) + + +def test_fine_tuned_chat_model(): + model = CohereLargeLanguageModel() + + # test invoke + result = model.invoke( + model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY'), + 'mode': 'chat' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.0, + 'max_tokens': 100 + }, + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py new file mode 100644 index 0000000000..9a15acc260 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -0,0 +1,64 @@ +import os + +import pytest +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel + + +def test_validate_credentials(): + model = CohereTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='embed-multilingual-v3.0', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='embed-multilingual-v3.0', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + } + ) + + +def test_invoke_model(): + model = CohereTextEmbeddingModel() + + result = model.invoke( + model='embed-multilingual-v3.0', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + texts=[ + "hello", + "world", + " ".join(["long_text"] * 100), + " ".join(["another_long_text"] * 100) + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 811 + + +def test_get_num_tokens(): + model = CohereTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='embed-multilingual-v3.0', + credentials={ + 'api_key': os.environ.get('COHERE_API_KEY') + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 3