diff --git a/api/constants/model_template.py b/api/constants/model_template.py
index 3b8fa3fb55..c35a0b38d6 100644
--- a/api/constants/model_template.py
+++ b/api/constants/model_template.py
@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
+ "mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
+ "mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -137,10 +139,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
- pre_prompt="Please translate the following text into {{target_language}}:\n",
+ pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
+ "mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian',
]
}
+ },{
+ "paragraph": {
+ "label": "Query",
+ "variable": "query",
+ "required": True,
+ "default": ""
+ }
}
])
)
@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
+ "mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
@@ -255,10 +266,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
- pre_prompt="请将以下文本翻译为{{target_language}}:\n",
+ pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
+ "mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语",
]
}
+ },{
+ "paragraph": {
+ "label": "文本内容",
+ "variable": "query",
+ "required": True,
+ "default": ""
+ }
}
])
)
@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
+ "mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index 4834f84555..2476d91887 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
-from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
+from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py
new file mode 100644
index 0000000000..ce47e9e4d8
--- /dev/null
+++ b/api/controllers/console/app/advanced_prompt_template.py
@@ -0,0 +1,26 @@
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from libs.login import login_required
+from services.advanced_prompt_template_service import AdvancedPromptTemplateService
+
+class AdvancedPromptTemplateList(Resource):
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+
+ parser = reqparse.RequestParser()
+ parser.add_argument('app_mode', type=str, required=True, location='args')
+ parser.add_argument('model_mode', type=str, required=True, location='args')
+ parser.add_argument('has_context', type=str, required=False, default='true', location='args')
+ parser.add_argument('model_name', type=str, required=True, location='args')
+ args = parser.parse_args()
+
+ service = AdvancedPromptTemplateService()
+ return service.get_prompt(args)
+
+api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
\ No newline at end of file
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 70275bb70d..f454426ab4 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
-class IntroductionGenerateApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('prompt_template', type=str, required=True, location='json')
- args = parser.parse_args()
-
- account = current_user
-
- try:
- answer = LLMGenerator.generate_introduction(
- account.current_tenant_id,
- args['prompt_template']
- )
- except ProviderTokenNotInitError as ex:
- raise ProviderNotInitializeError(ex.description)
- except QuotaExceededError:
- raise ProviderQuotaExceededError()
- except ModelCurrentlyNotSupportError:
- raise ProviderModelCurrentlyNotSupportError()
- except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
- LLMRateLimitError, LLMAuthorizationError) as e:
- raise CompletionRequestError(str(e))
-
- return {'introduction': answer}
-
-
class RuleGenerateApi(Resource):
@setup_required
@login_required
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules
-api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index 1b2765f9bc..d6f9172e57 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id)
# get app info
- app_model = _get_app(app_id, 'chat')
+ app_model = _get_app(app_id)
message = db.session.query(Message).filter(
Message.id == message_id,
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 9d083f0027..2adc1db45f 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
- response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
+ response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
diff --git a/api/core/completion.py b/api/core/completion.py
index 59d589eabf..768231a53d 100644
--- a/api/core/completion.py
+++ b/api/core/completion.py
@@ -1,4 +1,3 @@
-import json
import logging
from typing import Optional, List, Union
@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
-from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
-from models.dataset import DocumentSegment, Dataset, Document
-from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
+from core.prompt.prompt_template import PromptTemplateParser
+from models.model import App, AppModelConfig, Account, Conversation, EndUser
class Completion:
@@ -30,7 +27,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
- query = PromptBuilder.process_template(query)
+ query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
@@ -160,14 +157,28 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
- prompt_messages, stop_words = model_instance.get_prompt(
- mode=mode,
- pre_prompt=app_model_config.pre_prompt,
- inputs=inputs,
- query=query,
- context=agent_execute_result.output if agent_execute_result else None,
- memory=memory
- )
+ if app_model_config.prompt_type == 'simple':
+ prompt_messages, stop_words = model_instance.get_prompt(
+ mode=mode,
+ pre_prompt=app_model_config.pre_prompt,
+ inputs=inputs,
+ query=query,
+ context=agent_execute_result.output if agent_execute_result else None,
+ memory=memory
+ )
+ else:
+ prompt_messages = model_instance.get_advanced_prompt(
+ app_mode=mode,
+ app_model_config=app_model_config,
+ inputs=inputs,
+ query=query,
+ context=agent_execute_result.output if agent_execute_result else None,
+ memory=memory
+ )
+
+ model_config = app_model_config.model_dict
+ completion_params = model_config.get("completion_params", {})
+ stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
@@ -176,7 +187,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
- stop=stop_words,
+ stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
@@ -266,52 +277,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
-
- @classmethod
- def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
- app_model_config: AppModelConfig, user: Account, streaming: bool):
-
- final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
- tenant_id=app.tenant_id,
- model_config=app_model_config.model_dict,
- streaming=streaming
- )
-
- # get llm prompt
- old_prompt_messages, _ = final_model_instance.get_prompt(
- mode='completion',
- pre_prompt=pre_prompt,
- inputs=message.inputs,
- query=message.query,
- context=None,
- memory=None
- )
-
- original_completion = message.answer.strip()
-
- prompt = MORE_LIKE_THIS_GENERATE_PROMPT
- prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
-
- prompt_messages = [PromptMessage(content=prompt)]
-
- conversation_message_task = ConversationMessageTask(
- task_id=task_id,
- app=app,
- app_model_config=app_model_config,
- user=user,
- inputs=message.inputs,
- query=message.query,
- is_override=True if message.override_model_configs else False,
- streaming=streaming,
- model_instance=final_model_instance
- )
-
- cls.recale_llm_max_tokens(
- model_instance=final_model_instance,
- prompt_messages=prompt_messages
- )
-
- final_model_instance.run(
- messages=prompt_messages,
- callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
- )
diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py
index ae98f91a88..3be6ffaee3 100644
--- a/api/core/conversation_message_task.py
+++ b/api/core/conversation_message_task.py
@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
- prompt_template = JinjaPromptTemplate.from_template(template=introduction)
- prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
+ prompt_template = PromptTemplateParser(template=introduction)
+ prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
- introduction = prompt_template.format(**prompt_inputs)
+ introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
- message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
- message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
+ message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
+ message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
- message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
+ message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
- self.message.answer = PromptBuilder.process_template(
+ self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
- agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
- agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
+ agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
+ agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
- loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
+ loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py
index 93208df960..a6699f32d7 100644
--- a/api/core/generator/llm_generator.py
+++ b/api/core/generator/llm_generator.py
@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
-from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
-from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
- GENERATOR_QA_PROMPT
+from core.prompt.prompt_template import PromptTemplateParser
+from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip()
- @classmethod
- def generate_conversation_summary(cls, tenant_id: str, messages):
- max_tokens = 200
-
- model_instance = ModelFactory.get_text_generation_model(
- tenant_id=tenant_id,
- model_kwargs=ModelKwargs(
- max_tokens=max_tokens
- )
- )
-
- prompt = CONVERSATION_SUMMARY_PROMPT
- prompt_with_empty_context = prompt.format(context='')
- prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
- max_context_token_length = model_instance.model_rules.max_tokens.max
- max_context_token_length = max_context_token_length if max_context_token_length else 1500
- rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
-
- context = ''
- for message in messages:
- if not message.answer:
- continue
-
- if len(message.query) > 2000:
- query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
- else:
- query = message.query
-
- if len(message.answer) > 2000:
- answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
- else:
- answer = message.answer
-
- message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
- if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
- context += message_qa_text
-
- if not context:
- return '[message too long, no summary]'
-
- prompt = prompt.format(context=context)
- prompts = [PromptMessage(content=prompt)]
- response = model_instance.run(prompts)
- answer = response.content
- return answer.strip()
-
- @classmethod
- def generate_introduction(cls, tenant_id: str, pre_prompt: str):
- prompt = INTRODUCTION_GENERATE_PROMPT
- prompt = prompt.format(prompt=pre_prompt)
-
- model_instance = ModelFactory.get_text_generation_model(
- tenant_id=tenant_id
- )
-
- prompts = [PromptMessage(content=prompt)]
- response = model_instance.run(prompts)
- answer = response.content
- return answer.strip()
-
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
- prompt = JinjaPromptTemplate(
- template="{{histories}}\n{{format_instructions}}\nquestions:\n",
- input_variables=["histories"],
- partial_variables={"format_instructions": format_instructions}
+ prompt_template = PromptTemplateParser(
+ template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
- _input = prompt.format_prompt(histories=histories)
+ prompt = prompt_template.format({
+ "histories": histories,
+ "format_instructions": format_instructions
+ })
try:
model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
- prompts = [PromptMessage(content=_input.to_string())]
+ prompt_messages = [PromptMessage(content=prompt)]
try:
- output = model_instance.run(prompts)
+ output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
- prompt = OutLinePromptTemplate(
- template=output_parser.get_format_instructions(),
- input_variables=["audiences", "hoping_to_solve"],
- partial_variables={
- "variable": '{variable}',
- "lanA": '{lanA}',
- "lanB": '{lanB}',
- "topic": '{topic}'
- },
- validate_template=False
+ prompt_template = PromptTemplateParser(
+ template=output_parser.get_format_instructions()
)
- _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
+ prompt = prompt_template.format(
+ inputs={
+ "audiences": audiences,
+ "hoping_to_solve": hoping_to_solve,
+ "variable": "{{variable}}",
+ "lanA": "{{lanA}}",
+ "lanB": "{{lanB}}",
+ "topic": "{{topic}}"
+ },
+ remove_template_variables=False
+ )
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
@@ -167,10 +109,10 @@ class LLMGenerator:
)
)
- prompts = [PromptMessage(content=_input.to_string())]
+ prompt_messages = [PromptMessage(content=prompt)]
try:
- output = model_instance.run(prompts)
+ output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 1475c143c2..fcf954a985 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -286,7 +286,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
- text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
+ text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -383,7 +383,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
- text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
+ text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
index 55d70d38ad..755df1201a 100644
--- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
+++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
- chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
+ chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py
index c37e88fac9..1ae04d67f5 100644
--- a/api/core/model_providers/models/entity/message.py
+++ b/api/core/model_providers/models/entity/message.py
@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel):
class MessageType(enum.Enum):
- HUMAN = 'human'
+ USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
- type: MessageType = MessageType.HUMAN
+ type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
@@ -27,7 +27,7 @@ class PromptMessage(BaseModel):
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
- if message.type == MessageType.HUMAN:
+ if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
additional_kwargs = {}
@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
- prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+ prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content,
@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
- prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+ prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages
diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py
index 7224bf7141..3a6e8b41ca 100644
--- a/api/core/model_providers/models/llm/base.py
+++ b/api/core/model_providers/models/llm/base.py
@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
- if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+ if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
- if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+ if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
- if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+ if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel):
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
+ def get_advanced_prompt(self, app_mode: str,
+ app_model_config: str, inputs: dict,
+ query: str,
+ context: Optional[str],
+ memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
+
+ model_mode = app_model_config.model_dict['mode']
+ conversation_histories_role = {}
+
+ raw_prompt_list = []
+ prompt_messages = []
+
+ if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
+ prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
+ raw_prompt_list = [{
+ 'role': MessageType.USER.value,
+ 'text': prompt_text
+ }]
+ conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
+ elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+ raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+ elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
+ raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+ elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
+ prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
+ raw_prompt_list = [{
+ 'role': MessageType.USER.value,
+ 'text': prompt_text
+ }]
+ else:
+ raise Exception("app_mode or model_mode not support")
+
+ for prompt_item in raw_prompt_list:
+ prompt = prompt_item['text']
+
+ # set prompt template variables
+ prompt_template = PromptTemplateParser(template=prompt)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+ if '#context#' in prompt:
+ if context:
+ prompt_inputs['#context#'] = context
+ else:
+ prompt_inputs['#context#'] = ''
+
+ if '#query#' in prompt:
+ if query:
+ prompt_inputs['#query#'] = query
+ else:
+ prompt_inputs['#query#'] = ''
+
+ if '#histories#' in prompt:
+ if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
+ memory.human_prefix = conversation_histories_role['user_prefix']
+ memory.ai_prefix = conversation_histories_role['assistant_prefix']
+ histories = self._get_history_messages_from_memory(memory, 2000)
+ prompt_inputs['#histories#'] = histories
+ else:
+ prompt_inputs['#histories#'] = ''
+
+ prompt = prompt_template.format(
+ prompt_inputs
+ )
+
+ prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+ prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
+
+ if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+ memory.human_prefix = MessageType.USER.value
+ memory.ai_prefix = MessageType.ASSISTANT.value
+ histories = self._get_history_messages_list_from_memory(memory, 2000)
+ prompt_messages.extend(histories)
+
+ if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+ prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
+
+ return prompt_messages
+
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
- prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
+ prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
- context=context
+ {'context': context}
)
pre_prompt_content = ''
if pre_prompt:
- prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
- prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
+ prompt_template = PromptTemplateParser(template=pre_prompt)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
- **prompt_inputs
+ prompt_inputs
)
prompt = ''
@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
- prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
- histories_prompt_content = prompt_template.format(
- histories=histories
- )
+ prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
+ histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt':
prompt += histories_prompt_content
- prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
- query_prompt_content = prompt_template.format(
- query=query
- )
+ prompt_template = PromptTemplateParser(template=query_prompt)
+ query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel):
external_context = memory.load_memory_variables({})
return external_context[memory_key]
+ def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
+ max_token_limit: int) -> List[PromptMessage]:
+ """Get memory messages."""
+ memory.max_token_limit = max_token_limit
+ memory.return_messages = True
+ memory_key = memory.memory_variables[0]
+ external_context = memory.load_memory_variables({})
+ memory.return_messages = False
+ return to_prompt_messages(external_context[memory_key])
+
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:
diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py
index 35532b0ec4..eab61c60cc 100644
--- a/api/core/model_providers/providers/anthropic_provider.py
+++ b/api/core/model_providers/providers/anthropic_provider.py
@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
+ 'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.CHAT.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/azure_openai_provider.py b/api/core/model_providers/providers/azure_openai_provider.py
index 4f7c8b717c..a34b463286 100644
--- a/api/core/model_providers/providers/azure_openai_provider.py
+++ b/api/core/model_providers/providers/azure_openai_provider.py
@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
-from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
+
+ if provider_model.model_type == ModelType.TEXT_GENERATION.value:
+ model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
+
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ if model_name == 'text-davinci-003':
+ return ModelMode.COMPLETION.value
+ else:
+ return ModelMode.CHAT.value
+
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
+ 'mode': ModelMode.COMPLETION.value,
}
]
diff --git a/api/core/model_providers/providers/baichuan_provider.py b/api/core/model_providers/providers/baichuan_provider.py
index 12c475f92d..784c9df2c6 100644
--- a/api/core/model_providers/providers/baichuan_provider.py
+++ b/api/core/model_providers/providers/baichuan_provider.py
@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider.
"""
return 'baichuan'
+
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
+ 'mode': ModelMode.CHAT.value,
}
]
else:
diff --git a/api/core/model_providers/providers/base.py b/api/core/model_providers/providers/base.py
index f10aa9f99d..9b05b4f5fd 100644
--- a/api/core/model_providers/providers/base.py
+++ b/api/core/model_providers/providers/base.py
@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
- return [{
- 'id': provider_model.model_name,
- 'name': provider_model.model_name
- } for provider_model in provider_models]
+ provider_model_list = []
+ for provider_model in provider_models:
+ provider_model_dict = {
+ 'id': provider_model.model_name,
+ 'name': provider_model.model_name
+ }
+
+ if model_type == ModelType.TEXT_GENERATION:
+ provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
+
+ provider_model_list.append(provider_model_dict)
+
+ return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
+ @abstractmethod
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ """
+ get text generation model mode.
+
+ :param model_name:
+ :return:
+ """
+ raise NotImplementedError
+
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""
diff --git a/api/core/model_providers/providers/chatglm_provider.py b/api/core/model_providers/providers/chatglm_provider.py
index 4b2a46ad42..d3c83e37ce 100644
--- a/api/core/model_providers/providers/chatglm_provider.py
+++ b/api/core/model_providers/providers/chatglm_provider.py
@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
+ 'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py
index deae4e35df..2cb7ff120a 100644
--- a/api/core/model_providers/providers/huggingface_hub_provider.py
+++ b/api/core/model_providers/providers/huggingface_hub_provider.py
@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/localai_provider.py b/api/core/model_providers/providers/localai_provider.py
index f5b07b1e6c..89279996f8 100644
--- a/api/core/model_providers/providers/localai_provider.py
+++ b/api/core/model_providers/providers/localai_provider.py
@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
+from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
+ if credentials['completion_type'] == 'chat_completion':
+ return ModelMode.CHAT.value
+ else:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/minimax_provider.py b/api/core/model_providers/providers/minimax_provider.py
index c13165d602..f643e1e805 100644
--- a/api/core/model_providers/providers/minimax_provider.py
+++ b/api/core/model_providers/providers/minimax_provider.py
@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
+ 'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/openai_provider.py b/api/core/model_providers/providers/openai_provider.py
index 01b2adcedd..de5de28025 100644
--- a/api/core/model_providers/providers/openai_provider.py
+++ b/api/core/model_providers/providers/openai_provider.py
@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
-from core.model_providers.models.llm.openai_model import OpenAIModel
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
+from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
+ 'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
+ 'mode': ModelMode.COMPLETION.value,
}
]
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ if model_name in COMPLETION_MODELS:
+ return ModelMode.COMPLETION.value
+ else:
+ return ModelMode.CHAT.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py
index a691507b9f..ea0e0b860d 100644
--- a/api/core/model_providers/providers/openllm_provider.py
+++ b/api/core/model_providers/providers/openllm_provider.py
@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/replicate_provider.py b/api/core/model_providers/providers/replicate_provider.py
index 9324d432a4..be9a7aa7ae 100644
--- a/api/core/model_providers/providers/replicate_provider.py
+++ b/api/core/model_providers/providers/replicate_provider.py
@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
-from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
+ ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py
index 89ed5d30b7..4174c01163 100644
--- a/api/core/model_providers/providers/spark_provider.py
+++ b/api/core/model_providers/providers/spark_provider.py
@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{
'id': 'spark',
'name': 'Spark V1.5',
+ 'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
+ 'mode': ModelMode.CHAT.value,
}
]
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.CHAT.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/tongyi_provider.py b/api/core/model_providers/providers/tongyi_provider.py
index d48b4447f8..49ff731ac5 100644
--- a/api/core/model_providers/providers/tongyi_provider.py
+++ b/api/core/model_providers/providers/tongyi_provider.py
@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
{
'id': 'qwen-turbo',
'name': 'qwen-turbo',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus',
'name': 'qwen-plus',
+ 'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/wenxin_provider.py b/api/core/model_providers/providers/wenxin_provider.py
index d6d1816323..e729358c0a 100644
--- a/api/core/model_providers/providers/wenxin_provider.py
+++ b/api/core/model_providers/providers/wenxin_provider.py
@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
+ 'mode': ModelMode.COMPLETION.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
+ 'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py
index f56c5fb59d..fff0119eaf 100644
--- a/api/core/model_providers/providers/xinference_provider.py
+++ b/api/core/model_providers/providers/xinference_provider.py
@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/model_providers/providers/zhipuai_provider.py b/api/core/model_providers/providers/zhipuai_provider.py
index 0f7dae5f4f..9b56851688 100644
--- a/api/core/model_providers/providers/zhipuai_provider.py
+++ b/api/core/model_providers/providers/zhipuai_provider.py
@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
+ 'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
+ 'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
+ 'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
+ 'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.CHAT.value
+
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py
index f359cf82fd..2ba732ee3d 100644
--- a/api/core/orchestrator_rule_parser.py
+++ b/api/core/orchestrator_rule_parser.py
@@ -1,4 +1,3 @@
-import math
from typing import Optional
from langchain import WikipediaAPIWrapper
@@ -50,6 +49,7 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
+ dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
@@ -96,13 +96,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tools = self.to_tools(
- agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
+ callbacks=[agent_callback, DifyStdOutCallbackHandler()],
+ agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
- callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
- retriever_from=retriever_from
+ retriever_from=retriever_from,
+ dataset_configs=dataset_configs
)
if len(tools) == 0:
@@ -170,20 +171,12 @@ class OrchestratorRuleParser:
return None
- def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
- conversation_message_task: ConversationMessageTask,
- rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
- retriever_from: str = 'dev') -> list[BaseTool]:
+ def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
- :param agent_model_instance:
- :param rest_tokens:
:param tool_configs: app agent tool configs
- :param conversation_message_task:
:param callbacks:
- :param return_resource:
- :param retriever_from:
:return:
"""
tools = []
@@ -195,15 +188,15 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
- tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
+ tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader":
- tool = self.to_web_reader_tool(agent_model_instance)
+ tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
- tool = self.to_google_search_tool()
+ tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia":
- tool = self.to_wikipedia_tool()
+ tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime":
- tool = self.to_current_datetime_tool()
+ tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool:
if tool.callbacks is not None:
@@ -215,12 +208,15 @@ class OrchestratorRuleParser:
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
- rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
+ dataset_configs: dict, rest_tokens: int,
+ return_resource: bool = False, retriever_from: str = 'dev',
+ **kwargs) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
+ :param dataset_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
@@ -238,10 +234,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
- k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
+ top_k = dataset_configs.get("top_k", 2)
+
+ # dynamically adjust top_k when the remaining token number is not enough to support top_k
+ top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
+
+ score_threshold = None
+ score_threshold_config = dataset_configs.get("score_threshold")
+ if score_threshold_config and score_threshold_config.get("enable"):
+ score_threshold = score_threshold_config.get("value")
+
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
- k=k,
+ top_k=top_k,
+ score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
@@ -250,7 +256,7 @@ class OrchestratorRuleParser:
return tool
- def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
+ def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
A tool for reading web pages
@@ -278,7 +284,7 @@ class OrchestratorRuleParser:
return tool
- def to_google_search_tool(self) -> Optional[BaseTool]:
+ def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
@@ -296,12 +302,12 @@ class OrchestratorRuleParser:
return tool
- def to_current_datetime_tool(self) -> Optional[BaseTool]:
+ def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool()
return tool
- def to_wikipedia_tool(self) -> Optional[BaseTool]:
+ def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
@@ -312,22 +318,18 @@ class OrchestratorRuleParser:
)
@classmethod
- def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
- DEFAULT_K = 2
- CONTEXT_TOKENS_PERCENT = 0.3
- MAX_K = 10
-
+ def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
if rest_tokens == -1:
- return DEFAULT_K
+ return top_k
processing_rule = dataset.latest_process_rule
if not processing_rule:
- return DEFAULT_K
+ return top_k
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
- return DEFAULT_K
+ return top_k
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
@@ -335,14 +337,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
- if rest_tokens < segment_max_tokens * DEFAULT_K:
+ if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens
- context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
-
- # when context_limit_tokens is less than default context tokens, use default_k
- if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
- return DEFAULT_K
-
- # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
- return min(context_limit_tokens // segment_max_tokens, MAX_K)
+ return min(top_k, 10)
diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/advanced_prompt_templates.py
new file mode 100644
index 0000000000..c5eee005b6
--- /dev/null
+++ b/api/core/prompt/advanced_prompt_templates.py
@@ -0,0 +1,79 @@
+CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
+
+BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
+
+CHAT_APP_COMPLETION_PROMPT_CONFIG = {
+ "completion_prompt_config": {
+ "prompt": {
+ "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: "
+ },
+ "conversation_histories_role": {
+ "user_prefix": "Human",
+ "assistant_prefix": "Assistant"
+ }
+ }
+}
+
+CHAT_APP_CHAT_PROMPT_CONFIG = {
+ "chat_prompt_config": {
+ "prompt": [{
+ "role": "system",
+ "text": "{{#pre_prompt#}}"
+ }]
+ }
+}
+
+COMPLETION_APP_CHAT_PROMPT_CONFIG = {
+ "chat_prompt_config": {
+ "prompt": [{
+ "role": "user",
+ "text": "{{#pre_prompt#}}"
+ }]
+ }
+}
+
+COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
+ "completion_prompt_config": {
+ "prompt": {
+ "text": "{{#pre_prompt#}}"
+ }
+ }
+}
+
+BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
+ "completion_prompt_config": {
+ "prompt": {
+ "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
+ },
+ "conversation_histories_role": {
+ "user_prefix": "用户",
+ "assistant_prefix": "助手"
+ }
+ }
+}
+
+BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
+ "chat_prompt_config": {
+ "prompt": [{
+ "role": "system",
+ "text": "{{#pre_prompt#}}"
+ }]
+ }
+}
+
+BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
+ "chat_prompt_config": {
+ "prompt": [{
+ "role": "user",
+ "text": "{{#pre_prompt#}}"
+ }]
+ }
+}
+
+BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
+ "completion_prompt_config": {
+ "prompt": {
+ "text": "{{#pre_prompt#}}"
+ }
+ }
+}
diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py
index 073cf2ce25..cc2a11a78f 100644
--- a/api/core/prompt/prompt_builder.py
+++ b/api/core/prompt/prompt_builder.py
@@ -1,38 +1,24 @@
-import re
+from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
-from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
-from langchain.schema import BaseMessage
-
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
class PromptBuilder:
+ @classmethod
+ def parse_prompt(cls, prompt: str, inputs: dict) -> str:
+ prompt_template = PromptTemplateParser(prompt)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+ prompt = prompt_template.format(prompt_inputs)
+ return prompt
+
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
- prompt_template = JinjaPromptTemplate.from_template(prompt_content)
- system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
- prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
- system_message = system_prompt_template.format(**prompt_inputs)
- return system_message
+ return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
- prompt_template = JinjaPromptTemplate.from_template(prompt_content)
- ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
- prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
- ai_message = ai_prompt_template.format(**prompt_inputs)
- return ai_message
+ return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
- prompt_template = JinjaPromptTemplate.from_template(prompt_content)
- human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
- human_message = human_prompt_template.format(**inputs)
- return human_message
-
- @classmethod
- def process_template(cls, template: str):
- processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
- # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
- # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
- return processed_template
+ return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py
index c51c4700c1..fbf09d2c64 100644
--- a/api/core/prompt/prompt_template.py
+++ b/api/core/prompt/prompt_template.py
@@ -1,79 +1,39 @@
import re
-from typing import Any
-from jinja2 import Environment, meta
-from langchain import PromptTemplate
-from langchain.formatting import StrictFormatter
+REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
-class JinjaPromptTemplate(PromptTemplate):
- template_format: str = "jinja2"
- """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
+class PromptTemplateParser:
+ """
+ Rules:
+
+ 1. Template variables must be enclosed in `{{}}`.
+ 2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
+ and can only start with letters and underscores.
+ 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
+ 4. In addition to the above, 3 types of special template variable Keys are accepted:
+ `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
+ """
+
+ def __init__(self, template: str):
+ self.template = template
+ self.variable_keys = self.extract()
+
+ def extract(self) -> list:
+ # Regular expression to match the template rules
+ return re.findall(REGEX, self.template)
+
+ def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
+ def replacer(match):
+ key = match.group(1)
+ value = inputs.get(key, match.group(0)) # return original matched string if key not found
+
+ if remove_template_variables:
+ return PromptTemplateParser.remove_template_variables(value)
+ return value
+
+ return re.sub(REGEX, replacer, self.template)
@classmethod
- def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
- """Load a prompt template from a template."""
- env = Environment()
- template = template.replace("{{}}", "{}")
- ast = env.parse(template)
- input_variables = meta.find_undeclared_variables(ast)
-
- if "partial_variables" in kwargs:
- partial_variables = kwargs["partial_variables"]
- input_variables = {
- var for var in input_variables if var not in partial_variables
- }
-
- return cls(
- input_variables=list(sorted(input_variables)), template=template, **kwargs
- )
-
-
-class OutLinePromptTemplate(PromptTemplate):
- @classmethod
- def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
- """Load a prompt template from a template."""
- input_variables = {
- v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
- }
- return cls(
- input_variables=list(sorted(input_variables)), template=template, **kwargs
- )
-
- def format(self, **kwargs: Any) -> str:
- """Format the prompt with the inputs.
-
- Args:
- kwargs: Any arguments to be passed to the prompt template.
-
- Returns:
- A formatted string.
-
- Example:
-
- .. code-block:: python
-
- prompt.format(variable1="foo")
- """
- kwargs = self._merge_partial_and_user_variables(**kwargs)
- return OneLineFormatter().format(self.template, **kwargs)
-
-
-class OneLineFormatter(StrictFormatter):
- def parse(self, format_string):
- last_end = 0
- results = []
- for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
- field_name = match.group(1)
- start, end = match.span()
-
- literal_text = format_string[last_end:start]
- last_end = end
-
- results.append((literal_text, field_name, '', None))
-
- remaining_literal_text = format_string[last_end:]
- if remaining_literal_text:
- results.append((remaining_literal_text, None, None, None))
-
- return results
+ def remove_template_variables(cls, text: str):
+ return re.sub(REGEX, r'{\1}', text)
diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py
index 979fe9be96..44cf954d3e 100644
--- a/api/core/prompt/prompts.py
+++ b/api/core/prompt/prompts.py
@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
User Input:
"""
-CONVERSATION_SUMMARY_PROMPT = (
- "Please generate a short summary of the following conversation.\n"
- "If the following conversation communicating in English, you should only return an English summary.\n"
- "If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
- "[Conversation Start]\n"
- "{context}\n"
- "[Conversation End]\n\n"
- "summary:"
-)
-
-INTRODUCTION_GENERATE_PROMPT = (
- "I am designing a product for users to interact with an AI through dialogue. "
- "The Prompt given to the AI before the conversation is:\n\n"
- "```\n{prompt}\n```\n\n"
- "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
- "Do not reveal the developer's motivation or deep logic behind the Prompt, "
- "but focus on building a relationship with the user:\n"
-)
-
-MORE_LIKE_THIS_GENERATE_PROMPT = (
- "-----\n"
- "{original_completion}\n"
- "-----\n\n"
- "Please use the above content as a sample for generating the result, "
- "and include key information points related to the original sample in the result. "
- "Try to rephrase this information in different ways and predict according to the rules below.\n\n"
- "-----\n"
- "{prompt}\n"
-)
-
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
```
<< MY INTENDED AUDIENCES >>
-{audiences}
+{{audiences}}
<< HOPING TO SOLVE >>
-{hoping_to_solve}
+{{hoping_to_solve}}
<< OUTPUT >>
"""
\ No newline at end of file
diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py
index 33fec157ea..2c14f40d15 100644
--- a/api/core/tool/dataset_retriever_tool.py
+++ b/api/core/tool/dataset_retriever_tool.py
@@ -1,5 +1,5 @@
import json
-from typing import Type
+from typing import Type, Optional
from flask import current_app
from langchain.tools import BaseTool
@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str
dataset_id: str
- k: int = 3
+ top_k: int = 2
+ score_threshold: Optional[float] = None
conversation_message_task: ConversationMessageTask
return_resource: bool
retriever_from: str
@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
)
)
- documents = kw_table_index.search(query, search_kwargs={'k': self.k})
+ documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents]))
else:
@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
return ''
except ProviderTokenNotInitError:
return ''
- embeddings = CacheEmbedding(embedding_model)
+ embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
- if self.k > 0:
+ if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
- 'k': self.k,
+ 'k': self.top_k,
+ 'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index 02020e9192..2414950a54 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle
from .clean_when_dataset_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .generate_conversation_name_when_first_message_created import handle
-from .generate_conversation_summary_when_few_message_created import handle
from .create_document_index import handle
diff --git a/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py b/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py
deleted file mode 100644
index df62a90b8e..0000000000
--- a/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from events.message_event import message_was_created
-from tasks.generate_conversation_summary_task import generate_conversation_summary_task
-
-
-@message_was_created.connect
-def handle(sender, **kwargs):
- message = sender
- conversation = kwargs.get('conversation')
- is_first_message = kwargs.get('is_first_message')
-
- if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
- history_message_count = conversation.message_count
- if history_message_count >= 5:
- generate_conversation_summary_task.delay(conversation.id)
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index b370ac41e0..fccfa5df30 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -28,6 +28,10 @@ model_config_fields = {
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
+ 'prompt_type': fields.String,
+ 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
+ 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
+ 'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
}
app_detail_fields = {
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index dcfbe8a069..df43a62fb6 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -123,6 +123,7 @@ conversation_with_summary_fields = {
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
+ 'name': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
new file mode 100644
index 0000000000..cbb04bb01e
--- /dev/null
+++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
@@ -0,0 +1,37 @@
+"""add advanced prompt templates
+
+Revision ID: b3a09c049e8e
+Revises: 2e9819ca5b28
+Create Date: 2023-10-10 15:23:23.395420
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision = 'b3a09c049e8e'
+down_revision = '2e9819ca5b28'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
+ batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
+ batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
+ batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.drop_column('dataset_configs')
+ batch_op.drop_column('completion_prompt_config')
+ batch_op.drop_column('chat_prompt_config')
+ batch_op.drop_column('prompt_type')
+
+ # ### end Alembic commands ###
diff --git a/api/models/model.py b/api/models/model.py
index f372f516da..d3f5c8135f 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text)
+ prompt_type = db.Column(db.String(255), nullable=False, default='simple')
+ chat_prompt_config = db.Column(db.Text)
+ completion_prompt_config = db.Column(db.Text)
+ dataset_configs = db.Column(db.Text)
@property
def app(self):
@@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
+ @property
+ def chat_prompt_config_dict(self) -> dict:
+ return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
+
+ @property
+ def completion_prompt_config_dict(self) -> dict:
+ return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
+
+ @property
+ def dataset_configs_dict(self) -> dict:
+ return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
+
def to_dict(self) -> dict:
return {
"provider": "",
@@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
"user_input_form": self.user_input_form_list,
"dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt,
- "agent_mode": self.agent_mode_dict
+ "agent_mode": self.agent_mode_dict,
+ "prompt_type": self.prompt_type,
+ "chat_prompt_config": self.chat_prompt_config_dict,
+ "completion_prompt_config": self.completion_prompt_config_dict,
+ "dataset_configs": self.dataset_configs_dict
}
def from_model_config_dict(self, model_config: dict):
@@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None
+ self.prompt_type = model_config.get('prompt_type', 'simple')
+ self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
+ if model_config.get('chat_prompt_config') else None
+ self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
+ if model_config.get('completion_prompt_config') else None
+ self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
+ if model_config.get('dataset_configs') else None
return self
def copy(self):
@@ -197,7 +224,11 @@ class AppModelConfig(db.Model):
dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode,
- retriever_resource=self.retriever_resource
+ retriever_resource=self.retriever_resource,
+ prompt_type=self.prompt_type,
+ chat_prompt_config=self.chat_prompt_config,
+ completion_prompt_config=self.completion_prompt_config,
+ dataset_configs=self.dataset_configs
)
return new_app_model_config
diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py
new file mode 100644
index 0000000000..3ef2b6059e
--- /dev/null
+++ b/api/services/advanced_prompt_template_service.py
@@ -0,0 +1,56 @@
+
+import copy
+
+from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
+ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
+
+class AdvancedPromptTemplateService:
+
+ def get_prompt(self, args: dict) -> dict:
+ app_mode = args['app_mode']
+ model_mode = args['model_mode']
+ model_name = args['model_name']
+ has_context = args['has_context']
+
+ if 'baichuan' in model_name:
+ return self.get_baichuan_prompt(app_mode, model_mode, has_context)
+ else:
+ return self.get_common_prompt(app_mode, model_mode, has_context)
+
+ def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
+ if app_mode == 'chat':
+ if model_mode == 'completion':
+ return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
+ elif model_mode == 'chat':
+ return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
+ elif app_mode == 'completion':
+ if model_mode == 'completion':
+ return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
+ elif model_mode == 'chat':
+ return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
+
+ def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
+ if has_context == 'true':
+ prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
+
+ return prompt_template
+
+
+ def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
+ if has_context == 'true':
+ prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
+
+ return prompt_template
+
+
+ def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
+ if app_mode == 'chat':
+ if model_mode == 'completion':
+ return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+ elif model_mode == 'chat':
+ return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+ elif app_mode == 'completion':
+ if model_mode == 'completion':
+ return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+ elif model_mode == 'chat':
+ return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
\ No newline at end of file
diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py
index 916a1078e5..4acb2f346f 100644
--- a/api/services/app_model_config_service.py
+++ b/api/services/app_model_config_service.py
@@ -3,7 +3,7 @@ import uuid
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
-from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.entity.model_params import ModelType, ModelMode
from models.account import Account
from services.dataset_service import DatasetService
@@ -34,40 +34,28 @@ class AppModelConfigService:
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
- #
- # if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
- # llm_constant.max_context_token_length[model_name]:
- # raise ValueError(
- # "max_tokens must be an integer greater than 0 "
- # "and not exceeding the maximum value of the corresponding model")
- #
+
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
- #
- # if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
- # raise ValueError("temperature must be a float between 0 and 2")
- #
+
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
- # if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
- # raise ValueError("top_p must be a float between 0 and 2")
- #
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
- # if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
- # raise ValueError("presence_penalty must be a float between -2 and 2")
- #
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
- # if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
- # raise ValueError("frequency_penalty must be a float between -2 and 2")
+ # stop
+ if 'stop' not in cp:
+ cp["stop"] = []
+ elif not isinstance(cp["stop"], list):
+ raise ValueError("stop in model.completion_params must be of list type")
# Filter out extra parameters
filtered_cp = {
@@ -75,7 +63,8 @@ class AppModelConfigService:
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
- "frequency_penalty": cp["frequency_penalty"]
+ "frequency_penalty": cp["frequency_penalty"],
+ "stop": cp["stop"]
}
return filtered_cp
@@ -211,6 +200,10 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
+
+ # model.mode
+ if 'mode' not in config['model'] or not config['model']["mode"]:
+ config['model']["mode"] = ""
# model.completion_params
if 'completion_params' not in config["model"]:
@@ -339,6 +332,9 @@ class AppModelConfigService:
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
+ # advanced prompt validation
+ AppModelConfigService.is_advanced_prompt_valid(config, mode)
+
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
@@ -351,12 +347,17 @@ class AppModelConfigService:
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
+ "mode": config['model']["mode"],
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"dataset_query_variable": config.get('dataset_query_variable'),
"pre_prompt": config["pre_prompt"],
- "agent_mode": config["agent_mode"]
+ "agent_mode": config["agent_mode"],
+ "prompt_type": config["prompt_type"],
+ "chat_prompt_config": config["chat_prompt_config"],
+ "completion_prompt_config": config["completion_prompt_config"],
+ "dataset_configs": config["dataset_configs"]
}
return filtered_config
@@ -375,4 +376,51 @@ class AppModelConfigService:
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
+
+ @staticmethod
+ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
+ # prompt_type
+ if 'prompt_type' not in config or not config["prompt_type"]:
+ config["prompt_type"] = "simple"
+
+ if config['prompt_type'] not in ['simple', 'advanced']:
+ raise ValueError("prompt_type must be in ['simple', 'advanced']")
+
+ # chat_prompt_config
+ if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
+ config["chat_prompt_config"] = {}
+
+ if not isinstance(config["chat_prompt_config"], dict):
+ raise ValueError("chat_prompt_config must be of object type")
+
+ # completion_prompt_config
+ if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
+ config["completion_prompt_config"] = {}
+
+ if not isinstance(config["completion_prompt_config"], dict):
+ raise ValueError("completion_prompt_config must be of object type")
+
+ # dataset_configs
+ if 'dataset_configs' not in config or not config["dataset_configs"]:
+ config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
+
+ if not isinstance(config["dataset_configs"], dict):
+ raise ValueError("dataset_configs must be of object type")
+
+ if config['prompt_type'] == 'advanced':
+ if not config['chat_prompt_config'] and not config['completion_prompt_config']:
+ raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
+
+ if config['model']["mode"] not in ['chat', 'completion']:
+ raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
+
+ if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
+ user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
+ assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
+
+ if not user_prefix:
+ config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
+
+ if not assistant_prefix:
+ config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
diff --git a/api/services/completion_service.py b/api/services/completion_service.py
index c95905c6c8..e2a28357cb 100644
--- a/api/services/completion_service.py
+++ b/api/services/completion_service.py
@@ -244,7 +244,8 @@ class CompletionService:
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
- message_id: str, streaming: bool = True) -> Union[dict | Generator]:
+ message_id: str, streaming: bool = True,
+ retriever_from: str = 'dev') -> Union[dict | Generator]:
if not user:
raise ValueError('user cannot be None')
@@ -266,14 +267,11 @@ class CompletionService:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
-
- if message.override_model_configs:
- override_model_configs = json.loads(message.override_model_configs)
- pre_prompt = override_model_configs.get("pre_prompt", '')
- elif app_model_config:
- pre_prompt = app_model_config.pre_prompt
- else:
- raise AppModelConfigBrokenError()
+ model_dict = app_model_config.model_dict
+ completion_params = model_dict.get('completion_params')
+ completion_params['temperature'] = 0.9
+ model_dict['completion_params'] = completion_params
+ app_model_config.model = json.dumps(model_dict)
generate_task_id = str(uuid.uuid4())
@@ -282,58 +280,28 @@ class CompletionService:
user = cls.get_real_user_instead_of_proxy_obj(user)
- generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
+ generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config,
- 'detached_message': message,
- 'pre_prompt': pre_prompt,
+ 'query': message.query,
+ 'inputs': message.inputs,
'detached_user': user,
- 'streaming': streaming
+ 'detached_conversation': None,
+ 'streaming': streaming,
+ 'is_model_config_override': True,
+ 'retriever_from': retriever_from
})
generate_worker_thread.start()
- cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
+ # wait for 10 minutes to close the thread
+ cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
+ generate_task_id)
return cls.compact_response(pubsub, streaming)
- @classmethod
- def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
- app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
- detached_user: Union[Account, EndUser], streaming: bool):
- with flask_app.app_context():
- # fixed the state of the model object when it detached from the original session
- user = db.session.merge(detached_user)
- app_model = db.session.merge(detached_app_model)
- message = db.session.merge(detached_message)
-
- try:
- # run
- Completion.generate_more_like_this(
- task_id=generate_task_id,
- app=app_model,
- user=user,
- message=message,
- pre_prompt=pre_prompt,
- app_model_config=app_model_config,
- streaming=streaming
- )
- except ConversationTaskStoppedException:
- pass
- except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
- LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
- ModelCurrentlyNotSupportError) as e:
- PubHandler.pub_error(user, generate_task_id, e)
- except LLMAuthorizationError:
- PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
- except Exception as e:
- logging.exception("Unknown Error in completion")
- PubHandler.pub_error(user, generate_task_id, e)
- finally:
- db.session.commit()
-
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:
diff --git a/api/services/provider_service.py b/api/services/provider_service.py
index 34064d0c33..f9acedf8c2 100644
--- a/api/services/provider_service.py
+++ b/api/services/provider_service.py
@@ -482,6 +482,9 @@ class ProviderService:
'features': []
}
+ if 'mode' in model:
+ valid_model_dict['model_mode'] = model['mode']
+
if 'features' in model:
valid_model_dict['features'] = model['features']
diff --git a/api/tasks/generate_conversation_summary_task.py b/api/tasks/generate_conversation_summary_task.py
deleted file mode 100644
index 791f141d5b..0000000000
--- a/api/tasks/generate_conversation_summary_task.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import logging
-import time
-
-import click
-from celery import shared_task
-from werkzeug.exceptions import NotFound
-
-from core.generator.llm_generator import LLMGenerator
-from core.model_providers.error import LLMError, ProviderTokenNotInitError
-from extensions.ext_database import db
-from models.model import Conversation, Message
-
-
-@shared_task(queue='generation')
-def generate_conversation_summary_task(conversation_id: str):
- """
- Async Generate conversation summary
- :param conversation_id:
-
- Usage: generate_conversation_summary_task.delay(conversation_id)
- """
- logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
- start_at = time.perf_counter()
-
- conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
- if not conversation:
- raise NotFound('Conversation not found')
-
- try:
- # get conversation messages count
- history_message_count = conversation.message_count
- if history_message_count >= 5 and not conversation.summary:
- app_model = conversation.app
- if not app_model:
- return
-
- history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
- .order_by(Message.created_at.asc()).all()
-
- conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
- db.session.add(conversation)
- db.session.commit()
- except (LLMError, ProviderTokenNotInitError):
- conversation.summary = '[No Summary]'
- db.session.commit()
- pass
- except Exception as e:
- conversation.summary = '[No Summary]'
- db.session.commit()
- logging.exception(e)
-
- end_at = time.perf_counter()
- logging.info(
- click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
- fg='green'))
diff --git a/api/tests/integration_tests/models/llm/test_anthropic_model.py b/api/tests/integration_tests/models/llm/test_anthropic_model.py
index 32013b27aa..f0636f6e79 100644
--- a/api/tests/integration_tests/models/llm/test_anthropic_model.py
+++ b/api/tests/integration_tests/models/llm/test_anthropic_model.py
@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
diff --git a/api/tests/integration_tests/models/llm/test_azure_openai_model.py b/api/tests/integration_tests/models/llm/test_azure_openai_model.py
index 1df272d1cc..9d289f404d 100644
--- a/api/tests/integration_tests/models/llm/test_azure_openai_model.py
+++ b/api/tests/integration_tests/models/llm/test_azure_openai_model.py
@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
diff --git a/api/tests/integration_tests/models/llm/test_baichuan_model.py b/api/tests/integration_tests/models/llm/test_baichuan_model.py
index 15610e1d1d..c70b14ce2b 100644
--- a/api/tests/integration_tests/models/llm/test_baichuan_model.py
+++ b/api/tests/integration_tests/models/llm/test_baichuan_model.py
@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b')
messages = [
- PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+ PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
- PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+ PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
diff --git a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py
index eda95102c9..2c8c4556bc 100644
--- a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py
+++ b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py
@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
mocker
)
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
mocker
)
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_minimax_model.py b/api/tests/integration_tests/models/llm/test_minimax_model.py
index d93f8ad735..43634f3499 100644
--- a/api/tests/integration_tests/models/llm/test_minimax_model.py
+++ b/api/tests/integration_tests/models/llm/test_minimax_model.py
@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_openai_model.py b/api/tests/integration_tests/models/llm/test_openai_model.py
index 3deeb2f02c..e6044c0bb5 100644
--- a/api/tests/integration_tests/models/llm/test_openai_model.py
+++ b/api/tests/integration_tests/models/llm/test_openai_model.py
@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
diff --git a/api/tests/integration_tests/models/llm/test_openllm_model.py b/api/tests/integration_tests/models/llm/test_openllm_model.py
index d515f35048..8a70e6ace4 100644
--- a/api/tests/integration_tests/models/llm/test_openllm_model.py
+++ b/api/tests/integration_tests/models/llm/test_openllm_model.py
@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_replicate_model.py b/api/tests/integration_tests/models/llm/test_replicate_model.py
index 13efc19881..d5e55def41 100644
--- a/api/tests/integration_tests/models/llm/test_replicate_model.py
+++ b/api/tests/integration_tests/models/llm/test_replicate_model.py
@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 7
diff --git a/api/tests/integration_tests/models/llm/test_spark_model.py b/api/tests/integration_tests/models/llm/test_spark_model.py
index d07bfb279a..e6fa45f0cb 100644
--- a/api/tests/integration_tests/models/llm/test_spark_model.py
+++ b/api/tests/integration_tests/models/llm/test_spark_model.py
@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
diff --git a/api/tests/integration_tests/models/llm/test_tongyi_model.py b/api/tests/integration_tests/models/llm/test_tongyi_model.py
index 8c34497ac7..b448c29f47 100644
--- a/api/tests/integration_tests/models/llm/test_tongyi_model.py
+++ b/api/tests/integration_tests/models/llm/test_tongyi_model.py
@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-turbo')
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_wenxin_model.py b/api/tests/integration_tests/models/llm/test_wenxin_model.py
index 29a0de3262..8cc4779160 100644
--- a/api/tests/integration_tests/models/llm/test_wenxin_model.py
+++ b/api/tests/integration_tests/models/llm/test_wenxin_model.py
@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_xinference_model.py b/api/tests/integration_tests/models/llm/test_xinference_model.py
index aab075fae2..01d5fcdd9f 100644
--- a/api/tests/integration_tests/models/llm/test_xinference_model.py
+++ b/api/tests/integration_tests/models/llm/test_xinference_model.py
@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('llama-2-chat', mocker)
rst = model.get_num_tokens([
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
diff --git a/api/tests/integration_tests/models/llm/test_zhipuai_model.py b/api/tests/integration_tests/models/llm/test_zhipuai_model.py
index 4bc47bec9b..8f1a60e8f2 100644
--- a/api/tests/integration_tests/models/llm/test_zhipuai_model.py
+++ b/api/tests/integration_tests/models/llm/test_zhipuai_model.py
@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
- PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+ PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite')
messages = [
- PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+ PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite', streaming=True)
messages = [
- PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+ PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
diff --git a/api/tests/unit_tests/model_providers/fake_model_provider.py b/api/tests/unit_tests/model_providers/fake_model_provider.py
index 4e14d5924e..35c44061dc 100644
--- a/api/tests/unit_tests/model_providers/fake_model_provider.py
+++ b/api/tests/unit_tests/model_providers/fake_model_provider.py
@@ -1,7 +1,7 @@
from typing import Type
from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider
@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
- return [{'id': 'test_model', 'name': 'Test Model'}]
+ return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
+
+ def _get_text_generation_model_mode(self, model_name) -> str:
+ return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel
diff --git a/api/tests/unit_tests/model_providers/test_base_model_provider.py b/api/tests/unit_tests/model_providers/test_base_model_provider.py
index 7d6e56eb0a..534599c319 100644
--- a/api/tests/unit_tests/model_providers/test_base_model_provider.py
+++ b/api/tests/unit_tests/model_providers/test_base_model_provider.py
@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
- assert result == [{'id': 'test_model', 'name': 'test_model'}]
+ assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
def test_check_quota_over_limit(mocker):