feat: agent node add memory (#15976)

This commit is contained in:
Novice 2025-04-03 16:40:58 +08:00 committed by GitHub
parent 3d76f09c3a
commit dcdec98c8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 116 additions and 20 deletions

View File

@ -70,11 +70,20 @@ class AgentStrategyIdentity(ToolIdentity):
pass pass
class AgentFeature(enum.StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""
HISTORY_MESSAGES = "history-messages"
class AgentStrategyEntity(BaseModel): class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list) parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy") description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

View File

@ -1,15 +1,18 @@
import json import json
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.model_manager import ModelManager from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories.agent_factory import get_plugin_agent_strategy from factories.agent_factory import get_plugin_agent_strategy
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -233,17 +238,20 @@ class AgentNode(ToolNode):
value = tool_value value = tool_value
if parameter.type == "model-selector": if parameter.type == "model-selector":
value = cast(dict[str, Any], value) value = cast(dict[str, Any], value)
model_instance = ModelManager().get_model_instance( model_instance, model_schema = self._fetch_model(value)
tenant_id=self.tenant_id, # memory config
provider=value.get("provider", ""), history_prompt_messages = []
model_type=ModelType(value.get("model_type", "")), if node_data.memory:
model=value.get("model", ""), memory = self._fetch_memory(model_instance)
) if memory:
models = model_instance.model_type_instance.plugin_model_provider.declaration.models prompt_messages = memory.get_history_prompt_messages(
finded_model = next((model for model in models if model.model == value.get("model", "")), None) message_limit=node_data.memory.window.size if node_data.memory.window.size else None
)
value["entity"] = finded_model.model_dump(mode="json") if finded_model else None history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
result[parameter_name] = value result[parameter_name] = value
return result return result
@ -297,3 +305,46 @@ class AgentNode(ToolNode):
except StopIteration: except StopIteration:
icon = None icon = None
return icon return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema

View File

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.entities import BaseNodeData
@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy agent_strategy_provider_name: str # redundancy
agent_strategy_name: str agent_strategy_name: str
agent_strategy_label: str # redundancy agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
class AgentInput(BaseModel): class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any] value: Union[list[str], list[ToolSelector], Any]

View File

@ -1,7 +1,7 @@
import type { CredentialFormSchemaBase } from '../header/account-setting/model-provider-page/declarations' import type { CredentialFormSchemaBase } from '../header/account-setting/model-provider-page/declarations'
import type { ToolCredential } from '@/app/components/tools/types' import type { ToolCredential } from '@/app/components/tools/types'
import type { Locale } from '@/i18n' import type { Locale } from '@/i18n'
import type { AgentFeature } from '@/app/components/workflow/nodes/agent/types'
export enum PluginType { export enum PluginType {
tool = 'tool', tool = 'tool',
model = 'model', model = 'model',
@ -418,6 +418,7 @@ export type StrategyDetail = {
parameters: StrategyParamItem[] parameters: StrategyParamItem[]
description: Record<Locale, string> description: Record<Locale, string>
output_schema: Record<string, any> output_schema: Record<string, any>
features: AgentFeature[]
} }
export type StrategyDeclaration = { export type StrategyDeclaration = {

View File

@ -1,7 +1,7 @@
import type { FC } from 'react' import type { FC } from 'react'
import { memo, useMemo } from 'react' import { memo, useMemo } from 'react'
import type { NodePanelProps } from '../../types' import type { NodePanelProps } from '../../types'
import type { AgentNodeType } from './types' import { AgentFeature, type AgentNodeType } from './types'
import Field from '../_base/components/field' import Field from '../_base/components/field'
import { AgentStrategy } from '../_base/components/agent-strategy' import { AgentStrategy } from '../_base/components/agent-strategy'
import useConfig from './use-config' import useConfig from './use-config'
@ -16,6 +16,8 @@ import { useLogs } from '@/app/components/workflow/run/hooks'
import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form' import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form'
import { toType } from '@/app/components/tools/utils/to-form-schema' import { toType } from '@/app/components/tools/utils/to-form-schema'
import { useStore } from '../../store' import { useStore } from '../../store'
import Split from '../_base/components/split'
import MemoryConfig from '../_base/components/memory-config'
const i18nPrefix = 'workflow.nodes.agent' const i18nPrefix = 'workflow.nodes.agent'
@ -35,10 +37,10 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
currentStrategy, currentStrategy,
formData, formData,
onFormChange, onFormChange,
isChatMode,
availableNodesWithParent, availableNodesWithParent,
availableVars, availableVars,
readOnly,
isShowSingleRun, isShowSingleRun,
hideSingleRun, hideSingleRun,
runningStatus, runningStatus,
@ -49,6 +51,7 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
setRunInputData, setRunInputData,
varInputs, varInputs,
outputSchema, outputSchema,
handleMemoryChange,
} = useConfig(props.id, props.data) } = useConfig(props.id, props.data)
const { t } = useTranslation() const { t } = useTranslation()
const nodeInfo = useMemo(() => { const nodeInfo = useMemo(() => {
@ -106,6 +109,20 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
nodeId={props.id} nodeId={props.id}
/> />
</Field> </Field>
<div className='px-4 py-2'>
{isChatMode && currentStrategy?.features.includes(AgentFeature.HISTORY_MESSAGES) && (
<>
<Split />
<MemoryConfig
className='mt-4'
readonly={readOnly}
config={{ data: inputs.memory }}
onChange={handleMemoryChange}
canSetRoleName={false}
/>
</>
)}
</div>
<div> <div>
<OutputVars> <OutputVars>
<VarItem <VarItem

View File

@ -1,4 +1,4 @@
import type { CommonNodeType } from '@/app/components/workflow/types' import type { CommonNodeType, Memory } from '@/app/components/workflow/types'
import type { ToolVarInputs } from '../tool/types' import type { ToolVarInputs } from '../tool/types'
export type AgentNodeType = CommonNodeType & { export type AgentNodeType = CommonNodeType & {
@ -8,4 +8,9 @@ export type AgentNodeType = CommonNodeType & {
agent_parameters?: ToolVarInputs agent_parameters?: ToolVarInputs
output_schema: Record<string, any> output_schema: Record<string, any>
plugin_unique_identifier?: string plugin_unique_identifier?: string
memory?: Memory
}
export enum AgentFeature {
HISTORY_MESSAGES = 'history-messages',
} }

View File

@ -4,14 +4,16 @@ import useVarList from '../_base/hooks/use-var-list'
import useOneStepRun from '../_base/hooks/use-one-step-run' import useOneStepRun from '../_base/hooks/use-one-step-run'
import type { AgentNodeType } from './types' import type { AgentNodeType } from './types'
import { import {
useIsChatMode,
useNodesReadOnly, useNodesReadOnly,
} from '@/app/components/workflow/hooks' } from '@/app/components/workflow/hooks'
import { useCallback, useMemo } from 'react' import { useCallback, useMemo } from 'react'
import { type ToolVarInputs, VarType } from '../tool/types' import { type ToolVarInputs, VarType } from '../tool/types'
import { useCheckInstalled, useFetchPluginsInMarketPlaceByIds } from '@/service/use-plugins' import { useCheckInstalled, useFetchPluginsInMarketPlaceByIds } from '@/service/use-plugins'
import type { Var } from '../../types' import type { Memory, Var } from '../../types'
import { VarType as VarKindType } from '../../types' import { VarType as VarKindType } from '../../types'
import useAvailableVarList from '../_base/hooks/use-available-var-list' import useAvailableVarList from '../_base/hooks/use-available-var-list'
import produce from 'immer'
export type StrategyStatus = { export type StrategyStatus = {
plugin: { plugin: {
@ -175,6 +177,13 @@ const useConfig = (id: string, payload: AgentNodeType) => {
return res return res
}, [inputs.output_schema]) }, [inputs.output_schema])
const handleMemoryChange = useCallback((newMemory?: Memory) => {
const newInputs = produce(inputs, (draft) => {
draft.memory = newMemory
})
setInputs(newInputs)
}, [inputs, setInputs])
const isChatMode = useIsChatMode()
return { return {
readOnly, readOnly,
inputs, inputs,
@ -202,6 +211,8 @@ const useConfig = (id: string, payload: AgentNodeType) => {
runResult, runResult,
varInputs, varInputs,
outputSchema, outputSchema,
handleMemoryChange,
isChatMode,
} }
} }