mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-11 19:09:01 +08:00
Remove <think> for exeSql component. (#5069)
### What problem does this PR solve? #5061 #5067 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
4694604836
commit
84b4b38cbb
@ -60,7 +60,8 @@ class ExeSQLParam(GenerateParam):
|
|||||||
class ExeSQL(Generate, ABC):
|
class ExeSQL(Generate, ABC):
|
||||||
component_name = "ExeSQL"
|
component_name = "ExeSQL"
|
||||||
|
|
||||||
def _refactor(self,ans):
|
def _refactor(self, ans):
|
||||||
|
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||||
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
|
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
ans = match.group(1) # Query content
|
ans = match.group(1) # Query content
|
||||||
@ -78,7 +79,7 @@ class ExeSQL(Generate, ABC):
|
|||||||
ans = self.get_input()
|
ans = self.get_input()
|
||||||
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
|
||||||
ans = self._refactor(ans)
|
ans = self._refactor(ans)
|
||||||
logging.info("db_type: ",self._param.db_type)
|
logging.info("db_type: ", self._param.db_type)
|
||||||
if self._param.db_type in ["mysql", "mariadb"]:
|
if self._param.db_type in ["mysql", "mariadb"]:
|
||||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||||
port=self._param.port, password=self._param.password)
|
port=self._param.port, password=self._param.password)
|
||||||
@ -87,11 +88,11 @@ class ExeSQL(Generate, ABC):
|
|||||||
port=self._param.port, password=self._param.password)
|
port=self._param.port, password=self._param.password)
|
||||||
elif self._param.db_type == 'mssql':
|
elif self._param.db_type == 'mssql':
|
||||||
conn_str = (
|
conn_str = (
|
||||||
r'DRIVER={ODBC Driver 17 for SQL Server};'
|
r'DRIVER={ODBC Driver 17 for SQL Server};'
|
||||||
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
|
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
|
||||||
r'DATABASE=' + self._param.database + ';'
|
r'DATABASE=' + self._param.database + ';'
|
||||||
r'UID=' + self._param.username + ';'
|
r'UID=' + self._param.username + ';'
|
||||||
r'PWD=' + self._param.password
|
r'PWD=' + self._param.password
|
||||||
)
|
)
|
||||||
db = pyodbc.connect(conn_str)
|
db = pyodbc.connect(conn_str)
|
||||||
try:
|
try:
|
||||||
@ -101,12 +102,12 @@ class ExeSQL(Generate, ABC):
|
|||||||
if not hasattr(self, "_loop"):
|
if not hasattr(self, "_loop"):
|
||||||
setattr(self, "_loop", 0)
|
setattr(self, "_loop", 0)
|
||||||
self._loop += 1
|
self._loop += 1
|
||||||
input_list=re.split(r';', ans.replace(r"\n", " "))
|
input_list = re.split(r';', ans.replace(r"\n", " "))
|
||||||
sql_res = []
|
sql_res = []
|
||||||
for i in range(len(input_list)):
|
for i in range(len(input_list)):
|
||||||
single_sql=input_list[i]
|
single_sql = input_list[i]
|
||||||
while self._loop <= self._param.loop:
|
while self._loop <= self._param.loop:
|
||||||
self._loop+=1
|
self._loop += 1
|
||||||
if not single_sql:
|
if not single_sql:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
@ -116,11 +117,12 @@ class ExeSQL(Generate, ABC):
|
|||||||
sql_res.append({"content": "No record in the database!"})
|
sql_res.append({"content": "No record in the database!"})
|
||||||
break
|
break
|
||||||
if self._param.db_type == 'mssql':
|
if self._param.db_type == 'mssql':
|
||||||
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),columns = [desc[0] for desc in cursor.description])
|
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),
|
||||||
|
columns=[desc[0] for desc in cursor.description])
|
||||||
else:
|
else:
|
||||||
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
|
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
|
||||||
single_res.columns = [i[0] for i in cursor.description]
|
single_res.columns = [i[0] for i in cursor.description]
|
||||||
sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")})
|
sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")})
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
single_sql = self._regenerate_sql(single_sql, str(e), **kwargs)
|
single_sql = self._regenerate_sql(single_sql, str(e), **kwargs)
|
||||||
@ -133,19 +135,19 @@ class ExeSQL(Generate, ABC):
|
|||||||
return ExeSQL.be_output("")
|
return ExeSQL.be_output("")
|
||||||
return pd.DataFrame(sql_res)
|
return pd.DataFrame(sql_res)
|
||||||
|
|
||||||
def _regenerate_sql(self, failed_sql, error_message,**kwargs):
|
def _regenerate_sql(self, failed_sql, error_message, **kwargs):
|
||||||
prompt = f'''
|
prompt = f'''
|
||||||
## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report.
|
## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report.
|
||||||
## The original SQL statement is as follows:{failed_sql}.
|
## The original SQL statement is as follows:{failed_sql}.
|
||||||
## The contents of the SQL query error report is as follows:{error_message}.
|
## The contents of the SQL query error report is as follows:{error_message}.
|
||||||
## Answer only the modified SQL statement. Please do not give any explanation, just answer the code.
|
## Answer only the modified SQL statement. Please do not give any explanation, just answer the code.
|
||||||
'''
|
'''
|
||||||
self._param.prompt=prompt
|
self._param.prompt = prompt
|
||||||
kwargs_ = deepcopy(kwargs)
|
kwargs_ = deepcopy(kwargs)
|
||||||
kwargs_["stream"] = False
|
kwargs_["stream"] = False
|
||||||
response = Generate._run(self, [], **kwargs_)
|
response = Generate._run(self, [], **kwargs_)
|
||||||
try:
|
try:
|
||||||
regenerated_sql = response.loc[0,"content"]
|
regenerated_sql = response.loc[0, "content"]
|
||||||
return regenerated_sql
|
return regenerated_sql
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to regenerate SQL: {e}")
|
logging.error(f"Failed to regenerate SQL: {e}")
|
||||||
|
@ -567,7 +567,7 @@ Requirements:
|
|||||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"<think>.*</think>", "", kwd)
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
if kwd.find("**ERROR**") >= 0:
|
if kwd.find("**ERROR**") >= 0:
|
||||||
return ""
|
return ""
|
||||||
return kwd
|
return kwd
|
||||||
@ -597,7 +597,7 @@ Requirements:
|
|||||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"<think>.*</think>", "", kwd)
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
if kwd.find("**ERROR**") >= 0:
|
if kwd.find("**ERROR**") >= 0:
|
||||||
return ""
|
return ""
|
||||||
return kwd
|
return kwd
|
||||||
@ -668,7 +668,7 @@ Output: What's the weather in Rochester on {tomorrow}?
|
|||||||
###############
|
###############
|
||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
||||||
ans = re.sub(r"<think>.*</think>", "", ans)
|
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||||
|
|
||||||
|
|
||||||
@ -793,7 +793,7 @@ Output:
|
|||||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"<think>.*</think>", "", kwd)
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
if kwd.find("**ERROR**") >= 0:
|
if kwd.find("**ERROR**") >= 0:
|
||||||
raise Exception(kwd)
|
raise Exception(kwd)
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class Extractor:
|
|||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
response = self._llm.chat(system, hist, conf)
|
response = self._llm.chat(system, hist, conf)
|
||||||
response = re.sub(r"<think>.*</think>", "", response)
|
response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
|
||||||
if response.find("**ERROR**") >= 0:
|
if response.find("**ERROR**") >= 0:
|
||||||
raise Exception(response)
|
raise Exception(response)
|
||||||
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user