diff --git a/agent/component/exesql.py b/agent/component/exesql.py index 744e86600..5007c4283 100644 --- a/agent/component/exesql.py +++ b/agent/component/exesql.py @@ -60,7 +60,8 @@ class ExeSQLParam(GenerateParam): class ExeSQL(Generate, ABC): component_name = "ExeSQL" - def _refactor(self,ans): + def _refactor(self, ans): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL) if match: ans = match.group(1) # Query content @@ -78,7 +79,7 @@ class ExeSQL(Generate, ABC): ans = self.get_input() ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else "" ans = self._refactor(ans) - logging.info("db_type: ",self._param.db_type) + logging.info("db_type: ", self._param.db_type) if self._param.db_type in ["mysql", "mariadb"]: db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, port=self._param.port, password=self._param.password) @@ -87,11 +88,11 @@ class ExeSQL(Generate, ABC): port=self._param.port, password=self._param.password) elif self._param.db_type == 'mssql': conn_str = ( - r'DRIVER={ODBC Driver 17 for SQL Server};' - r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';' - r'DATABASE=' + self._param.database + ';' - r'UID=' + self._param.username + ';' - r'PWD=' + self._param.password + r'DRIVER={ODBC Driver 17 for SQL Server};' + r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';' + r'DATABASE=' + self._param.database + ';' + r'UID=' + self._param.username + ';' + r'PWD=' + self._param.password ) db = pyodbc.connect(conn_str) try: @@ -101,12 +102,12 @@ class ExeSQL(Generate, ABC): if not hasattr(self, "_loop"): setattr(self, "_loop", 0) self._loop += 1 - input_list=re.split(r';', ans.replace(r"\n", " ")) + input_list = re.split(r';', ans.replace(r"\n", " ")) sql_res = [] for i in range(len(input_list)): - single_sql=input_list[i] + single_sql = input_list[i] while self._loop <= self._param.loop: - self._loop+=1 + self._loop += 1 if not single_sql: break try: @@ -116,11 +117,12 @@ class ExeSQL(Generate, ABC): sql_res.append({"content": "No record in the database!"}) break if self._param.db_type == 'mssql': - single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),columns = [desc[0] for desc in cursor.description]) + single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n), + columns=[desc[0] for desc in cursor.description]) else: single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)]) single_res.columns = [i[0] for i in cursor.description] - sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")}) + sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")}) break except Exception as e: single_sql = self._regenerate_sql(single_sql, str(e), **kwargs) @@ -133,19 +135,19 @@ class ExeSQL(Generate, ABC): return ExeSQL.be_output("") return pd.DataFrame(sql_res) - def _regenerate_sql(self, failed_sql, error_message,**kwargs): + def _regenerate_sql(self, failed_sql, error_message, **kwargs): prompt = f''' ## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report. ## The original SQL statement is as follows:{failed_sql}. ## The contents of the SQL query error report is as follows:{error_message}. ## Answer only the modified SQL statement. Please do not give any explanation, just answer the code. ''' - self._param.prompt=prompt + self._param.prompt = prompt kwargs_ = deepcopy(kwargs) kwargs_["stream"] = False response = Generate._run(self, [], **kwargs_) try: - regenerated_sql = response.loc[0,"content"] + regenerated_sql = response.loc[0, "content"] return regenerated_sql except Exception as e: logging.error(f"Failed to regenerate SQL: {e}") diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index d9103baec..0e8771e33 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -567,7 +567,7 @@ Requirements: kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] - kwd = re.sub(r".*", "", kwd) + kwd = re.sub(r".*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: return "" return kwd @@ -597,7 +597,7 @@ Requirements: kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] - kwd = re.sub(r".*", "", kwd) + kwd = re.sub(r".*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: return "" return kwd @@ -668,7 +668,7 @@ Output: What's the weather in Rochester on {tomorrow}? ############### """ ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2}) - ans = re.sub(r".*", "", ans) + ans = re.sub(r".*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] @@ -793,7 +793,7 @@ Output: kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): kwd = kwd[0] - kwd = re.sub(r".*", "", kwd) + kwd = re.sub(r".*", "", kwd, flags=re.DOTALL) if kwd.find("**ERROR**") >= 0: raise Exception(kwd) diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 50184002f..d211cf381 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -60,7 +60,7 @@ class Extractor: if response: return response response = self._llm.chat(system, hist, conf) - response = re.sub(r".*", "", response) + response = re.sub(r".*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)