diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 7b4416f44e..c06f122984 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -6,6 +6,7 @@ - cohere - bedrock - togetherai +- openrouter - ollama - mistralai - groq diff --git a/api/core/model_runtime/model_providers/openrouter/__init__.py b/api/core/model_runtime/model_providers/openrouter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/openrouter/_assets/openrouter.svg b/api/core/model_runtime/model_providers/openrouter/_assets/openrouter.svg new file mode 100644 index 0000000000..2e9590d923 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/_assets/openrouter.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/openrouter/_assets/openrouter_square.svg b/api/core/model_runtime/model_providers/openrouter/_assets/openrouter_square.svg new file mode 100644 index 0000000000..ed81fc041f --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/_assets/openrouter_square.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/openrouter/llm/__init__.py b/api/core/model_runtime/model_providers/openrouter/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py new file mode 100644 index 0000000000..bb62fc7bb2 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -0,0 +1,46 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): + + def _update_endpoint_url(self, credentials: dict): + credentials['endpoint_url'] = "https://openrouter.ai/api/v1" + return credentials + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().validate_credentials(model, cred_with_endpoint) + + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().get_customizable_model_schema(model, cred_with_endpoint) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + + return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py new file mode 100644 index 0000000000..81313fd29a --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -0,0 +1,11 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class OpenRouterProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.yaml b/api/core/model_runtime/model_providers/openrouter/openrouter.yaml new file mode 100644 index 0000000000..48a70700bb --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.yaml @@ -0,0 +1,75 @@ +provider: openrouter +label: + en_US: openrouter.ai +icon_small: + en_US: openrouter_square.svg +icon_large: + en_US: openrouter.svg +background: "#F1EFED" +help: + title: + en_US: Get your API key from openrouter.ai + zh_Hans: 从 openrouter.ai 获取 API Key + url: + en_US: https://openrouter.ai/keys +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter full model name + zh_Hans: 输入模型全称 + credential_form_schemas: + - variable: api_key + required: true + label: + en_US: API Key + type: secret-input + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: "4096" + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: "4096" + type: text-input diff --git a/api/tests/integration_tests/model_runtime/openrouter/__init__.py b/api/tests/integration_tests/model_runtime/openrouter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py new file mode 100644 index 0000000000..c0164e6418 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -0,0 +1,118 @@ +import os +from typing import Generator + +import pytest +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, + SystemPromptMessage, UserPromptMessage) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel + + +def test_validate_credentials(): + model = OpenRouterLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='mistralai/mixtral-8x7b-instruct', + credentials={ + 'api_key': 'invalid_key', + 'mode': 'chat' + } + ) + + model.validate_credentials( + model='mistralai/mixtral-8x7b-instruct', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'chat' + } + ) + + +def test_invoke_model(): + model = OpenRouterLargeLanguageModel() + + response = model.invoke( + model='mistralai/mixtral-8x7b-instruct', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'completion' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + }, + stop=['How'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = OpenRouterLargeLanguageModel() + + response = model.invoke( + model='mistralai/mixtral-8x7b-instruct', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'mode': 'chat' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + }, + stop=['How'], + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = OpenRouterLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='mistralai/mixtral-8x7b-instruct', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21