diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index d3480d2c47..29b516ac02 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform): context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None) -> list[PromptMessage]: inputs = {key: str(value) for key, value in inputs.items()} prompt_messages = [] @@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_template=prompt_template, inputs=inputs, query=query, + query_prompt_template=query_prompt_template, files=files, context=context, memory_config=memory_config, @@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform): context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform): elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) + if query and query_prompt_template: + prompt_template = PromptTemplateParser( + template=query_prompt_template, + with_variable_tmpl=self.with_variable_tmpl + ) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + prompt_inputs['#sys.query#'] = query + + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + + query = prompt_template.format( + prompt_inputs + ) + if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 97ac2e3e2a..2be00bdf0e 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -40,3 +40,4 @@ class MemoryConfig(BaseModel): role_prefix: Optional[RolePrefix] = None window: WindowConfig + query_prompt_template: Optional[str] = None diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index a894e19a61..c8b7f279ab 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -74,6 +74,7 @@ class LLMNode(BaseNode): node_data=node_data, query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) if node_data.memory else None, + query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, files=files, context=context, @@ -209,6 +210,17 @@ class LLMNode(BaseNode): inputs[variable_selector.variable] = variable_value + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) + .extract_variable_selectors()) + for variable_selector in query_variable_selectors: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.variable} not found') + + inputs[variable_selector.variable] = variable_value + return inputs def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: @@ -302,7 +314,8 @@ class LLMNode(BaseNode): return None - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ + ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data_model: node data model @@ -407,6 +420,7 @@ class LLMNode(BaseNode): def _fetch_prompt_messages(self, node_data: LLMNodeData, query: Optional[str], + query_prompt_template: Optional[str], inputs: dict[str, str], files: list[FileVar], context: Optional[str], @@ -417,6 +431,7 @@ class LLMNode(BaseNode): Fetch prompt messages :param node_data: node data :param query: query + :param query_prompt_template: query prompt template :param inputs: inputs :param files: files :param context: context @@ -433,7 +448,8 @@ class LLMNode(BaseNode): context=context, memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, + query_prompt_template=query_prompt_template, ) stop = model_config.stop @@ -539,6 +555,13 @@ class LLMNode(BaseNode): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) + .extract_variable_selectors()) + for variable_selector in query_variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + if node_data.context.enabled: variable_mapping['#context#'] = node_data.context.variable_selector diff --git a/web/app/components/base/prompt-editor/constants.tsx b/web/app/components/base/prompt-editor/constants.tsx index c24a265684..e02963128a 100644 --- a/web/app/components/base/prompt-editor/constants.tsx +++ b/web/app/components/base/prompt-editor/constants.tsx @@ -30,6 +30,9 @@ export const checkHasQueryBlock = (text: string) => { * {{#1711617514996.sys.query#}} => [sys, query] */ export const getInputVars = (text: string): ValueSelector[] => { + if (!text) + return [] + const allVars = text.match(/{{#([^#]*)#}}/g) if (allVars && allVars?.length > 0) { // {{#context#}}, {{#query#}} is not input vars diff --git a/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx b/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx index 38e56daf2e..f3802d4612 100644 --- a/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx +++ b/web/app/components/workflow/nodes/_base/components/prompt/editor.tsx @@ -146,6 +146,7 @@ const Editor: FC = ({ { const payload = (data as LLMNodeType) const isChatModel = payload.model?.mode === 'chat' let prompts: string[] = [] - if (isChatModel) + if (isChatModel) { prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || [] - else - prompts = [(payload.prompt_template as PromptItem).text] + if (payload.memory?.query_prompt_template) + prompts.push(payload.memory.query_prompt_template) + } + else { prompts = [(payload.prompt_template as PromptItem).text] } const inputVars: ValueSelector[] = matchNotSystemVars(prompts) const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : [] @@ -375,6 +377,8 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector), } }) + if (payload.memory?.query_prompt_template) + payload.memory.query_prompt_template = replaceOldVarInText(payload.memory.query_prompt_template, oldVarSelector, newVarSelector) } else { payload.prompt_template = { diff --git a/web/app/components/workflow/nodes/llm/default.ts b/web/app/components/workflow/nodes/llm/default.ts index fbe360c4e2..8ad6d86260 100644 --- a/web/app/components/workflow/nodes/llm/default.ts +++ b/web/app/components/workflow/nodes/llm/default.ts @@ -50,6 +50,13 @@ const nodeDefault: NodeDefault = { if (isPromptyEmpty) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') }) } + + if (!errorMessages && !!payload.memory) { + const isChatModel = payload.model.mode === 'chat' + // payload.memory.query_prompt_template not pass is default: {{#sys.query#}} + if (isChatModel && !!payload.memory.query_prompt_template && !payload.memory.query_prompt_template.includes('{{#sys.query#}}')) + errorMessages = t('workflow.nodes.llm.sysQueryInUser') + } return { isValid: !errorMessages, errorMessage: errorMessages, diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index bf68cc76d1..705ab82fd7 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -50,7 +50,10 @@ const Panel: FC> = ({ handleContextVarChange, filterInputVar, filterVar, + availableVars, + availableNodes, handlePromptChange, + handleSyeQueryChange, handleMemoryChange, handleVisionResolutionEnabledChange, handleVisionResolutionChange, @@ -204,19 +207,20 @@ const Panel: FC> = ({ } - value={'{{#sys.query#}}'} - onChange={() => { }} - readOnly + value={inputs.memory.query_prompt_template || '{{#sys.query#}}'} + onChange={handleSyeQueryChange} + readOnly={readOnly} isShowContext={false} isChatApp - isChatModel={false} - hasSetBlockStatus={{ - query: false, - history: true, - context: true, - }} - availableNodes={[startNode!]} + isChatModel + hasSetBlockStatus={hasSetBlockStatus} + nodesOutputVars={availableVars} + availableNodes={availableNodes} /> + + {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( +
{t(`${i18nPrefix}.sysQueryInUser`)}
+ )} )} diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 5efb49aa9d..8ccbb50cca 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -8,6 +8,7 @@ import { useIsChatMode, useNodesReadOnly, } from '../../hooks' +import useAvailableVarList from '../_base/hooks/use-available-var-list' import type { LLMNodeType } from './types' import { Resolution } from '@/types/app' import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -206,6 +207,24 @@ const useConfig = (id: string, payload: LLMNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) + const handleSyeQueryChange = useCallback((newQuery: string) => { + const newInputs = produce(inputs, (draft) => { + if (!draft.memory) { + draft.memory = { + window: { + enabled: false, + size: 10, + }, + query_prompt_template: newQuery, + } + } + else { + draft.memory.query_prompt_template = newQuery + } + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { const newInputs = produce(inputs, (draft) => { if (!draft.vision) { @@ -248,6 +267,14 @@ const useConfig = (id: string, payload: LLMNodeType) => { return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type) }, []) + const { + availableVars, + availableNodes, + } = useAvailableVarList(id, { + onlyLeafNodeVar: false, + filterVar, + }) + // single run const { isShowSingleRun, @@ -322,8 +349,10 @@ const useConfig = (id: string, payload: LLMNodeType) => { const allVarStrArr = (() => { const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text] - if (isChatMode && isChatModel && !!inputs.memory) + if (isChatMode && isChatModel && !!inputs.memory) { arr.push('{{#sys.query#}}') + arr.push(inputs.memory.query_prompt_template) + } return arr })() @@ -346,8 +375,11 @@ const useConfig = (id: string, payload: LLMNodeType) => { handleContextVarChange, filterInputVar, filterVar, + availableVars, + availableNodes, handlePromptChange, handleMemoryChange, + handleSyeQueryChange, handleVisionResolutionEnabledChange, handleVisionResolutionChange, isShowSingleRun, diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 82fb1c3168..aa4238b27c 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -143,6 +143,7 @@ export type Memory = { enabled: boolean size: number | string | null } + query_prompt_template: string } export enum VarType { diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 949fcb061a..0d1ced6120 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -204,6 +204,7 @@ const translation = { singleRun: { variable: 'Variable', }, + sysQueryInUser: 'sys.query in user message is required', }, knowledgeRetrieval: { queryVariable: 'Query Variable', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 2805c59fa4..2771f8c1ca 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -204,6 +204,7 @@ const translation = { singleRun: { variable: '变量', }, + sysQueryInUser: 'user message 中必须包含 sys.query', }, knowledgeRetrieval: { queryVariable: '查询变量',