diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py
index 19fff3a39b..6c7aba2488 100644
--- a/api/core/model_runtime/model_providers/__base/model_provider.py
+++ b/api/core/model_runtime/model_providers/__base/model_provider.py
@@ -112,7 +112,7 @@ class ModelProvider(ABC):
model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
- and obj != AIModel):
+ and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py
index e4b78abefa..bf00caabd0 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py
@@ -40,87 +40,4 @@ class _CommonOAI_API_Compat:
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
- }
-
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
- """
- generate custom model entities from credentials
- """
- model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING
-
- entity = AIModelEntity(
- model=model,
- label=I18nObject(en_US=model),
- model_type=model_type,
- fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- model_properties={
- ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000),
- ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1),
- },
- parameter_rules=[
- ParameterRule(
- name=DefaultParameterName.TEMPERATURE.value,
- label=I18nObject(en_US="Temperature"),
- type=ParameterType.FLOAT,
- default=float(credentials.get('temperature', 1)),
- min=0,
- max=2
- ),
- ParameterRule(
- name=DefaultParameterName.TOP_P.value,
- label=I18nObject(en_US="Top P"),
- type=ParameterType.FLOAT,
- default=float(credentials.get('top_p', 1)),
- min=0,
- max=1
- ),
- ParameterRule(
- name="top_k",
- label=I18nObject(en_US="Top K"),
- type=ParameterType.INT,
- default=int(credentials.get('top_k', 1)),
- min=1,
- max=100
- ),
- ParameterRule(
- name=DefaultParameterName.FREQUENCY_PENALTY.value,
- label=I18nObject(en_US="Frequency Penalty"),
- type=ParameterType.FLOAT,
- default=float(credentials.get('frequency_penalty', 0)),
- min=-2,
- max=2
- ),
- ParameterRule(
- name=DefaultParameterName.PRESENCE_PENALTY.value,
- label=I18nObject(en_US="PRESENCE Penalty"),
- type=ParameterType.FLOAT,
- default=float(credentials.get('PRESENCE_penalty', 0)),
- min=-2,
- max=2
- ),
- ParameterRule(
- name=DefaultParameterName.MAX_TOKENS.value,
- label=I18nObject(en_US="Max Tokens"),
- type=ParameterType.INT,
- default=1024,
- min=1,
- max=int(credentials.get('max_tokens_to_sample', 4096)),
- )
- ],
- pricing=PriceConfig(
- input=Decimal(credentials.get('input_price', 0)),
- output=Decimal(credentials.get('output_price', 0)),
- unit=Decimal(credentials.get('unit', 0)),
- currency=credentials.get('currency', "USD")
- )
- )
-
- if model_type == ModelType.LLM:
- if credentials['mode'] == 'chat':
- entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
- elif credentials['mode'] == 'completion':
- entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
- else:
- raise ValueError(f"Unknown completion type {credentials['completion_type']}")
-
- return entity
+ }
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
index cf694b940b..338c655110 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
@@ -158,7 +158,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
- ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
ModelPropertyKey.MODE: credentials.get('mode'),
},
parameter_rules=[
@@ -196,9 +196,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
- label=I18nObject(en_US="PRESENCE Penalty"),
+ label=I18nObject(en_US="Presence Penalty"),
type=ParameterType.FLOAT,
- default=float(credentials.get('PRESENCE_penalty', 0)),
+ default=float(credentials.get('presence_penalty', 0)),
min=-2,
max=2
),
@@ -219,6 +219,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
)
+ if credentials['mode'] == 'chat':
+ entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
+ elif credentials['mode'] == 'completion':
+ entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
+ else:
+ raise ValueError(f"Unknown completion type {credentials['completion_type']}")
+
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
@@ -261,7 +268,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
- elif completion_type == LLMMode.COMPLETION:
+ elif completion_type is LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content
else:
@@ -291,10 +298,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
stream=stream
)
- # Debug: Print request headers and json data
- logger.debug(f"Request headers: {headers}")
- logger.debug(f"Request JSON data: {data}")
-
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
index b2a4af0057..e5d5f9547e 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
@@ -2,8 +2,8 @@ provider: openai_api_compatible
label:
en_US: OpenAI-API-compatible
description:
- en_US: All model providers compatible with OpenAI's API standard, such as Together.ai.
- zh_Hans: 兼容 OpenAI API 的模型供应商,例如 Together.ai。
+ en_US: Model providers compatible with OpenAI's API standard, such as LM Studio.
+ zh_Hans: 兼容 OpenAI API 的模型供应商,例如 LM Studio 。
supported_model_types:
- llm
- text-embedding
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
index d59a30e599..19ec73d109 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
@@ -112,7 +112,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
credentials=credentials,
tokens=used_tokens
)
-
+
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
diff --git a/docker/volumes/db/scripts/init_extension.sh b/api/core/model_runtime/model_providers/togetherai/__init__.py
similarity index 100%
rename from docker/volumes/db/scripts/init_extension.sh
rename to api/core/model_runtime/model_providers/togetherai/__init__.py
diff --git a/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg
new file mode 100644
index 0000000000..e9d918b15e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg
@@ -0,0 +1,13 @@
+
diff --git a/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg
new file mode 100644
index 0000000000..16bae5235f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg
@@ -0,0 +1,19 @@
+
diff --git a/api/core/model_runtime/model_providers/togetherai/llm/__init__.py b/api/core/model_runtime/model_providers/togetherai/llm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py
new file mode 100644
index 0000000000..f2c74b808b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py
@@ -0,0 +1,45 @@
+from typing import Generator, List, Optional, Union
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
+
+ def _update_endpoint_url(self, credentials: dict):
+ credentials['endpoint_url'] = "https://api.together.xyz/v1"
+ return credentials
+
+ def _invoke(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+ stream: bool = True, user: Optional[str] = None) \
+ -> Union[LLMResult, Generator]:
+ cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+ return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+ return super().validate_credentials(model, cred_with_endpoint)
+
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+ stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+ cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+ return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+ return super().get_customizable_model_schema(model, cred_with_endpoint)
+
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
+ cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+ return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)
+
+
diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py
new file mode 100644
index 0000000000..e2ede35d69
--- /dev/null
+++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py
@@ -0,0 +1,13 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class TogetherAIProvider(ModelProvider):
+
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.yaml b/api/core/model_runtime/model_providers/togetherai/togetherai.yaml
new file mode 100644
index 0000000000..7213750060
--- /dev/null
+++ b/api/core/model_runtime/model_providers/togetherai/togetherai.yaml
@@ -0,0 +1,75 @@
+provider: togetherai
+label:
+ en_US: together.ai
+icon_small:
+ en_US: togetherai_square.svg
+icon_large:
+ en_US: togetherai.svg
+background: "#F1EFED"
+help:
+ title:
+ en_US: Get your API key from together.ai
+ zh_Hans: 从 together.ai 获取 API Key
+ url:
+ en_US: https://api.together.xyz/
+supported_model_types:
+- llm
+configurate_methods:
+- customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter full model name
+ zh_Hans: 输入模型全称
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: false
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ en_US: Completion mode
+ type: select
+ required: false
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select completion mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 在此输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens_to_sample
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: '4096'
+ type: text-input
\ No newline at end of file
diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py
index b86ee682f1..4007222719 100644
--- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py
+++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py
@@ -39,13 +39,15 @@ def test_invoke_model(setup_openai_mock):
},
texts=[
"hello",
- "world"
+ "world",
+ " ".join(["long_text"] * 100),
+ " ".join(["another_long_text"] * 100)
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
- assert len(result.embeddings) == 2
+ assert len(result.embeddings) == 4
assert result.usage.total_tokens == 2
diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py
index fbaa322881..88a23c6f99 100644
--- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py
+++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py
@@ -46,14 +46,16 @@ def test_invoke_model():
},
texts=[
"hello",
- "world"
+ "world",
+ " ".join(["long_text"] * 100),
+ " ".join(["another_long_text"] * 100)
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
- assert len(result.embeddings) == 2
- assert result.usage.total_tokens == 2
+ assert len(result.embeddings) == 4
+ assert result.usage.total_tokens == 502
def test_get_num_tokens():
diff --git a/api/tests/integration_tests/model_runtime/togetherai/__init__.py b/api/tests/integration_tests/model_runtime/togetherai/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py
new file mode 100644
index 0000000000..f4aad709c1
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py
@@ -0,0 +1,117 @@
+import os
+from typing import Generator
+
+import pytest
+
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
+ SystemPromptMessage, PromptMessageTool
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
+ LLMResultChunk
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
+
+
+def test_validate_credentials():
+ model = TogetherAILargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+ credentials={
+ 'api_key': 'invalid_key',
+ 'mode': 'chat'
+ }
+ )
+
+ model.validate_credentials(
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+ credentials={
+ 'api_key': os.environ.get('TOGETHER_API_KEY'),
+ 'mode': 'chat'
+ }
+ )
+
+def test_invoke_model():
+ model = TogetherAILargeLanguageModel()
+
+ response = model.invoke(
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+ credentials={
+ 'api_key': os.environ.get('TOGETHER_API_KEY'),
+ 'mode': 'completion'
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ },
+ stop=['How'],
+ stream=False,
+ user="abc-123"
+ )
+
+ assert isinstance(response, LLMResult)
+ assert len(response.message.content) > 0
+
+def test_invoke_stream_model():
+ model = TogetherAILargeLanguageModel()
+
+ response = model.invoke(
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+ credentials={
+ 'api_key': os.environ.get('TOGETHER_API_KEY'),
+ 'mode': 'chat'
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Who are you?'
+ )
+ ],
+ model_parameters={
+ 'temperature': 1.0,
+ 'top_k': 2,
+ 'top_p': 0.5,
+ },
+ stop=['How'],
+ stream=True,
+ user="abc-123"
+ )
+
+ assert isinstance(response, Generator)
+
+ for chunk in response:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+def test_get_num_tokens():
+ model = TogetherAILargeLanguageModel()
+
+ num_tokens = model.get_num_tokens(
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+ credentials={
+ 'api_key': os.environ.get('TOGETHER_API_KEY'),
+ },
+ prompt_messages=[
+ SystemPromptMessage(
+ content='You are a helpful AI assistant.',
+ ),
+ UserPromptMessage(
+ content='Hello World!'
+ )
+ ]
+ )
+
+ assert isinstance(num_tokens, int)
+ assert num_tokens == 21