Feat/firecrawl data source (#5232)

Co-authored-by: Nicolas <nicolascamara29@gmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Jyong 2024-06-15 02:46:02 +08:00 committed by GitHub
parent 918ebe1620
commit ba5f8afaa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1174 additions and 64 deletions

View File

@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_CALL_MAX_DEPTH=5
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200

View File

@ -29,13 +29,13 @@ from .app import (
) )
# Import auth controllers # Import auth controllers
from .auth import activate, data_source_oauth, login, oauth from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
# Import billing controllers # Import billing controllers
from .billing import billing from .billing import billing
# Import datasets controllers # Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (

View File

@ -0,0 +1,67 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..setup import setup_required
from ..wraps import account_initialization_required
class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings:
return {
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
data_source_api_key_bindings]}
return {'settings': []}
class ApiKeyAuthDataSourceBinding(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {'result': 'success'}, 200
class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {'result': 'success'}, 200
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')

View File

@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500

View File

@ -16,7 +16,7 @@ from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceOauthBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -29,9 +29,9 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.query(DataSourceBinding).filter( data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.disabled == False DataSourceOauthBinding.disabled == False
).all() ).all()
base_url = request.url_root.rstrip('/') base_url = request.url_root.rstrip('/')
@ -71,7 +71,7 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action): def patch(self, binding_id, action):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action) action = str(action)
data_source_binding = DataSourceBinding.query.filter_by( data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id id=binding_id
).first() ).first()
if data_source_binding is None: if data_source_binding is None:
@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id']) exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages # get all authorized pages
data_source_bindings = DataSourceBinding.query.filter_by( data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider='notion', provider='notion',
disabled=False disabled=False
@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id) workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:

View File

@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource):
document_model=args['doc_form'] document_model=args['doc_form']
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args['info_list']['website_info_list']
for url in website_info_list['urls']:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": website_info_list['provider'],
"job_id": website_info_list['job_id'],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": 'crawl',
"only_main_content": website_info_list['only_main_content']
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.") raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource): class DatasetErrorDocs(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
document_model=document.doc_form document_model=document.doc_form
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == 'website_crawl':
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"url": data_source_info['url'],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource):
return document return document
class WebsiteDocumentSyncApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
if document.tenant_id != current_user.current_tenant_id:
raise Forbidden('No permission.')
if document.data_source_type != 'website_crawl':
raise ValueError('Document is not a website document.')
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id, document)
return {'result': 'success'}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi, api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents') '/datasets/<uuid:dataset_id>/documents')
@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry') api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
api.add_resource(DocumentRenameApi, api.add_resource(DocumentRenameApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')

View File

@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
code = 400 code = 400
class WebsiteCrawlError(BaseHTTPException):
error_code = 'crawl_failed'
description = "{message}"
code = 500
class DatasetInUseError(BaseHTTPException): class DatasetInUseError(BaseHTTPException):
error_code = 'dataset_in_use' error_code = 'dataset_in_use'
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."

View File

@ -0,0 +1,49 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'],
required=True, nullable=True, location='json')
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
try:
result = WebsiteService.crawl_url(args)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
class WebsiteCrawlStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
args = parser.parse_args()
# get crawl status
try:
result = WebsiteService.get_crawl_status(job_id, args['provider'])
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')

View File

@ -339,7 +339,7 @@ class IndexingRunner:
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
-> list[Document]: -> list[Document]:
# load file # load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]: if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
return [] return []
data_source_info = dataset_document.data_source_info_dict data_source_info = dataset_document.data_source_info_dict
@ -375,6 +375,23 @@ class IndexingRunner:
document_model=dataset_document.doc_form document_model=dataset_document.doc_form
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'website_crawl':
if (not data_source_info or 'provider' not in data_source_info
or 'url' not in data_source_info or 'job_id' not in data_source_info):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info['url'],
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,

View File

@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
default=float(credentials.get('presence_penalty', 0)), default=float(credentials.get('presence_penalty', 0)),
min=-2, min=-2,
max=2 max=2
) ),
], ],
pricing=PriceConfig( pricing=PriceConfig(
input=Decimal(cred_with_endpoint.get('input_price', 0)), input=Decimal(cred_with_endpoint.get('input_price', 0)),

View File

@ -4,3 +4,4 @@ from enum import Enum
class DatasourceType(Enum): class DatasourceType(Enum):
FILE = "upload_file" FILE = "upload_file"
NOTION = "notion_import" NOTION = "notion_import"
WEBSITE = "website_crawl"

View File

@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from models.dataset import Document from models.dataset import Document
@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
super().__init__(**data) super().__init__(**data)
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
job_id: str
url: str
mode: str
tenant_id: str
only_main_content: bool = False
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
class ExtractSetting(BaseModel): class ExtractSetting(BaseModel):
""" """
Model class for provider response. Model class for provider response.
""" """
datasource_type: str datasource_type: str
upload_file: UploadFile = None upload_file: Optional[UploadFile]
notion_info: NotionInfo = None notion_info: Optional[NotionInfo]
document_model: str = None website_info: Optional[WebsiteInfo]
document_model: Optional[str]
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None: def __init__(self, **data) -> None:

View File

@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
@ -154,5 +155,17 @@ class ExtractProcessor:
tenant_id=extract_setting.notion_info.tenant_id, tenant_id=extract_setting.notion_info.tenant_id,
) )
return extractor.extract() return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
if extract_setting.website_info.provider == 'firecrawl':
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content
)
return extractor.extract()
else:
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
else: else:
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

View File

@ -0,0 +1,132 @@
import json
import time
import requests
from extensions.ext_storage import storage
class FirecrawlApp:
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url or 'https://api.firecrawl.dev'
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
raise ValueError('No API key provided')
def scrape_url(self, url, params=None) -> dict:
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
json_data = {'url': url}
if params:
json_data.update(params)
response = requests.post(
f'{self.base_url}/v0/scrape',
headers=headers,
json=json_data
)
if response.status_code == 200:
response = response.json()
if response['success'] == True:
data = response['data']
return {
'title': data.get('metadata').get('title'),
'description': data.get('metadata').get('description'),
'source_url': data.get('metadata').get('sourceURL'),
'markdown': data.get('markdown')
}
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
else:
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
def crawl_url(self, url, params=None) -> str:
start_time = time.time()
headers = self._prepare_headers()
json_data = {'url': url}
if params:
json_data.update(params)
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
if response.status_code == 200:
job_id = response.json().get('jobId')
return job_id
else:
self._handle_error(response, 'start crawl job')
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get('status') == 'completed':
total = crawl_status_response.get('total', 0)
if total == 0:
raise Exception('Failed to check crawl status. Error: No page found')
data = crawl_status_response.get('data', [])
url_data_list = []
for item in data:
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
url_data = {
'title': item.get('metadata').get('title'),
'description': item.get('metadata').get('description'),
'source_url': item.get('metadata').get('sourceURL'),
'markdown': item.get('markdown')
}
url_data_list.append(url_data)
if url_data_list:
file_key = 'website_files/' + job_id + '.txt'
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
return {
'status': 'completed',
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': url_data_list
}
else:
return {
'status': crawl_status_response.get('status'),
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': []
}
else:
self._handle_error(response, 'check crawl status')
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
for attempt in range(retries):
response = requests.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
else:
return response
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
for attempt in range(retries):
response = requests.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
else:
return response
return response
def _handle_error(self, response, action):
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')

View File

@ -0,0 +1,60 @@
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
class FirecrawlWebExtractor(BaseExtractor):
"""
Crawl and scrape websites and return content in clean llm-ready markdown.
Args:
url: The URL to scrape.
api_key: The API key for Firecrawl.
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = 'crawl',
only_main_content: bool = False
):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == 'crawl':
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
if crawl_data is None:
return []
document = Document(page_content=crawl_data.get('markdown', ''),
metadata={
'source_url': crawl_data.get('source_url'),
'description': crawl_data.get('description'),
'title': crawl_data.get('title')
}
)
documents.append(document)
elif self.mode == 'scrape':
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
self.only_main_content)
document = Document(page_content=scrape_data.get('markdown', ''),
metadata={
'source_url': scrape_data.get('source_url'),
'description': scrape_data.get('description'),
'title': scrape_data.get('title')
}
)
documents.append(document)
return documents

View File

@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document as DocumentModel from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding from models.source import DataSourceOauthBinding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
@classmethod @classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == tenant_id, DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
) )
).first() ).first()

View File

@ -0,0 +1,64 @@
# [REVIEW] Implement if Needed? Do we need a new type of data source
from abc import abstractmethod
import requests
from api.models.source import DataSourceBearerBinding
from flask_login import current_user
from extensions.ext_database import db
class BearerDataSource:
def __init__(self, api_key: str, api_base_url: str):
self.api_key = api_key
self.api_base_url = api_base_url
@abstractmethod
def validate_bearer_data_source(self):
"""
Validate the data source
"""
class FireCrawlDataSource(BearerDataSource):
def validate_bearer_data_source(self):
TEST_CRAWL_SITE_URL = "https://www.google.com"
FIRECRAWL_API_VERSION = "v0"
test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
data = {
"url": TEST_CRAWL_SITE_URL,
}
response = requests.get(test_api_endpoint, headers=headers, json=data)
return response.json().get("status") == "success"
def save_credentials(self):
# save data source binding
data_source_binding = DataSourceBearerBinding.query.filter(
db.and_(
DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
DataSourceBearerBinding.provider == 'firecrawl',
DataSourceBearerBinding.endpoint_url == self.api_base_url,
DataSourceBearerBinding.bearer_key == self.api_key
)
).first()
if data_source_binding:
data_source_binding.disabled = False
db.session.commit()
else:
new_data_source_binding = DataSourceBearerBinding(
tenant_id=current_user.current_tenant_id,
provider='firecrawl',
endpoint_url=self.api_base_url,
bearer_key=self.api_key
)
db.session.add(new_data_source_binding)
db.session.commit()

View File

@ -4,7 +4,7 @@ import requests
from flask_login import current_user from flask_login import current_user
from extensions.ext_database import db from extensions.ext_database import db
from models.source import DataSourceBinding from models.source import DataSourceOauthBinding
class OAuthDataSource: class OAuthDataSource:
@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages) 'total': len(pages)
} }
# save data source binding # save data source binding
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.access_token == access_token DataSourceOauthBinding.access_token == access_token
) )
).first() ).first()
if data_source_binding: if data_source_binding:
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.disabled = False data_source_binding.disabled = False
db.session.commit() db.session.commit()
else: else:
new_data_source_binding = DataSourceBinding( new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
access_token=access_token, access_token=access_token,
source_info=source_info, source_info=source_info,
@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages) 'total': len(pages)
} }
# save data source binding # save data source binding
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.access_token == access_token DataSourceOauthBinding.access_token == access_token
) )
).first() ).first()
if data_source_binding: if data_source_binding:
@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding.disabled = False data_source_binding.disabled = False
db.session.commit() db.session.commit()
else: else:
new_data_source_binding = DataSourceBinding( new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
access_token=access_token, access_token=access_token,
source_info=source_info, source_info=source_info,
@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.id == binding_id, DataSourceOauthBinding.id == binding_id,
DataSourceBinding.disabled == False DataSourceOauthBinding.disabled == False
) )
).first() ).first()
if data_source_binding: if data_source_binding:

View File

@ -0,0 +1,67 @@
"""add-api-key-auth-binding
Revision ID: 7b45942e39bb
Revises: 47cc7df8c4f3
Create Date: 2024-05-14 07:31:29.702766
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '7b45942e39bb'
down_revision = '4e99a8df00ff'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('data_source_api_key_auth_bindings',
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.StringUUID(), nullable=False),
sa.Column('category', sa.String(length=255), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('credentials', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
)
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
batch_op.drop_index('source_binding_tenant_id_idx')
batch_op.drop_index('source_info_idx')
op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
batch_op.drop_index('source_info_idx', postgresql_using='gin')
batch_op.drop_index('source_binding_tenant_id_idx')
op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
batch_op.create_index('source_info_idx', ['source_info'], unique=False)
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
op.drop_table('data_source_api_key_auth_bindings')
# ### end Alembic commands ###

View File

@ -270,7 +270,7 @@ class Document(db.Model):
255), nullable=False, server_default=db.text("'text_model'::character varying")) 255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = db.Column(db.String(255), nullable=True) doc_language = db.Column(db.String(255), nullable=True)
DATA_SOURCES = ['upload_file', 'notion_import'] DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
@property @property
def display_status(self): def display_status(self):
@ -322,7 +322,7 @@ class Document(db.Model):
'created_at': file_detail.created_at.timestamp() 'created_at': file_detail.created_at.timestamp()
} }
} }
elif self.data_source_type == 'notion_import': elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
return json.loads(self.data_source_info) return json.loads(self.data_source_info)
return {} return {}

View File

@ -1,11 +1,13 @@
import json
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db from extensions.ext_database import db
from models import StringUUID from models import StringUUID
class DataSourceBinding(db.Model): class DataSourceOauthBinding(db.Model):
__tablename__ = 'data_source_bindings' __tablename__ = 'data_source_oauth_bindings'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='source_binding_pkey'), db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('source_binding_tenant_id_idx', 'tenant_id'), db.Index('source_binding_tenant_id_idx', 'tenant_id'),
@ -20,3 +22,33 @@ class DataSourceBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
class DataSourceApiKeyAuthBinding(db.Model):
__tablename__ = 'data_source_api_key_auth_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'),
db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'),
db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
category = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
credentials = db.Column(db.Text, nullable=True) # JSON
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
def to_dict(self):
return {
'id': self.id,
'tenant_id': self.tenant_id,
'category': self.category,
'provider': self.provider,
'credentials': json.loads(self.credentials),
'created_at': self.created_at.timestamp(),
'updated_at': self.updated_at.timestamp(),
'disabled': self.disabled
}

View File

@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000"
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
CODE_EXECUTION_API_KEY="dify-sandbox" CODE_EXECUTION_API_KEY="dify-sandbox"
FIRECRAWL_API_KEY = "fc-"
[tool.poetry] [tool.poetry]
name = "dify-api" name = "dify-api"

View File

View File

@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
raise NotImplementedError

View File

@ -0,0 +1,14 @@
from services.auth.firecrawl import FirecrawlAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == 'firecrawl':
self.auth = FirecrawlAuth(credentials)
else:
raise ValueError('Invalid provider')
def validate_credentials(self):
return self.auth.validate_credentials()

View File

@ -0,0 +1,70 @@
import json
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).all()
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
args['credentials']['config']['api_key'] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args['category']
data_source_api_key_binding.provider = args['provider']
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).first()
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id
).first()
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if 'category' not in args or not args['category']:
raise ValueError('category is required')
if 'provider' not in args or not args['provider']:
raise ValueError('provider is required')
if 'credentials' not in args or not args['credentials']:
raise ValueError('credentials is required')
if not isinstance(args['credentials'], dict):
raise ValueError('credentials must be a dictionary')
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
raise ValueError('auth_type is required')

View File

@ -0,0 +1,56 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get('auth_type')
if auth_type != 'bearer':
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
self.api_key = credentials.get('config').get('api_key', None)
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
if not self.api_key:
raise ValueError('No API key provided')
def validate_credentials(self):
headers = self._prepare_headers()
options = {
'url': 'https://example.com',
'crawlerOptions': {
'excludes': [],
'includes': [],
'limit': 1
},
'pageOptions': {
'onlyMainContent': True
}
}
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
else:
if response.text:
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')

View File

@ -31,7 +31,7 @@ from models.dataset import (
DocumentSegment, DocumentSegment,
) )
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding from models.source import DataSourceOauthBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
class DatasetService: class DatasetService:
@ -508,18 +509,40 @@ class DocumentService:
@staticmethod @staticmethod
def retry_document(dataset_id: str, documents: list[Document]): def retry_document(dataset_id: str, documents: list[Document]):
for document in documents: for document in documents:
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
cache_result = redis_client.get(retry_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing # retry document indexing
document.indexing_status = 'waiting' document.indexing_status = 'waiting'
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1) redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task # trigger async task
document_ids = [document.id for document in documents] document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids) retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod @staticmethod
def sync_website_document(dataset_id: str, document: Document):
# add sync flag
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
cache_result = redis_client.get(sync_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = 'waiting'
data_source_info = document.data_source_info_dict
data_source_info['mode'] = 'scrape'
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if document: if document:
@ -545,6 +568,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list: for notion_info in notion_info_list:
count = count + len(notion_info['pages']) count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit: if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -683,12 +709,12 @@ class DocumentService:
exist_document[data_source_info['notion_page_id']] = document.id exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id'] workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
@ -717,6 +743,28 @@ class DocumentService:
# delete not selected documents # delete not selected documents
if len(exist_document) > 0: if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id) clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit() db.session.commit()
# trigger async task # trigger async task
@ -818,12 +866,12 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id'] workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
@ -835,6 +883,17 @@ class DocumentService:
"notion_page_icon": page['page_icon'], "notion_page_icon": page['page_icon'],
"type": page['type'] "type": page['type']
} }
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document.data_source_type = document_data["data_source"]["type"] document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info) document.data_source_info = json.dumps(data_source_info)
document.name = file_name document.name = file_name
@ -873,6 +932,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list: for notion_info in notion_info_list:
count = count + len(notion_info['pages']) count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit: if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -973,6 +1035,10 @@ class DocumentService:
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']: 'notion_info_list']:
raise ValueError("Notion source info is required") raise ValueError("Notion source info is required")
if args['data_source']['type'] == 'website_crawl':
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'website_info_list']:
raise ValueError("Website source info is required")
@classmethod @classmethod
def process_rule_args_validate(cls, args: dict): def process_rule_args_validate(cls, args: dict):

View File

@ -0,0 +1,171 @@
import datetime
import json
from flask_login import current_user
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if 'url' not in args or not args['url']:
raise ValueError('url is required')
if 'options' not in args or not args['options']:
raise ValueError('options is required')
if 'limit' not in args['options'] or not args['options']['limit']:
raise ValueError('limit is required')
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get('provider')
url = args.get('url')
options = args.get('options')
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
crawl_sub_pages = options.get('crawl_sub_pages', False)
only_main_content = options.get('only_main_content', False)
if not crawl_sub_pages:
params = {
'crawlerOptions': {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"limit": 1,
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
else:
includes = options.get('includes').split(',') if options.get('includes') else []
excludes = options.get('excludes').split(',') if options.get('excludes') else []
params = {
'crawlerOptions': {
"includes": includes if includes else [],
"excludes": excludes if excludes else [],
"generateImgAltText": True,
"limit": options.get('limit', 1),
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
if options.get('max_depth'):
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f'website_crawl_{job_id}'
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {
'status': 'active',
'job_id': job_id
}
else:
raise ValueError('Invalid provider')
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
'status': result.get('status', 'active'),
'job_id': job_id,
'total': result.get('total', 0),
'current': result.get('current', 0),
'data': result.get('data', [])
}
if crawl_status_data['status'] == 'completed':
website_crawl_time_cache_key = f'website_crawl_{job_id}'
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
else:
raise ValueError('Invalid provider')
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
file_key = 'website_files/' + job_id + '.txt'
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode('utf-8'))
else:
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get('status') != 'completed':
raise ValueError('Crawl job is not completed')
data = result.get('data')
if data:
for item in data:
if item.get('source_url') == url:
return item
return None
else:
raise ValueError('Invalid provider')
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
params = {
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
result = firecrawl_app.scrape_url(url, params)
return result
else:
raise ValueError('Invalid provider')

View File

@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceBinding from models.source import DataSourceOauthBinding
@shared_task(queue='dataset') @shared_task(queue='dataset')
@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info['notion_page_id'] page_id = data_source_info['notion_page_id']
page_type = data_source_info['type'] page_type = data_source_info['type']
page_edited_time = data_source_info['last_edited_time'] page_edited_time = data_source_info['last_edited_time']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion', DataSourceOauthBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:

View File

@ -0,0 +1,90 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
Async process document
:param dataset_id:
:param document_id:
Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
try:
if document:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = 'error'
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg='yellow'))
redis_client.delete(sync_indexing_cache_key)
pass
end_at = time.perf_counter()
logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))

View File

@ -0,0 +1,33 @@
import os
from unittest import mock
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.models.document import Document
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
url = "https://firecrawl.dev"
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
base_url = 'https://api.firecrawl.dev'
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=base_url)
params = {
'crawlerOptions': {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"maxDepth": 1,
"limit": 1,
'returnOnlyUrls': False,
}
}
mocked_firecrawl = {
"jobId": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
print(job_id)
assert isinstance(job_id, str)

View File