diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 19fff3a39b..6c7aba2488 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -112,7 +112,7 @@ class ModelProvider(ABC): model_class = None for name, obj in vars(mod).items(): if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__ - and obj != AIModel): + and obj != AIModel and obj.__module__ == mod.__name__): model_class = obj break diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index e4b78abefa..bf00caabd0 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -40,87 +40,4 @@ class _CommonOAI_API_Compat: requests.exceptions.ConnectTimeout, # Timeout requests.exceptions.ReadTimeout # Timeout ] - } - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - """ - generate custom model entities from credentials - """ - model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING - - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=model_type, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000), - ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1), - }, - parameter_rules=[ - ParameterRule( - name=DefaultParameterName.TEMPERATURE.value, - label=I18nObject(en_US="Temperature"), - type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 1)), - min=0, - max=2 - ), - ParameterRule( - name=DefaultParameterName.TOP_P.value, - label=I18nObject(en_US="Top P"), - type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), - min=0, - max=1 - ), - ParameterRule( - name="top_k", - label=I18nObject(en_US="Top K"), - type=ParameterType.INT, - default=int(credentials.get('top_k', 1)), - min=1, - max=100 - ), - ParameterRule( - name=DefaultParameterName.FREQUENCY_PENALTY.value, - label=I18nObject(en_US="Frequency Penalty"), - type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), - min=-2, - max=2 - ), - ParameterRule( - name=DefaultParameterName.PRESENCE_PENALTY.value, - label=I18nObject(en_US="PRESENCE Penalty"), - type=ParameterType.FLOAT, - default=float(credentials.get('PRESENCE_penalty', 0)), - min=-2, - max=2 - ), - ParameterRule( - name=DefaultParameterName.MAX_TOKENS.value, - label=I18nObject(en_US="Max Tokens"), - type=ParameterType.INT, - default=1024, - min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) - ], - pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) - ) - - if model_type == ModelType.LLM: - if credentials['mode'] == 'chat': - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value - else: - raise ValueError(f"Unknown completion type {credentials['completion_type']}") - - return entity + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index cf694b940b..338c655110 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -158,7 +158,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), ModelPropertyKey.MODE: credentials.get('mode'), }, parameter_rules=[ @@ -196,9 +196,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, - label=I18nObject(en_US="PRESENCE Penalty"), + label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('PRESENCE_penalty', 0)), + default=float(credentials.get('presence_penalty', 0)), min=-2, max=2 ), @@ -219,6 +219,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ) ) + if credentials['mode'] == 'chat': + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif credentials['mode'] == 'completion': + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {credentials['completion_type']}") + return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. @@ -261,7 +268,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, 'chat/completions') data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] - elif completion_type == LLMMode.COMPLETION: + elif completion_type is LLMMode.COMPLETION: endpoint_url = urljoin(endpoint_url, 'completions') data['prompt'] = prompt_messages[0].content else: @@ -291,10 +298,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): stream=stream ) - # Debug: Print request headers and json data - logger.debug(f"Request headers: {headers}") - logger.debug(f"Request JSON data: {data}") - if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index b2a4af0057..e5d5f9547e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -2,8 +2,8 @@ provider: openai_api_compatible label: en_US: OpenAI-API-compatible description: - en_US: All model providers compatible with OpenAI's API standard, such as Together.ai. - zh_Hans: 兼容 OpenAI API 的模型供应商,例如 Together.ai。 + en_US: Model providers compatible with OpenAI's API standard, such as LM Studio. + zh_Hans: 兼容 OpenAI API 的模型供应商,例如 LM Studio 。 supported_model_types: - llm - text-embedding diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index d59a30e599..19ec73d109 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -112,7 +112,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): credentials=credentials, tokens=used_tokens ) - + return TextEmbeddingResult( embeddings=batched_embeddings, usage=usage, diff --git a/docker/volumes/db/scripts/init_extension.sh b/api/core/model_runtime/model_providers/togetherai/__init__.py similarity index 100% rename from docker/volumes/db/scripts/init_extension.sh rename to api/core/model_runtime/model_providers/togetherai/__init__.py diff --git a/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg new file mode 100644 index 0000000000..e9d918b15e --- /dev/null +++ b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg new file mode 100644 index 0000000000..16bae5235f --- /dev/null +++ b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/togetherai/llm/__init__.py b/api/core/model_runtime/model_providers/togetherai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py new file mode 100644 index 0000000000..f2c74b808b --- /dev/null +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -0,0 +1,45 @@ +from typing import Generator, List, Optional, Union +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + +class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + + def _update_endpoint_url(self, credentials: dict): + credentials['endpoint_url'] = "https://api.together.xyz/v1" + return credentials + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().validate_credentials(model, cred_with_endpoint) + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().get_customizable_model_schema(model, cred_with_endpoint) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) + + diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py new file mode 100644 index 0000000000..e2ede35d69 --- /dev/null +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -0,0 +1,13 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class TogetherAIProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.yaml b/api/core/model_runtime/model_providers/togetherai/togetherai.yaml new file mode 100644 index 0000000000..7213750060 --- /dev/null +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.yaml @@ -0,0 +1,75 @@ +provider: togetherai +label: + en_US: together.ai +icon_small: + en_US: togetherai_square.svg +icon_large: + en_US: togetherai.svg +background: "#F1EFED" +help: + title: + en_US: Get your API key from together.ai + zh_Hans: 从 together.ai 获取 API Key + url: + en_US: https://api.together.xyz/ +supported_model_types: +- llm +configurate_methods: +- customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter full model name + zh_Hans: 输入模型全称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index b86ee682f1..4007222719 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -39,13 +39,15 @@ def test_invoke_model(setup_openai_mock): }, texts=[ "hello", - "world" + "world", + " ".join(["long_text"] * 100), + " ".join(["another_long_text"] * 100) ], user="abc-123" ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 2 + assert len(result.embeddings) == 4 assert result.usage.total_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index fbaa322881..88a23c6f99 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -46,14 +46,16 @@ def test_invoke_model(): }, texts=[ "hello", - "world" + "world", + " ".join(["long_text"] * 100), + " ".join(["another_long_text"] * 100) ], user="abc-123" ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 2 + assert len(result.embeddings) == 4 + assert result.usage.total_tokens == 502 def test_get_num_tokens(): diff --git a/api/tests/integration_tests/model_runtime/togetherai/__init__.py b/api/tests/integration_tests/model_runtime/togetherai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py new file mode 100644 index 0000000000..f4aad709c1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -0,0 +1,117 @@ +import os +from typing import Generator + +import pytest + +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \ + SystemPromptMessage, PromptMessageTool +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \ + LLMResultChunk +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel + + +def test_validate_credentials(): + model = TogetherAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': 'invalid_key', + 'mode': 'chat' + } + ) + + model.validate_credentials( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'chat' + } + ) + +def test_invoke_model(): + model = TogetherAILargeLanguageModel() + + response = model.invoke( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'completion' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + }, + stop=['How'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + +def test_invoke_stream_model(): + model = TogetherAILargeLanguageModel() + + response = model.invoke( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'chat' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + }, + stop=['How'], + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + +def test_get_num_tokens(): + model = TogetherAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21