feat (new llm): add support for openrouter (#3042)

This commit is contained in:
Salem Korayem 2024-04-02 13:38:46 +03:00 committed by GitHub
parent e12a0c154c
commit 6b4c8e76e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 272 additions and 0 deletions

View File

@ -6,6 +6,7 @@
- cohere
- bedrock
- togetherai
- openrouter
- ollama
- mistralai
- groq

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 13 KiB

View File

@ -0,0 +1,10 @@
<svg width="25" height="21" viewBox="0 0 25 21" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M1.05858 10.1738C1.76158 10.1738 4.47988 9.56715 5.88589 8.77041C7.2919 7.97367 7.2919 7.97367 10.1977 5.91152C13.8766 3.30069 16.4779 4.17486 20.7428 4.17486" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.4182 7.63145L11.3787 7.65951C8.50565 9.69845 8.42504 9.75566 6.92566 10.6053C5.98567 11.138 4.74704 11.5436 3.75151 11.8089C2.80313 12.0615 1.71203 12.2829 1.05858 12.2829V8.06483C1.05075 8.06483 1.05422 8.06445 1.06984 8.06276C1.11491 8.05788 1.26116 8.04203 1.52896 7.9926C1.84599 7.9341 2.24205 7.84582 2.6657 7.73296C3.55657 7.49564 4.3801 7.1996 4.84612 6.93552C4.88175 6.91533 4.91635 6.89573 4.95001 6.87666C6.15007 6.19693 6.15657 6.19325 8.97708 4.1916C12.5199 1.67735 15.5815 1.83587 18.5849 1.99138C19.3056 2.0287 20.0229 2.06584 20.7428 2.06584V6.28388C19.6102 6.28388 18.6583 6.24193 17.8263 6.20527C15.1245 6.08621 13.685 6.02278 11.4182 7.63145Z" fill="black"/>
<path d="M24.8671 4.20087L17.6613 8.36117V0.0405881L24.8671 4.20087Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M17.6378 0L24.9139 4.20087L17.6378 8.40176V0ZM17.6847 0.0811762V8.32058L24.8202 4.20087L17.6847 0.0811762Z" fill="black"/>
<path d="M0.917975 10.1764C1.62098 10.1764 4.33927 10.7831 5.74529 11.5799C7.1513 12.3766 7.1513 12.3766 10.0571 14.4388C13.736 17.0496 16.3373 16.1754 20.6022 16.1754" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.929234 12.2875C0.913615 12.2858 0.910145 12.2854 0.917975 12.2854V8.06741C1.57142 8.06741 2.66253 8.28878 3.61091 8.54142C4.60644 8.80663 5.84507 9.21231 6.78506 9.74497C8.28444 10.5946 8.36505 10.6518 11.2381 12.6908L11.2776 12.7188C13.5444 14.3275 14.9839 14.2641 17.6857 14.145C18.5177 14.1083 19.4696 14.0664 20.6022 14.0664V18.2844C19.8823 18.2844 19.165 18.3216 18.4443 18.3589C15.4409 18.5144 12.3793 18.6729 8.83648 16.1587C6.01597 14.157 6.00947 14.1533 4.80941 13.4736C4.77575 13.4545 4.74115 13.4349 4.70551 13.4148C4.2395 13.1507 3.41597 12.8546 2.5251 12.6173C2.10145 12.5045 1.70538 12.4162 1.38836 12.3577C1.12056 12.3083 0.974309 12.2924 0.929234 12.2875Z" fill="black"/>
<path d="M24.7265 16.1494L17.5207 11.9892V20.3097L24.7265 16.1494Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M17.4972 11.9486L24.7733 16.1494L17.4972 20.3503V11.9486ZM17.5441 12.0297V20.2691L24.6796 16.1494L17.5441 12.0297Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

@ -0,0 +1,46 @@
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _update_endpoint_url(self, credentials: dict):
credentials['endpoint_url'] = "https://openrouter.ai/api/v1"
return credentials
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().validate_credentials(model, cred_with_endpoint)
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_customizable_model_schema(model, cred_with_endpoint)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

View File

@ -0,0 +1,11 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class OpenRouterProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,75 @@
provider: openrouter
label:
en_US: openrouter.ai
icon_small:
en_US: openrouter_square.svg
icon_large:
en_US: openrouter.svg
background: "#F1EFED"
help:
title:
en_US: Get your API key from openrouter.ai
zh_Hans: 从 openrouter.ai 获取 API Key
url:
en_US: https://openrouter.ai/keys
supported_model_types:
- llm
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter full model name
zh_Hans: 输入模型全称
credential_form_schemas:
- variable: api_key
required: true
label:
en_US: API Key
type: secret-input
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: "4096"
type: text-input

View File

@ -0,0 +1,118 @@
import os
from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel
def test_validate_credentials():
model = OpenRouterLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': 'invalid_key',
'mode': 'chat'
}
)
model.validate_credentials(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
}
)
def test_invoke_model():
model = OpenRouterLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'completion'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = OpenRouterLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
},
stop=['How'],
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_get_num_tokens():
model = OpenRouterLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 21