From 026175c8f789ddaac13390bd15dac3aebb33ab89 Mon Sep 17 00:00:00 2001 From: yalei <269870927@qq.com> Date: Fri, 24 May 2024 20:30:48 +0800 Subject: [PATCH] feat: update notion extractor (#3898) Co-authored-by: duyalei <> --- api/core/rag/extractor/notion_extractor.py | 38 +++---- .../unit_tests/core/rag/extractor/__init__.py | 0 .../rag/extractor/test_notion_extractor.py | 102 ++++++++++++++++++ 3 files changed, 118 insertions(+), 22 deletions(-) create mode 100644 api/tests/unit_tests/core/rag/extractor/__init__.py create mode 100644 api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index c40064fd1d..1885ad3aca 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -19,8 +19,12 @@ SEARCH_URL = "https://api.notion.com/v1/search" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" -HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] - +# if user want split by headings, use the corresponding splitter +HEADING_SPLITTER = { + 'heading_1': '# ', + 'heading_2': '## ', + 'heading_3': '### ', +} class NotionExtractor(BaseExtractor): @@ -73,8 +77,7 @@ class NotionExtractor(BaseExtractor): docs.extend(page_text_documents) elif notion_page_type == 'page': page_text_list = self._get_notion_block_data(notion_obj_id) - for page_text in page_text_list: - docs.append(Document(page_content=page_text)) + docs.append(Document(page_content='\n'.join(page_text_list))) else: raise ValueError("notion page type not supported") @@ -96,7 +99,7 @@ class NotionExtractor(BaseExtractor): data = res.json() - database_content_list = [] + database_content = [] if 'results' not in data or data["results"] is None: return [] for result in data["results"]: @@ -131,10 +134,9 @@ class NotionExtractor(BaseExtractor): row_content = row_content + f'{key}:{value_content}\n' else: row_content = row_content + f'{key}:{value}\n' - document = Document(page_content=row_content) - database_content_list.append(document) + database_content.append(row_content) - return database_content_list + return [Document(page_content='\n'.join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] @@ -154,8 +156,6 @@ class NotionExtractor(BaseExtractor): json=query_dict ) data = res.json() - # current block's heading - heading = '' for result in data["results"]: result_type = result["type"] result_obj = result[result_type] @@ -172,8 +172,6 @@ class NotionExtractor(BaseExtractor): if "text" in rich_text: text = rich_text["text"]["content"] cur_result_text_arr.append(text) - if result_type in HEADING_TYPE: - heading = text result_block_id = result["id"] has_children = result["has_children"] @@ -185,11 +183,10 @@ class NotionExtractor(BaseExtractor): cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) - cur_result_text += "\n\n" - if result_type in HEADING_TYPE: - result_lines_arr.append(cur_result_text) + if result_type in HEADING_SPLITTER: + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(f'{heading}\n{cur_result_text}') + result_lines_arr.append(cur_result_text + '\n\n') if data["next_cursor"] is None: break @@ -218,7 +215,6 @@ class NotionExtractor(BaseExtractor): data = res.json() if 'results' not in data or data["results"] is None: break - heading = '' for result in data["results"]: result_type = result["type"] result_obj = result[result_type] @@ -235,8 +231,6 @@ class NotionExtractor(BaseExtractor): text = rich_text["text"]["content"] prefix = "\t" * num_tabs cur_result_text_arr.append(prefix + text) - if result_type in HEADING_TYPE: - heading = text result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] @@ -247,10 +241,10 @@ class NotionExtractor(BaseExtractor): cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) - if result_type in HEADING_TYPE: - result_lines_arr.append(cur_result_text) + if result_type in HEADING_SPLITTER: + result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') else: - result_lines_arr.append(f'{heading}\n{cur_result_text}') + result_lines_arr.append(cur_result_text + '\n\n') if data["next_cursor"] is None: break diff --git a/api/tests/unit_tests/core/rag/extractor/__init__.py b/api/tests/unit_tests/core/rag/extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py new file mode 100644 index 0000000000..b231fe479b --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -0,0 +1,102 @@ +from unittest import mock + +from core.rag.extractor import notion_extractor + +user_id = "user1" +database_id = "database1" +page_id = "page1" + + +extractor = notion_extractor.NotionExtractor( + notion_workspace_id='x', + notion_obj_id='x', + notion_page_type='page', + tenant_id='x', + notion_access_token='x') + + +def _generate_page(page_title: str): + return { + "object": "page", + "id": page_id, + "properties": { + "Page": { + "type": "title", + "title": [ + { + "type": "text", + "text": {"content": page_title}, + "plain_text": page_title + } + ] + } + } + } + + +def _generate_block(block_id: str, block_type: str, block_text: str): + return { + "object": "block", + "id": block_id, + "parent": { + "type": "page_id", + "page_id": page_id + }, + "type": block_type, + "has_children": False, + block_type: { + "rich_text": [ + { + "type": "text", + "text": {"content": block_text}, + "plain_text": block_text, + }] + } + } + + +def _mock_response(data): + response = mock.Mock() + response.status_code = 200 + response.json.return_value = data + return response + + +def _remove_multiple_new_lines(text): + while '\n\n' in text: + text = text.replace("\n\n", "\n") + return text.strip() + + +def test_notion_page(mocker): + texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] + mocked_notion_page = { + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]) + ], + "next_cursor": None + } + mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) + + page_docs = extractor._load_data_as_documents(page_id, "page") + assert len(page_docs) == 1 + content = _remove_multiple_new_lines(page_docs[0].page_content) + assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + + +def test_notion_database(mocker): + page_title_list = ["page1", "page2", "page3"] + mocked_notion_database = { + "object": "list", + "results": [_generate_page(i) for i in page_title_list], + "next_cursor": None + } + mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) + database_docs = extractor._load_data_as_documents(database_id, "database") + assert len(database_docs) == 1 + content = _remove_multiple_new_lines(database_docs[0].page_content) + assert content == '\n'.join([f'Page:{i}' for i in page_title_list])