diff --git a/agent/component/exesql.py b/agent/component/exesql.py
index 744e86600..5007c4283 100644
--- a/agent/component/exesql.py
+++ b/agent/component/exesql.py
@@ -60,7 +60,8 @@ class ExeSQLParam(GenerateParam):
class ExeSQL(Generate, ABC):
component_name = "ExeSQL"
- def _refactor(self,ans):
+ def _refactor(self, ans):
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
if match:
ans = match.group(1) # Query content
@@ -78,7 +79,7 @@ class ExeSQL(Generate, ABC):
ans = self.get_input()
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
ans = self._refactor(ans)
- logging.info("db_type: ",self._param.db_type)
+ logging.info("db_type: ", self._param.db_type)
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
@@ -87,11 +88,11 @@ class ExeSQL(Generate, ABC):
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'mssql':
conn_str = (
- r'DRIVER={ODBC Driver 17 for SQL Server};'
- r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
- r'DATABASE=' + self._param.database + ';'
- r'UID=' + self._param.username + ';'
- r'PWD=' + self._param.password
+ r'DRIVER={ODBC Driver 17 for SQL Server};'
+ r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
+ r'DATABASE=' + self._param.database + ';'
+ r'UID=' + self._param.username + ';'
+ r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
try:
@@ -101,12 +102,12 @@ class ExeSQL(Generate, ABC):
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
self._loop += 1
- input_list=re.split(r';', ans.replace(r"\n", " "))
+ input_list = re.split(r';', ans.replace(r"\n", " "))
sql_res = []
for i in range(len(input_list)):
- single_sql=input_list[i]
+ single_sql = input_list[i]
while self._loop <= self._param.loop:
- self._loop+=1
+ self._loop += 1
if not single_sql:
break
try:
@@ -116,11 +117,12 @@ class ExeSQL(Generate, ABC):
sql_res.append({"content": "No record in the database!"})
break
if self._param.db_type == 'mssql':
- single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),columns = [desc[0] for desc in cursor.description])
+ single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),
+ columns=[desc[0] for desc in cursor.description])
else:
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
single_res.columns = [i[0] for i in cursor.description]
- sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")})
+ sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")})
break
except Exception as e:
single_sql = self._regenerate_sql(single_sql, str(e), **kwargs)
@@ -133,19 +135,19 @@ class ExeSQL(Generate, ABC):
return ExeSQL.be_output("")
return pd.DataFrame(sql_res)
- def _regenerate_sql(self, failed_sql, error_message,**kwargs):
+ def _regenerate_sql(self, failed_sql, error_message, **kwargs):
prompt = f'''
## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report.
## The original SQL statement is as follows:{failed_sql}.
## The contents of the SQL query error report is as follows:{error_message}.
## Answer only the modified SQL statement. Please do not give any explanation, just answer the code.
'''
- self._param.prompt=prompt
+ self._param.prompt = prompt
kwargs_ = deepcopy(kwargs)
kwargs_["stream"] = False
response = Generate._run(self, [], **kwargs_)
try:
- regenerated_sql = response.loc[0,"content"]
+ regenerated_sql = response.loc[0, "content"]
return regenerated_sql
except Exception as e:
logging.error(f"Failed to regenerate SQL: {e}")
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index d9103baec..0e8771e33 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -567,7 +567,7 @@ Requirements:
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
- kwd = re.sub(r".*", "", kwd)
+ kwd = re.sub(r".*", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
@@ -597,7 +597,7 @@ Requirements:
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
- kwd = re.sub(r".*", "", kwd)
+ kwd = re.sub(r".*", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
@@ -668,7 +668,7 @@ Output: What's the weather in Rochester on {tomorrow}?
###############
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
- ans = re.sub(r".*", "", ans)
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
@@ -793,7 +793,7 @@ Output:
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
- kwd = re.sub(r".*", "", kwd)
+ kwd = re.sub(r".*", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py
index 50184002f..d211cf381 100644
--- a/graphrag/general/extractor.py
+++ b/graphrag/general/extractor.py
@@ -60,7 +60,7 @@ class Extractor:
if response:
return response
response = self._llm.chat(system, hist, conf)
- response = re.sub(r".*", "", response)
+ response = re.sub(r".*", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)