diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 874b2800b2..f2d1bd305a 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,9 +1,9 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, ValidationInfo, field_validator -class TracingProviderEnum(Enum): +class TracingProviderEnum(StrEnum): LANGFUSE = "langfuse" LANGSMITH = "langsmith" OPIK = "opik" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 2c68055f87..2bcca6ccea 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -16,11 +16,7 @@ from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, - LangfuseConfig, - LangSmithConfig, - OpikConfig, TracingProviderEnum, - WeaveConfig, ) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, @@ -33,11 +29,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from core.ops.opik_trace.opik_trace import OpikDataTrace from core.ops.utils import get_message_data -from core.ops.weave_trace.weave_trace import WeaveDataTrace from extensions.ext_database import db from extensions.ext_storage import storage from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig @@ -45,36 +37,58 @@ from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -def build_opik_trace_instance(config: OpikConfig): - return OpikDataTrace(config) +class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: + case TracingProviderEnum.LANGFUSE: + from core.ops.entities.config_entity import LangfuseConfig + from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } + + case TracingProviderEnum.LANGSMITH: + from core.ops.entities.config_entity import LangSmithConfig + from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } + + case TracingProviderEnum.OPIK: + from core.ops.entities.config_entity import OpikConfig + from core.ops.opik_trace.opik_trace import OpikDataTrace + + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } + + case TracingProviderEnum.WEAVE: + from core.ops.entities.config_entity import WeaveConfig + from core.ops.weave_trace.weave_trace import WeaveDataTrace + + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint"], + "trace_instance": WeaveDataTrace, + } + + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") -provider_config_map: dict[str, dict[str, Any]] = { - TracingProviderEnum.LANGFUSE.value: { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - }, - TracingProviderEnum.LANGSMITH.value: { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - }, - TracingProviderEnum.OPIK.value: { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": lambda config: build_opik_trace_instance(config), - }, - TracingProviderEnum.WEAVE.value: { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint"], - "trace_instance": WeaveDataTrace, - }, -} +provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() class OpsTraceManager: