mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 04:56:02 +08:00
feat: query prompt template support in chatflow (#3791)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
80b9507e7a
commit
12435774ca
@ -31,7 +31,8 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
@ -53,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
|
query_prompt_template=query_prompt_template,
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
@ -121,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get chat model prompt messages.
|
Get chat model prompt messages.
|
||||||
"""
|
"""
|
||||||
@ -148,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||||
|
|
||||||
|
if query and query_prompt_template:
|
||||||
|
prompt_template = PromptTemplateParser(
|
||||||
|
template=query_prompt_template,
|
||||||
|
with_variable_tmpl=self.with_variable_tmpl
|
||||||
|
)
|
||||||
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
prompt_inputs['#sys.query#'] = query
|
||||||
|
|
||||||
|
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||||
|
|
||||||
|
query = prompt_template.format(
|
||||||
|
prompt_inputs
|
||||||
|
)
|
||||||
|
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||||
|
|
||||||
|
@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
|
|||||||
|
|
||||||
role_prefix: Optional[RolePrefix] = None
|
role_prefix: Optional[RolePrefix] = None
|
||||||
window: WindowConfig
|
window: WindowConfig
|
||||||
|
query_prompt_template: Optional[str] = None
|
||||||
|
@ -74,6 +74,7 @@ class LLMNode(BaseNode):
|
|||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||||
if node_data.memory else None,
|
if node_data.memory else None,
|
||||||
|
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
@ -209,6 +210,17 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
inputs[variable_selector.variable] = variable_value
|
inputs[variable_selector.variable] = variable_value
|
||||||
|
|
||||||
|
memory = node_data.memory
|
||||||
|
if memory and memory.query_prompt_template:
|
||||||
|
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||||
|
.extract_variable_selectors())
|
||||||
|
for variable_selector in query_variable_selectors:
|
||||||
|
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||||
|
if variable_value is None:
|
||||||
|
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||||
|
|
||||||
|
inputs[variable_selector.variable] = variable_value
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||||
@ -302,7 +314,8 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||||
|
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
"""
|
"""
|
||||||
Fetch model config
|
Fetch model config
|
||||||
:param node_data_model: node data model
|
:param node_data_model: node data model
|
||||||
@ -407,6 +420,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
|
query_prompt_template: Optional[str],
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list[FileVar],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
@ -417,6 +431,7 @@ class LLMNode(BaseNode):
|
|||||||
Fetch prompt messages
|
Fetch prompt messages
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
:param query: query
|
:param query: query
|
||||||
|
:param query_prompt_template: query prompt template
|
||||||
:param inputs: inputs
|
:param inputs: inputs
|
||||||
:param files: files
|
:param files: files
|
||||||
:param context: context
|
:param context: context
|
||||||
@ -433,7 +448,8 @@ class LLMNode(BaseNode):
|
|||||||
context=context,
|
context=context,
|
||||||
memory_config=node_data.memory,
|
memory_config=node_data.memory,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config
|
model_config=model_config,
|
||||||
|
query_prompt_template=query_prompt_template,
|
||||||
)
|
)
|
||||||
stop = model_config.stop
|
stop = model_config.stop
|
||||||
|
|
||||||
@ -539,6 +555,13 @@ class LLMNode(BaseNode):
|
|||||||
for variable_selector in variable_selectors:
|
for variable_selector in variable_selectors:
|
||||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||||
|
|
||||||
|
memory = node_data.memory
|
||||||
|
if memory and memory.query_prompt_template:
|
||||||
|
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||||
|
.extract_variable_selectors())
|
||||||
|
for variable_selector in query_variable_selectors:
|
||||||
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||||
|
|
||||||
if node_data.context.enabled:
|
if node_data.context.enabled:
|
||||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||||
|
|
||||||
|
@ -30,6 +30,9 @@ export const checkHasQueryBlock = (text: string) => {
|
|||||||
* {{#1711617514996.sys.query#}} => [sys, query]
|
* {{#1711617514996.sys.query#}} => [sys, query]
|
||||||
*/
|
*/
|
||||||
export const getInputVars = (text: string): ValueSelector[] => {
|
export const getInputVars = (text: string): ValueSelector[] => {
|
||||||
|
if (!text)
|
||||||
|
return []
|
||||||
|
|
||||||
const allVars = text.match(/{{#([^#]*)#}}/g)
|
const allVars = text.match(/{{#([^#]*)#}}/g)
|
||||||
if (allVars && allVars?.length > 0) {
|
if (allVars && allVars?.length > 0) {
|
||||||
// {{#context#}}, {{#query#}} is not input vars
|
// {{#context#}}, {{#query#}} is not input vars
|
||||||
|
@ -146,6 +146,7 @@ const Editor: FC<Props> = ({
|
|||||||
<PromptEditor
|
<PromptEditor
|
||||||
instanceId={instanceId}
|
instanceId={instanceId}
|
||||||
compact
|
compact
|
||||||
|
className='min-h-[56px]'
|
||||||
style={isExpand ? { height: editorExpandHeight - 5 } : {}}
|
style={isExpand ? { height: editorExpandHeight - 5 } : {}}
|
||||||
value={value}
|
value={value}
|
||||||
contextBlock={{
|
contextBlock={{
|
||||||
|
@ -272,10 +272,12 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
|
|||||||
const payload = (data as LLMNodeType)
|
const payload = (data as LLMNodeType)
|
||||||
const isChatModel = payload.model?.mode === 'chat'
|
const isChatModel = payload.model?.mode === 'chat'
|
||||||
let prompts: string[] = []
|
let prompts: string[] = []
|
||||||
if (isChatModel)
|
if (isChatModel) {
|
||||||
prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || []
|
prompts = (payload.prompt_template as PromptItem[])?.map(p => p.text) || []
|
||||||
else
|
if (payload.memory?.query_prompt_template)
|
||||||
prompts = [(payload.prompt_template as PromptItem).text]
|
prompts.push(payload.memory.query_prompt_template)
|
||||||
|
}
|
||||||
|
else { prompts = [(payload.prompt_template as PromptItem).text] }
|
||||||
|
|
||||||
const inputVars: ValueSelector[] = matchNotSystemVars(prompts)
|
const inputVars: ValueSelector[] = matchNotSystemVars(prompts)
|
||||||
const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : []
|
const contextVar = (data as LLMNodeType).context?.variable_selector ? [(data as LLMNodeType).context?.variable_selector] : []
|
||||||
@ -375,6 +377,8 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new
|
|||||||
text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector),
|
text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
if (payload.memory?.query_prompt_template)
|
||||||
|
payload.memory.query_prompt_template = replaceOldVarInText(payload.memory.query_prompt_template, oldVarSelector, newVarSelector)
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
payload.prompt_template = {
|
payload.prompt_template = {
|
||||||
|
@ -50,6 +50,13 @@ const nodeDefault: NodeDefault<LLMNodeType> = {
|
|||||||
if (isPromptyEmpty)
|
if (isPromptyEmpty)
|
||||||
errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') })
|
errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.llm.prompt') })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!errorMessages && !!payload.memory) {
|
||||||
|
const isChatModel = payload.model.mode === 'chat'
|
||||||
|
// payload.memory.query_prompt_template not pass is default: {{#sys.query#}}
|
||||||
|
if (isChatModel && !!payload.memory.query_prompt_template && !payload.memory.query_prompt_template.includes('{{#sys.query#}}'))
|
||||||
|
errorMessages = t('workflow.nodes.llm.sysQueryInUser')
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
isValid: !errorMessages,
|
isValid: !errorMessages,
|
||||||
errorMessage: errorMessages,
|
errorMessage: errorMessages,
|
||||||
|
@ -50,7 +50,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||||||
handleContextVarChange,
|
handleContextVarChange,
|
||||||
filterInputVar,
|
filterInputVar,
|
||||||
filterVar,
|
filterVar,
|
||||||
|
availableVars,
|
||||||
|
availableNodes,
|
||||||
handlePromptChange,
|
handlePromptChange,
|
||||||
|
handleSyeQueryChange,
|
||||||
handleMemoryChange,
|
handleMemoryChange,
|
||||||
handleVisionResolutionEnabledChange,
|
handleVisionResolutionEnabledChange,
|
||||||
handleVisionResolutionChange,
|
handleVisionResolutionChange,
|
||||||
@ -204,19 +207,20 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||||||
<HelpCircle className='w-3.5 h-3.5 text-gray-400' />
|
<HelpCircle className='w-3.5 h-3.5 text-gray-400' />
|
||||||
</TooltipPlus>
|
</TooltipPlus>
|
||||||
</div>}
|
</div>}
|
||||||
value={'{{#sys.query#}}'}
|
value={inputs.memory.query_prompt_template || '{{#sys.query#}}'}
|
||||||
onChange={() => { }}
|
onChange={handleSyeQueryChange}
|
||||||
readOnly
|
readOnly={readOnly}
|
||||||
isShowContext={false}
|
isShowContext={false}
|
||||||
isChatApp
|
isChatApp
|
||||||
isChatModel={false}
|
isChatModel
|
||||||
hasSetBlockStatus={{
|
hasSetBlockStatus={hasSetBlockStatus}
|
||||||
query: false,
|
nodesOutputVars={availableVars}
|
||||||
history: true,
|
availableNodes={availableNodes}
|
||||||
context: true,
|
|
||||||
}}
|
|
||||||
availableNodes={[startNode!]}
|
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
||||||
|
<div className='leading-[18px] text-xs font-normal text-[#DC6803]'>{t(`${i18nPrefix}.sysQueryInUser`)}</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
@ -8,6 +8,7 @@ import {
|
|||||||
useIsChatMode,
|
useIsChatMode,
|
||||||
useNodesReadOnly,
|
useNodesReadOnly,
|
||||||
} from '../../hooks'
|
} from '../../hooks'
|
||||||
|
import useAvailableVarList from '../_base/hooks/use-available-var-list'
|
||||||
import type { LLMNodeType } from './types'
|
import type { LLMNodeType } from './types'
|
||||||
import { Resolution } from '@/types/app'
|
import { Resolution } from '@/types/app'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
@ -206,6 +207,24 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
setInputs(newInputs)
|
setInputs(newInputs)
|
||||||
}, [inputs, setInputs])
|
}, [inputs, setInputs])
|
||||||
|
|
||||||
|
const handleSyeQueryChange = useCallback((newQuery: string) => {
|
||||||
|
const newInputs = produce(inputs, (draft) => {
|
||||||
|
if (!draft.memory) {
|
||||||
|
draft.memory = {
|
||||||
|
window: {
|
||||||
|
enabled: false,
|
||||||
|
size: 10,
|
||||||
|
},
|
||||||
|
query_prompt_template: newQuery,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
draft.memory.query_prompt_template = newQuery
|
||||||
|
}
|
||||||
|
})
|
||||||
|
setInputs(newInputs)
|
||||||
|
}, [inputs, setInputs])
|
||||||
|
|
||||||
const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
|
const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => {
|
||||||
const newInputs = produce(inputs, (draft) => {
|
const newInputs = produce(inputs, (draft) => {
|
||||||
if (!draft.vision) {
|
if (!draft.vision) {
|
||||||
@ -248,6 +267,14 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
|
return [VarType.arrayObject, VarType.array, VarType.string].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
const {
|
||||||
|
availableVars,
|
||||||
|
availableNodes,
|
||||||
|
} = useAvailableVarList(id, {
|
||||||
|
onlyLeafNodeVar: false,
|
||||||
|
filterVar,
|
||||||
|
})
|
||||||
|
|
||||||
// single run
|
// single run
|
||||||
const {
|
const {
|
||||||
isShowSingleRun,
|
isShowSingleRun,
|
||||||
@ -322,8 +349,10 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
|
|
||||||
const allVarStrArr = (() => {
|
const allVarStrArr = (() => {
|
||||||
const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
|
const arr = isChatModel ? (inputs.prompt_template as PromptItem[]).map(item => item.text) : [(inputs.prompt_template as PromptItem).text]
|
||||||
if (isChatMode && isChatModel && !!inputs.memory)
|
if (isChatMode && isChatModel && !!inputs.memory) {
|
||||||
arr.push('{{#sys.query#}}')
|
arr.push('{{#sys.query#}}')
|
||||||
|
arr.push(inputs.memory.query_prompt_template)
|
||||||
|
}
|
||||||
|
|
||||||
return arr
|
return arr
|
||||||
})()
|
})()
|
||||||
@ -346,8 +375,11 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||||||
handleContextVarChange,
|
handleContextVarChange,
|
||||||
filterInputVar,
|
filterInputVar,
|
||||||
filterVar,
|
filterVar,
|
||||||
|
availableVars,
|
||||||
|
availableNodes,
|
||||||
handlePromptChange,
|
handlePromptChange,
|
||||||
handleMemoryChange,
|
handleMemoryChange,
|
||||||
|
handleSyeQueryChange,
|
||||||
handleVisionResolutionEnabledChange,
|
handleVisionResolutionEnabledChange,
|
||||||
handleVisionResolutionChange,
|
handleVisionResolutionChange,
|
||||||
isShowSingleRun,
|
isShowSingleRun,
|
||||||
|
@ -143,6 +143,7 @@ export type Memory = {
|
|||||||
enabled: boolean
|
enabled: boolean
|
||||||
size: number | string | null
|
size: number | string | null
|
||||||
}
|
}
|
||||||
|
query_prompt_template: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum VarType {
|
export enum VarType {
|
||||||
|
@ -204,6 +204,7 @@ const translation = {
|
|||||||
singleRun: {
|
singleRun: {
|
||||||
variable: 'Variable',
|
variable: 'Variable',
|
||||||
},
|
},
|
||||||
|
sysQueryInUser: 'sys.query in user message is required',
|
||||||
},
|
},
|
||||||
knowledgeRetrieval: {
|
knowledgeRetrieval: {
|
||||||
queryVariable: 'Query Variable',
|
queryVariable: 'Query Variable',
|
||||||
|
@ -204,6 +204,7 @@ const translation = {
|
|||||||
singleRun: {
|
singleRun: {
|
||||||
variable: '变量',
|
variable: '变量',
|
||||||
},
|
},
|
||||||
|
sysQueryInUser: 'user message 中必须包含 sys.query',
|
||||||
},
|
},
|
||||||
knowledgeRetrieval: {
|
knowledgeRetrieval: {
|
||||||
queryVariable: '查询变量',
|
queryVariable: '查询变量',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user