mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 04:15:57 +08:00
fix: Ensure model config integrity in retrieval processes (#20576)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
257bf13fef
commit
36f1b4b222
@ -175,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
if node_data.single_retrieval_config is None:
|
||||||
|
raise ValueError("single_retrieval_config is required")
|
||||||
|
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||||
# check model is support tool calling
|
# check model is support tool calling
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
@ -426,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
raise ValueError("metadata_model_config is required")
|
raise ValueError("metadata_model_config is required")
|
||||||
# get metadata model instance
|
# get metadata model instance
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
model_instance, model_config = self.get_model_config(metadata_model_config)
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
prompt_template = self._get_prompt_template(
|
prompt_template = self._get_prompt_template(
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
@ -552,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
"""
|
|
||||||
Fetch model config
|
|
||||||
:param model: model
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("model is required")
|
|
||||||
model_name = model.name
|
model_name = model.name
|
||||||
provider_name = model.provider
|
provider_name = model.provider
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user