mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-02 05:50:36 +08:00
feat: agent node add memory (#15976)
This commit is contained in:
parent
3d76f09c3a
commit
dcdec98c8e
@ -70,11 +70,20 @@ class AgentStrategyIdentity(ToolIdentity):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFeature(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Agent Feature, used to describe the features of the agent strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
HISTORY_MESSAGES = "history-messages"
|
||||||
|
|
||||||
|
|
||||||
class AgentStrategyEntity(BaseModel):
|
class AgentStrategyEntity(BaseModel):
|
||||||
identity: AgentStrategyIdentity
|
identity: AgentStrategyIdentity
|
||||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||||
output_schema: Optional[dict] = None
|
output_schema: Optional[dict] = None
|
||||||
|
features: Optional[list[AgentFeature]] = None
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
from core.model_manager import ModelManager
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.manager.plugin import PluginInstallationManager
|
from core.plugin.manager.plugin import PluginInstallationManager
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
|
from core.variables.segments import StringSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
|
|||||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
from extensions.ext_database import db
|
||||||
from factories.agent_factory import get_plugin_agent_strategy
|
from factories.agent_factory import get_plugin_agent_strategy
|
||||||
|
from models.model import Conversation
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -233,17 +238,20 @@ class AgentNode(ToolNode):
|
|||||||
value = tool_value
|
value = tool_value
|
||||||
if parameter.type == "model-selector":
|
if parameter.type == "model-selector":
|
||||||
value = cast(dict[str, Any], value)
|
value = cast(dict[str, Any], value)
|
||||||
model_instance = ModelManager().get_model_instance(
|
model_instance, model_schema = self._fetch_model(value)
|
||||||
tenant_id=self.tenant_id,
|
# memory config
|
||||||
provider=value.get("provider", ""),
|
history_prompt_messages = []
|
||||||
model_type=ModelType(value.get("model_type", "")),
|
if node_data.memory:
|
||||||
model=value.get("model", ""),
|
memory = self._fetch_memory(model_instance)
|
||||||
)
|
if memory:
|
||||||
models = model_instance.model_type_instance.plugin_model_provider.declaration.models
|
prompt_messages = memory.get_history_prompt_messages(
|
||||||
finded_model = next((model for model in models if model.model == value.get("model", "")), None)
|
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
||||||
|
)
|
||||||
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
|
history_prompt_messages = [
|
||||||
|
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||||
|
]
|
||||||
|
value["history_prompt_messages"] = history_prompt_messages
|
||||||
|
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
|
||||||
result[parameter_name] = value
|
result[parameter_name] = value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -297,3 +305,46 @@ class AgentNode(ToolNode):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
icon = None
|
icon = None
|
||||||
return icon
|
return icon
|
||||||
|
|
||||||
|
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||||
|
# get conversation id
|
||||||
|
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||||
|
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||||
|
)
|
||||||
|
if not isinstance(conversation_id_variable, StringSegment):
|
||||||
|
return None
|
||||||
|
conversation_id = conversation_id_variable.value
|
||||||
|
|
||||||
|
# get conversation
|
||||||
|
conversation = (
|
||||||
|
db.session.query(Conversation)
|
||||||
|
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||||
|
|
||||||
|
return memory
|
||||||
|
|
||||||
|
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||||
|
provider_manager = ProviderManager()
|
||||||
|
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||||
|
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
model_name = value.get("model", "")
|
||||||
|
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||||
|
model_type=ModelType.LLM, model=model_name
|
||||||
|
)
|
||||||
|
provider_name = provider_model_bundle.configuration.provider.provider
|
||||||
|
model_type_instance = provider_model_bundle.model_type_instance
|
||||||
|
model_instance = ModelManager().get_model_instance(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=provider_name,
|
||||||
|
model_type=ModelType(value.get("model_type", "")),
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
return model_instance, model_schema
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, Literal, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.tools.entities.tool_entities import ToolSelector
|
from core.tools.entities.tool_entities import ToolSelector
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
|
|||||||
agent_strategy_provider_name: str # redundancy
|
agent_strategy_provider_name: str # redundancy
|
||||||
agent_strategy_name: str
|
agent_strategy_name: str
|
||||||
agent_strategy_label: str # redundancy
|
agent_strategy_label: str # redundancy
|
||||||
|
memory: MemoryConfig | None = None
|
||||||
|
|
||||||
class AgentInput(BaseModel):
|
class AgentInput(BaseModel):
|
||||||
value: Union[list[str], list[ToolSelector], Any]
|
value: Union[list[str], list[ToolSelector], Any]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { CredentialFormSchemaBase } from '../header/account-setting/model-provider-page/declarations'
|
import type { CredentialFormSchemaBase } from '../header/account-setting/model-provider-page/declarations'
|
||||||
import type { ToolCredential } from '@/app/components/tools/types'
|
import type { ToolCredential } from '@/app/components/tools/types'
|
||||||
import type { Locale } from '@/i18n'
|
import type { Locale } from '@/i18n'
|
||||||
|
import type { AgentFeature } from '@/app/components/workflow/nodes/agent/types'
|
||||||
export enum PluginType {
|
export enum PluginType {
|
||||||
tool = 'tool',
|
tool = 'tool',
|
||||||
model = 'model',
|
model = 'model',
|
||||||
@ -418,6 +418,7 @@ export type StrategyDetail = {
|
|||||||
parameters: StrategyParamItem[]
|
parameters: StrategyParamItem[]
|
||||||
description: Record<Locale, string>
|
description: Record<Locale, string>
|
||||||
output_schema: Record<string, any>
|
output_schema: Record<string, any>
|
||||||
|
features: AgentFeature[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export type StrategyDeclaration = {
|
export type StrategyDeclaration = {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import { memo, useMemo } from 'react'
|
import { memo, useMemo } from 'react'
|
||||||
import type { NodePanelProps } from '../../types'
|
import type { NodePanelProps } from '../../types'
|
||||||
import type { AgentNodeType } from './types'
|
import { AgentFeature, type AgentNodeType } from './types'
|
||||||
import Field from '../_base/components/field'
|
import Field from '../_base/components/field'
|
||||||
import { AgentStrategy } from '../_base/components/agent-strategy'
|
import { AgentStrategy } from '../_base/components/agent-strategy'
|
||||||
import useConfig from './use-config'
|
import useConfig from './use-config'
|
||||||
@ -16,6 +16,8 @@ import { useLogs } from '@/app/components/workflow/run/hooks'
|
|||||||
import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form'
|
import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form'
|
||||||
import { toType } from '@/app/components/tools/utils/to-form-schema'
|
import { toType } from '@/app/components/tools/utils/to-form-schema'
|
||||||
import { useStore } from '../../store'
|
import { useStore } from '../../store'
|
||||||
|
import Split from '../_base/components/split'
|
||||||
|
import MemoryConfig from '../_base/components/memory-config'
|
||||||
|
|
||||||
const i18nPrefix = 'workflow.nodes.agent'
|
const i18nPrefix = 'workflow.nodes.agent'
|
||||||
|
|
||||||
@ -35,10 +37,10 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
|
|||||||
currentStrategy,
|
currentStrategy,
|
||||||
formData,
|
formData,
|
||||||
onFormChange,
|
onFormChange,
|
||||||
|
isChatMode,
|
||||||
availableNodesWithParent,
|
availableNodesWithParent,
|
||||||
availableVars,
|
availableVars,
|
||||||
|
readOnly,
|
||||||
isShowSingleRun,
|
isShowSingleRun,
|
||||||
hideSingleRun,
|
hideSingleRun,
|
||||||
runningStatus,
|
runningStatus,
|
||||||
@ -49,6 +51,7 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
|
|||||||
setRunInputData,
|
setRunInputData,
|
||||||
varInputs,
|
varInputs,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
|
handleMemoryChange,
|
||||||
} = useConfig(props.id, props.data)
|
} = useConfig(props.id, props.data)
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const nodeInfo = useMemo(() => {
|
const nodeInfo = useMemo(() => {
|
||||||
@ -106,6 +109,20 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
|
|||||||
nodeId={props.id}
|
nodeId={props.id}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
|
<div className='px-4 py-2'>
|
||||||
|
{isChatMode && currentStrategy?.features.includes(AgentFeature.HISTORY_MESSAGES) && (
|
||||||
|
<>
|
||||||
|
<Split />
|
||||||
|
<MemoryConfig
|
||||||
|
className='mt-4'
|
||||||
|
readonly={readOnly}
|
||||||
|
config={{ data: inputs.memory }}
|
||||||
|
onChange={handleMemoryChange}
|
||||||
|
canSetRoleName={false}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<OutputVars>
|
<OutputVars>
|
||||||
<VarItem
|
<VarItem
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import type { CommonNodeType } from '@/app/components/workflow/types'
|
import type { CommonNodeType, Memory } from '@/app/components/workflow/types'
|
||||||
import type { ToolVarInputs } from '../tool/types'
|
import type { ToolVarInputs } from '../tool/types'
|
||||||
|
|
||||||
export type AgentNodeType = CommonNodeType & {
|
export type AgentNodeType = CommonNodeType & {
|
||||||
@ -8,4 +8,9 @@ export type AgentNodeType = CommonNodeType & {
|
|||||||
agent_parameters?: ToolVarInputs
|
agent_parameters?: ToolVarInputs
|
||||||
output_schema: Record<string, any>
|
output_schema: Record<string, any>
|
||||||
plugin_unique_identifier?: string
|
plugin_unique_identifier?: string
|
||||||
|
memory?: Memory
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum AgentFeature {
|
||||||
|
HISTORY_MESSAGES = 'history-messages',
|
||||||
}
|
}
|
||||||
|
@ -4,14 +4,16 @@ import useVarList from '../_base/hooks/use-var-list'
|
|||||||
import useOneStepRun from '../_base/hooks/use-one-step-run'
|
import useOneStepRun from '../_base/hooks/use-one-step-run'
|
||||||
import type { AgentNodeType } from './types'
|
import type { AgentNodeType } from './types'
|
||||||
import {
|
import {
|
||||||
|
useIsChatMode,
|
||||||
useNodesReadOnly,
|
useNodesReadOnly,
|
||||||
} from '@/app/components/workflow/hooks'
|
} from '@/app/components/workflow/hooks'
|
||||||
import { useCallback, useMemo } from 'react'
|
import { useCallback, useMemo } from 'react'
|
||||||
import { type ToolVarInputs, VarType } from '../tool/types'
|
import { type ToolVarInputs, VarType } from '../tool/types'
|
||||||
import { useCheckInstalled, useFetchPluginsInMarketPlaceByIds } from '@/service/use-plugins'
|
import { useCheckInstalled, useFetchPluginsInMarketPlaceByIds } from '@/service/use-plugins'
|
||||||
import type { Var } from '../../types'
|
import type { Memory, Var } from '../../types'
|
||||||
import { VarType as VarKindType } from '../../types'
|
import { VarType as VarKindType } from '../../types'
|
||||||
import useAvailableVarList from '../_base/hooks/use-available-var-list'
|
import useAvailableVarList from '../_base/hooks/use-available-var-list'
|
||||||
|
import produce from 'immer'
|
||||||
|
|
||||||
export type StrategyStatus = {
|
export type StrategyStatus = {
|
||||||
plugin: {
|
plugin: {
|
||||||
@ -175,6 +177,13 @@ const useConfig = (id: string, payload: AgentNodeType) => {
|
|||||||
return res
|
return res
|
||||||
}, [inputs.output_schema])
|
}, [inputs.output_schema])
|
||||||
|
|
||||||
|
const handleMemoryChange = useCallback((newMemory?: Memory) => {
|
||||||
|
const newInputs = produce(inputs, (draft) => {
|
||||||
|
draft.memory = newMemory
|
||||||
|
})
|
||||||
|
setInputs(newInputs)
|
||||||
|
}, [inputs, setInputs])
|
||||||
|
const isChatMode = useIsChatMode()
|
||||||
return {
|
return {
|
||||||
readOnly,
|
readOnly,
|
||||||
inputs,
|
inputs,
|
||||||
@ -202,6 +211,8 @@ const useConfig = (id: string, payload: AgentNodeType) => {
|
|||||||
runResult,
|
runResult,
|
||||||
varInputs,
|
varInputs,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
|
handleMemoryChange,
|
||||||
|
isChatMode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user