feat: query prompt template support in chatflow (#3791)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
takatost 2024-04-25 18:01:53 +08:00 committed by GitHub
parent 80b9507e7a
commit 12435774ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 113 additions and 18 deletions

View File

@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()} inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = [] prompt_messages = []
@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template=prompt_template, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
query_prompt_template=query_prompt_template,
files=files, files=files,
context=context, context=context,
memory_config=memory_config, memory_config=memory_config,
@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig], memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
elif prompt_item.role == PromptMessageRole.ASSISTANT: elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt)) prompt_messages.append(AssistantPromptMessage(content=prompt))
if query and query_prompt_template:
prompt_template = PromptTemplateParser(
template=query_prompt_template,
with_variable_tmpl=self.with_variable_tmpl
)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs['#sys.query#'] = query
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
query = prompt_template.format(
prompt_inputs
)
if memory and memory_config: if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)

View File

@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
role_prefix: Optional[RolePrefix] = None role_prefix: Optional[RolePrefix] = None
window: WindowConfig window: WindowConfig
query_prompt_template: Optional[str] = None

View File

@ -74,6 +74,7 @@ class LLMNode(BaseNode):
node_data=node_data, node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
if node_data.memory else None, if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs, inputs=inputs,
files=files, files=files,
context=context, context=context,
@ -209,6 +210,17 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value inputs[variable_selector.variable] = variable_value
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
return inputs return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
@ -302,7 +314,8 @@ class LLMNode(BaseNode):
return None return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
""" """
Fetch model config Fetch model config
:param node_data_model: node data model :param node_data_model: node data model
@ -407,6 +420,7 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(self, node_data: LLMNodeData, def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str], query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileVar], files: list[FileVar],
context: Optional[str], context: Optional[str],
@ -417,6 +431,7 @@ class LLMNode(BaseNode):
Fetch prompt messages Fetch prompt messages
:param node_data: node data :param node_data: node data
:param query: query :param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs :param inputs: inputs
:param files: files :param files: files
:param context: context :param context: context
@ -433,7 +448,8 @@ class LLMNode(BaseNode):
context=context, context=context,
memory_config=node_data.memory, memory_config=node_data.memory,
memory=memory, memory=memory,
model_config=model_config model_config=model_config,
query_prompt_template=query_prompt_template,
) )
stop = model_config.stop stop = model_config.stop
@ -539,6 +555,13 @@ class LLMNode(BaseNode):
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled: if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector variable_mapping['#context#'] = node_data.context.variable_selector

View File

@ -30,6 +30,9 @@ export const checkHasQueryBlock = (text: string) => {
* {{#1711617514996.sys.query#}} => [sys, query] * {{#1711617514996.sys.query#}} => [sys, query]
*/ */
export const getInputVars = (text: string): ValueSelector[] => { export const getInputVars = (text: string): ValueSelector[] => {
if (!text)
return []
const allVars = text.match(/{{#([^#]*)#}}/g) const allVars = text.match(/{{#([^#]*)#}}/g)
if (allVars && allVars?.length > 0) { if (allVars && allVars?.length > 0) {
// {{#context#}}, {{#query#}} is not input vars // {{#context#}}, {{#query#}} is not input vars

View File

@ -146,6 +146,7 @@ const Editor: FC<Props> = ({
<PromptEditor <PromptEditor
instanceId={instanceId} instanceId={instanceId}
compact compact
className='min-h-[56px]'
style={isExpand ? { height: editorExpandHeight - 5 } : {}} style={isExpand ? { height: editorExpandHeight - 5 } : {}}
value={value} value={value}
contextBlock={{ contextBlock={{

View File

@ -272,10 +272,12 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
const payload = (data as LLMNodeType) const payload = (data as LLMNodeType)
const isChatModel = payload.model?.mode === 'chat' const isChatModel = payload.model?.mode === 'chat'
let prompts: string[] = [] let prompts: string[] = []
if (isChatModel) if (isChatModel) {
prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || [] prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || []
else if (payload.memory?.query_prompt_template)
prompts = [(payload.prompt_template as PromptItem).text] prompts.push(payload.memory.query_prompt_template)
}
else { prompts = [(payload.prompt_template as PromptItem).text] }
const inputVars: ValueSelector[] = matchNotSystemVars(prompts) const inputVars: ValueSelector[] = matchNotSystemVars(prompts)
const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : [] const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : []
@ -375,6 +377,8 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new
text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector), text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector),
} }
}) })
if (payload.memory?.query_prompt_template)
payload.memory.query_prompt_template = replaceOldVarInText(payload.memory.query_prompt_template, oldVarSelector, newVarSelector)
} }
else { else {
payload.prompt_template = { payload.prompt_template = {

View File

@ -50,6 +50,13 @@ const nodeDefault: NodeDefault<LLMNodeType> = {
if (isPromptyEmpty) if (isPromptyEmpty)
errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') }) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') })
} }
if (!errorMessages && !!payload.memory) {
const isChatModel = payload.model.mode === 'chat'
// payload.memory.query_prompt_template not pass is default: {{#sys.query#}}
if (isChatModel && !!payload.memory.query_prompt_template && !payload.memory.query_prompt_template.includes('{{#sys.query#}}'))
errorMessages = t('workflow.nodes.llm.sysQueryInUser')
}
return { return {
isValid: !errorMessages, isValid: !errorMessages,
errorMessage: errorMessages, errorMessage: errorMessages,

View File

@ -50,7 +50,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
handleContextVarChange, handleContextVarChange,
filterInputVar, filterInputVar,
filterVar, filterVar,
availableVars,
availableNodes,
handlePromptChange, handlePromptChange,
handleSyeQueryChange,
handleMemoryChange, handleMemoryChange,
handleVisionResolutionEnabledChange, handleVisionResolutionEnabledChange,
handleVisionResolutionChange, handleVisionResolutionChange,
@ -204,19 +207,20 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
<HelpCircle className='w-3.5 h-3.5 text-gray-400' /> <HelpCircle className='w-3.5 h-3.5 text-gray-400' />
</TooltipPlus> </TooltipPlus>
</div>} </div>}
value={'{{#sys.query#}}'} value={inputs.memory.query_prompt_template || '{{#sys.query#}}'}
onChange={() => { }} onChange={handleSyeQueryChange}
readOnly readOnly={readOnly}
isShowContext={false} isShowContext={false}
isChatApp isChatApp
isChatModel={false} isChatModel
hasSetBlockStatus={{ hasSetBlockStatus={hasSetBlockStatus}
query: false, nodesOutputVars={availableVars}
history: true, availableNodes={availableNodes}
context: true,
}}
availableNodes={[startNode!]}
/> />
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
<div className='leading-[18px] text-xs font-normal text-[#DC6803]'>{t(`${i18nPrefix}.sysQueryInUser`)}</div>
)}
</div> </div>
</div> </div>
)} )}

View File

@ -8,6 +8,7 @@ import {
useIsChatMode, useIsChatMode,
useNodesReadOnly, useNodesReadOnly,
} from '../../hooks' } from '../../hooks'
import useAvailableVarList from '../_base/hooks/use-available-var-list'
import type { LLMNodeType } from './types' import type { LLMNodeType } from './types'
import { Resolution } from '@/types/app' import { Resolution } from '@/types/app'
import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
@ -206,6 +207,24 @@ const useConfig = (id: string, payload: LLMNodeType) => {
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs]) }, [inputs, setInputs])
const handleSyeQueryChange = useCallback((newQuery: string) => {
const newInputs = produce(inputs, (draft) => {
if (!draft.memory) {
draft.memory = {
window: {
enabled: false,
size: 10,
},
query_prompt_template: newQuery,
}
}
else {
draft.memory.query_prompt_template = newQuery
}
})
setInputs(newInputs)
}, [inputs, setInputs])
const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
if (!draft.vision) { if (!draft.vision) {
@ -248,6 +267,14 @@ const useConfig = (id: string, payload: LLMNodeType) => {
return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type) return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
}, []) }, [])
const {
availableVars,
availableNodes,
} = useAvailableVarList(id, {
onlyLeafNodeVar: false,
filterVar,
})
// single run // single run
const { const {
isShowSingleRun, isShowSingleRun,
@ -322,8 +349,10 @@ const useConfig = (id: string, payload: LLMNodeType) => {
const allVarStrArr = (() => { const allVarStrArr = (() => {
const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text] const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
if (isChatMode && isChatModel && !!inputs.memory) if (isChatMode && isChatModel && !!inputs.memory) {
arr.push('{{#sys.query#}}') arr.push('{{#sys.query#}}')
arr.push(inputs.memory.query_prompt_template)
}
return arr return arr
})() })()
@ -346,8 +375,11 @@ const useConfig = (id: string, payload: LLMNodeType) => {
handleContextVarChange, handleContextVarChange,
filterInputVar, filterInputVar,
filterVar, filterVar,
availableVars,
availableNodes,
handlePromptChange, handlePromptChange,
handleMemoryChange, handleMemoryChange,
handleSyeQueryChange,
handleVisionResolutionEnabledChange, handleVisionResolutionEnabledChange,
handleVisionResolutionChange, handleVisionResolutionChange,
isShowSingleRun, isShowSingleRun,

View File

@ -143,6 +143,7 @@ export type Memory = {
enabled: boolean enabled: boolean
size: number | string | null size: number | string | null
} }
query_prompt_template: string
} }
export enum VarType { export enum VarType {

View File

@ -204,6 +204,7 @@ const translation = {
singleRun: { singleRun: {
variable: 'Variable', variable: 'Variable',
}, },
sysQueryInUser: 'sys.query in user message is required',
}, },
knowledgeRetrieval: { knowledgeRetrieval: {
queryVariable: 'Query Variable', queryVariable: 'Query Variable',

View File

@ -204,6 +204,7 @@ const translation = {
singleRun: { singleRun: {
variable: '变量', variable: '变量',
}, },
sysQueryInUser: 'user message 中必须包含 sys.query',
}, },
knowledgeRetrieval: { knowledgeRetrieval: {
queryVariable: '查询变量', queryVariable: '查询变量',