mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00

### What problem does this PR solve? https://github.com/infiniflow/ragflow/issues/8006 The category should work well, but the category's downstream seems to be unable to get the upstream output. Add the category's output as an attribute. However, in base.py, there is logic ` if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]: continue` If goto this cases will not tried to get output from Category (but I do not have full context about this if logic). ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
115 lines
4.2 KiB
Python
115 lines
4.2 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import logging
|
||
from abc import ABC
|
||
from api.db import LLMType
|
||
from api.db.services.llm_service import LLMBundle
|
||
from agent.component import GenerateParam, Generate
|
||
|
||
|
||
class CategorizeParam(GenerateParam):
|
||
|
||
"""
|
||
Define the Categorize component parameters.
|
||
"""
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.category_description = {}
|
||
self.prompt = ""
|
||
|
||
def check(self):
|
||
super().check()
|
||
self.check_empty(self.category_description, "[Categorize] Category examples")
|
||
for k, v in self.category_description.items():
|
||
if not k:
|
||
raise ValueError("[Categorize] Category name can not be empty!")
|
||
if not v.get("to"):
|
||
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
||
|
||
def get_prompt(self, chat_hist):
|
||
cate_lines = []
|
||
for c, desc in self.category_description.items():
|
||
for line in desc.get("examples", "").split("\n"):
|
||
if not line:
|
||
continue
|
||
cate_lines.append("USER: {}\nCategory: {}".format(line, c))
|
||
descriptions = []
|
||
for c, desc in self.category_description.items():
|
||
if desc.get("description"):
|
||
descriptions.append(
|
||
"\nCategory: {}\nDescription: {}".format(c, desc["description"]))
|
||
|
||
self.prompt = """
|
||
Role: You're a text classifier.
|
||
Task: You need to categorize the user’s questions into {} categories, namely: {}
|
||
|
||
Here's description of each category:
|
||
{}
|
||
|
||
You could learn from the following examples:
|
||
{}
|
||
You could learn from the above examples.
|
||
|
||
Requirements:
|
||
- Just mention the category names, no need for any additional words.
|
||
|
||
---- Real Data ----
|
||
USER: {}\n
|
||
""".format(
|
||
len(self.category_description.keys()),
|
||
"/".join(list(self.category_description.keys())),
|
||
"\n".join(descriptions),
|
||
"\n\n- ".join(cate_lines),
|
||
chat_hist
|
||
)
|
||
return self.prompt
|
||
|
||
|
||
class Categorize(Generate, ABC):
|
||
component_name = "Categorize"
|
||
|
||
def _run(self, history, **kwargs):
|
||
input = self.get_input()
|
||
input = " - ".join(input["content"]) if "content" in input else ""
|
||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||
self._canvas.set_component_infor(self._id, {"prompt":self._param.get_prompt(input),"messages": [{"role": "user", "content": "\nCategory: "}],"conf": self._param.gen_conf()})
|
||
|
||
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}],
|
||
self._param.gen_conf())
|
||
logging.debug(f"input: {input}, answer: {str(ans)}")
|
||
# Count the number of times each category appears in the answer.
|
||
category_counts = {}
|
||
for c in self._param.category_description.keys():
|
||
count = ans.lower().count(c.lower())
|
||
category_counts[c] = count
|
||
|
||
# If a category is found, return the category with the highest count.
|
||
if any(category_counts.values()):
|
||
max_category = max(category_counts.items(), key=lambda x: x[1])
|
||
res = Categorize.be_output(self._param.category_description[max_category[0]]["to"])
|
||
self.set_output(res)
|
||
return res
|
||
|
||
res = Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"])
|
||
self.set_output(res)
|
||
return res
|
||
|
||
def debug(self, **kwargs):
|
||
df = self._run([], **kwargs)
|
||
cpn_id = df.iloc[0, 0]
|
||
return Categorize.be_output(self._canvas.get_component_name(cpn_id))
|
||
|