mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 11:19:02 +08:00
Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
parent
918ebe1620
commit
ba5f8afaa8
@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
|
|||||||
WORKFLOW_CALL_MAX_DEPTH=5
|
WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
|
|
||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
|
|
||||||
|
@ -29,13 +29,13 @@ from .app import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Import auth controllers
|
# Import auth controllers
|
||||||
from .auth import activate, data_source_oauth, login, oauth
|
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
|
||||||
|
|
||||||
# Import billing controllers
|
# Import billing controllers
|
||||||
from .billing import billing
|
from .billing import billing
|
||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
|
67
api/controllers/console/auth/data_source_bearer_auth.py
Normal file
67
api/controllers/console/auth/data_source_bearer_auth.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
|
from ..setup import setup_required
|
||||||
|
from ..wraps import account_initialization_required
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthDataSource(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
# The role of the current user in the table must be admin or owner
|
||||||
|
if not current_user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||||
|
if data_source_api_key_bindings:
|
||||||
|
return {
|
||||||
|
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
|
||||||
|
data_source_api_key_bindings]}
|
||||||
|
return {'settings': []}
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
# The role of the current user in the table must be admin or owner
|
||||||
|
if not current_user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
|
try:
|
||||||
|
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||||
|
except Exception as e:
|
||||||
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, binding_id):
|
||||||
|
# The role of the current user in the table must be admin or owner
|
||||||
|
if not current_user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||||
|
|
||||||
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
||||||
|
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
||||||
|
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
7
api/controllers/console/auth/error.py
Normal file
7
api/controllers/console/auth/error.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||||
|
error_code = 'auth_failed'
|
||||||
|
description = "{message}"
|
||||||
|
code = 500
|
@ -16,7 +16,7 @@ from extensions.ext_database import db
|
|||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceOauthBinding
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
@ -29,9 +29,9 @@ class DataSourceApi(Resource):
|
|||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.query(DataSourceBinding).filter(
|
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.disabled == False
|
DataSourceOauthBinding.disabled == False
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
base_url = request.url_root.rstrip('/')
|
base_url = request.url_root.rstrip('/')
|
||||||
@ -71,7 +71,7 @@ class DataSourceApi(Resource):
|
|||||||
def patch(self, binding_id, action):
|
def patch(self, binding_id, action):
|
||||||
binding_id = str(binding_id)
|
binding_id = str(binding_id)
|
||||||
action = str(action)
|
action = str(action)
|
||||||
data_source_binding = DataSourceBinding.query.filter_by(
|
data_source_binding = DataSourceOauthBinding.query.filter_by(
|
||||||
id=binding_id
|
id=binding_id
|
||||||
).first()
|
).first()
|
||||||
if data_source_binding is None:
|
if data_source_binding is None:
|
||||||
@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
data_source_info = json.loads(document.data_source_info)
|
data_source_info = json.loads(document.data_source_info)
|
||||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
data_source_bindings = DataSourceBinding.query.filter_by(
|
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider='notion',
|
provider='notion',
|
||||||
disabled=False
|
disabled=False
|
||||||
@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
|
|||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, workspace_id, page_id, page_type):
|
||||||
workspace_id = str(workspace_id)
|
workspace_id = str(workspace_id)
|
||||||
page_id = str(page_id)
|
page_id = str(page_id)
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
|
@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
document_model=args['doc_form']
|
document_model=args['doc_form']
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
elif args['info_list']['data_source_type'] == 'website_crawl':
|
||||||
|
website_info_list = args['info_list']['website_info_list']
|
||||||
|
for url in website_info_list['urls']:
|
||||||
|
extract_setting = ExtractSetting(
|
||||||
|
datasource_type="website_crawl",
|
||||||
|
website_info={
|
||||||
|
"provider": website_info_list['provider'],
|
||||||
|
"job_id": website_info_list['job_id'],
|
||||||
|
"url": url,
|
||||||
|
"tenant_id": current_user.current_tenant_id,
|
||||||
|
"mode": 'crawl',
|
||||||
|
"only_main_content": website_info_list['only_main_content']
|
||||||
|
},
|
||||||
|
document_model=args['doc_form']
|
||||||
|
)
|
||||||
|
extract_settings.append(extract_setting)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Data source type not support')
|
raise ValueError('Data source type not support')
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetErrorDocs(Resource):
|
class DatasetErrorDocs(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
document_model=document.doc_form
|
document_model=document.doc_form
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
elif document.data_source_type == 'website_crawl':
|
||||||
|
extract_setting = ExtractSetting(
|
||||||
|
datasource_type="website_crawl",
|
||||||
|
website_info={
|
||||||
|
"provider": data_source_info['provider'],
|
||||||
|
"job_id": data_source_info['job_id'],
|
||||||
|
"url": data_source_info['url'],
|
||||||
|
"tenant_id": current_user.current_tenant_id,
|
||||||
|
"mode": data_source_info['mode'],
|
||||||
|
"only_main_content": data_source_info['only_main_content']
|
||||||
|
},
|
||||||
|
document_model=document.doc_form
|
||||||
|
)
|
||||||
|
extract_settings.append(extract_setting)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Data source type not support')
|
raise ValueError('Data source type not support')
|
||||||
@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource):
|
|||||||
return document
|
return document
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteDocumentSyncApi(DocumentResource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, dataset_id, document_id):
|
||||||
|
"""sync website document."""
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound('Dataset not found.')
|
||||||
|
document_id = str(document_id)
|
||||||
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
if not document:
|
||||||
|
raise NotFound('Document not found.')
|
||||||
|
if document.tenant_id != current_user.current_tenant_id:
|
||||||
|
raise Forbidden('No permission.')
|
||||||
|
if document.data_source_type != 'website_crawl':
|
||||||
|
raise ValueError('Document is not a website document.')
|
||||||
|
# 403 if document is archived
|
||||||
|
if DocumentService.check_archived(document):
|
||||||
|
raise ArchivedDocumentImmutableError()
|
||||||
|
# sync document
|
||||||
|
DocumentService.sync_website_document(dataset_id, document)
|
||||||
|
|
||||||
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||||
api.add_resource(DatasetDocumentListApi,
|
api.add_resource(DatasetDocumentListApi,
|
||||||
'/datasets/<uuid:dataset_id>/documents')
|
'/datasets/<uuid:dataset_id>/documents')
|
||||||
@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui
|
|||||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||||
api.add_resource(DocumentRenameApi,
|
api.add_resource(DocumentRenameApi,
|
||||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
|
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
|
||||||
|
|
||||||
|
api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')
|
||||||
|
@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
|
|||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlError(BaseHTTPException):
|
||||||
|
error_code = 'crawl_failed'
|
||||||
|
description = "{message}"
|
||||||
|
code = 500
|
||||||
|
|
||||||
|
|
||||||
class DatasetInUseError(BaseHTTPException):
|
class DatasetInUseError(BaseHTTPException):
|
||||||
error_code = 'dataset_in_use'
|
error_code = 'dataset_in_use'
|
||||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||||
|
49
api/controllers/console/datasets/website.py
Normal file
49
api/controllers/console/datasets/website.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.error import WebsiteCrawlError
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.website_service import WebsiteService
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
||||||
|
required=True, nullable=True, location='json')
|
||||||
|
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
||||||
|
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
WebsiteService.document_create_args_validate(args)
|
||||||
|
# crawl url
|
||||||
|
try:
|
||||||
|
result = WebsiteService.crawl_url(args)
|
||||||
|
except Exception as e:
|
||||||
|
raise WebsiteCrawlError(str(e))
|
||||||
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlStatusApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, job_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
# get crawl status
|
||||||
|
try:
|
||||||
|
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
||||||
|
except Exception as e:
|
||||||
|
raise WebsiteCrawlError(str(e))
|
||||||
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
||||||
|
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
@ -339,7 +339,7 @@ class IndexingRunner:
|
|||||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||||
-> list[Document]:
|
-> list[Document]:
|
||||||
# load file
|
# load file
|
||||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
data_source_info = dataset_document.data_source_info_dict
|
data_source_info = dataset_document.data_source_info_dict
|
||||||
@ -375,6 +375,23 @@ class IndexingRunner:
|
|||||||
document_model=dataset_document.doc_form
|
document_model=dataset_document.doc_form
|
||||||
)
|
)
|
||||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||||
|
elif dataset_document.data_source_type == 'website_crawl':
|
||||||
|
if (not data_source_info or 'provider' not in data_source_info
|
||||||
|
or 'url' not in data_source_info or 'job_id' not in data_source_info):
|
||||||
|
raise ValueError("no website import info found")
|
||||||
|
extract_setting = ExtractSetting(
|
||||||
|
datasource_type="website_crawl",
|
||||||
|
website_info={
|
||||||
|
"provider": data_source_info['provider'],
|
||||||
|
"job_id": data_source_info['job_id'],
|
||||||
|
"tenant_id": dataset_document.tenant_id,
|
||||||
|
"url": data_source_info['url'],
|
||||||
|
"mode": data_source_info['mode'],
|
||||||
|
"only_main_content": data_source_info['only_main_content']
|
||||||
|
},
|
||||||
|
document_model=dataset_document.doc_form
|
||||||
|
)
|
||||||
|
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||||
# update document status to splitting
|
# update document status to splitting
|
||||||
self._update_document_index_status(
|
self._update_document_index_status(
|
||||||
document_id=dataset_document.id,
|
document_id=dataset_document.id,
|
||||||
|
@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
default=float(credentials.get('presence_penalty', 0)),
|
default=float(credentials.get('presence_penalty', 0)),
|
||||||
min=-2,
|
min=-2,
|
||||||
max=2
|
max=2
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
||||||
|
@ -4,3 +4,4 @@ from enum import Enum
|
|||||||
class DatasourceType(Enum):
|
class DatasourceType(Enum):
|
||||||
FILE = "upload_file"
|
FILE = "upload_file"
|
||||||
NOTION = "notion_import"
|
NOTION = "notion_import"
|
||||||
|
WEBSITE = "website_crawl"
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
|
|||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
website import info.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
job_id: str
|
||||||
|
url: str
|
||||||
|
mode: str
|
||||||
|
tenant_id: str
|
||||||
|
only_main_content: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init__(self, **data) -> None:
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
class ExtractSetting(BaseModel):
|
class ExtractSetting(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for provider response.
|
Model class for provider response.
|
||||||
"""
|
"""
|
||||||
datasource_type: str
|
datasource_type: str
|
||||||
upload_file: UploadFile = None
|
upload_file: Optional[UploadFile]
|
||||||
notion_info: NotionInfo = None
|
notion_info: Optional[NotionInfo]
|
||||||
document_model: str = None
|
website_info: Optional[WebsiteInfo]
|
||||||
|
document_model: Optional[str]
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
def __init__(self, **data) -> None:
|
def __init__(self, **data) -> None:
|
||||||
|
@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
|
|||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||||
|
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
@ -154,5 +155,17 @@ class ExtractProcessor:
|
|||||||
tenant_id=extract_setting.notion_info.tenant_id,
|
tenant_id=extract_setting.notion_info.tenant_id,
|
||||||
)
|
)
|
||||||
return extractor.extract()
|
return extractor.extract()
|
||||||
|
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||||
|
if extract_setting.website_info.provider == 'firecrawl':
|
||||||
|
extractor = FirecrawlWebExtractor(
|
||||||
|
url=extract_setting.website_info.url,
|
||||||
|
job_id=extract_setting.website_info.job_id,
|
||||||
|
tenant_id=extract_setting.website_info.tenant_id,
|
||||||
|
mode=extract_setting.website_info.mode,
|
||||||
|
only_main_content=extract_setting.website_info.only_main_content
|
||||||
|
)
|
||||||
|
return extractor.extract()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
||||||
|
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
|
|
||||||
|
class FirecrawlApp:
|
||||||
|
def __init__(self, api_key=None, base_url=None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||||
|
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||||
|
raise ValueError('No API key provided')
|
||||||
|
|
||||||
|
def scrape_url(self, url, params=None) -> dict:
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}'
|
||||||
|
}
|
||||||
|
json_data = {'url': url}
|
||||||
|
if params:
|
||||||
|
json_data.update(params)
|
||||||
|
response = requests.post(
|
||||||
|
f'{self.base_url}/v0/scrape',
|
||||||
|
headers=headers,
|
||||||
|
json=json_data
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
response = response.json()
|
||||||
|
if response['success'] == True:
|
||||||
|
data = response['data']
|
||||||
|
return {
|
||||||
|
'title': data.get('metadata').get('title'),
|
||||||
|
'description': data.get('metadata').get('description'),
|
||||||
|
'source_url': data.get('metadata').get('sourceURL'),
|
||||||
|
'markdown': data.get('markdown')
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||||
|
|
||||||
|
elif response.status_code in [402, 409, 500]:
|
||||||
|
error_message = response.json().get('error', 'Unknown error occurred')
|
||||||
|
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||||
|
else:
|
||||||
|
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||||
|
|
||||||
|
def crawl_url(self, url, params=None) -> str:
|
||||||
|
start_time = time.time()
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
json_data = {'url': url}
|
||||||
|
if params:
|
||||||
|
json_data.update(params)
|
||||||
|
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
job_id = response.json().get('jobId')
|
||||||
|
return job_id
|
||||||
|
else:
|
||||||
|
self._handle_error(response, 'start crawl job')
|
||||||
|
|
||||||
|
def check_crawl_status(self, job_id) -> dict:
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
crawl_status_response = response.json()
|
||||||
|
if crawl_status_response.get('status') == 'completed':
|
||||||
|
total = crawl_status_response.get('total', 0)
|
||||||
|
if total == 0:
|
||||||
|
raise Exception('Failed to check crawl status. Error: No page found')
|
||||||
|
data = crawl_status_response.get('data', [])
|
||||||
|
url_data_list = []
|
||||||
|
for item in data:
|
||||||
|
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||||
|
url_data = {
|
||||||
|
'title': item.get('metadata').get('title'),
|
||||||
|
'description': item.get('metadata').get('description'),
|
||||||
|
'source_url': item.get('metadata').get('sourceURL'),
|
||||||
|
'markdown': item.get('markdown')
|
||||||
|
}
|
||||||
|
url_data_list.append(url_data)
|
||||||
|
if url_data_list:
|
||||||
|
file_key = 'website_files/' + job_id + '.txt'
|
||||||
|
if storage.exists(file_key):
|
||||||
|
storage.delete(file_key)
|
||||||
|
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||||
|
return {
|
||||||
|
'status': 'completed',
|
||||||
|
'total': crawl_status_response.get('total'),
|
||||||
|
'current': crawl_status_response.get('current'),
|
||||||
|
'data': url_data_list
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
'status': crawl_status_response.get('status'),
|
||||||
|
'total': crawl_status_response.get('total'),
|
||||||
|
'current': crawl_status_response.get('current'),
|
||||||
|
'data': []
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
self._handle_error(response, 'check crawl status')
|
||||||
|
|
||||||
|
def _prepare_headers(self):
|
||||||
|
return {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}'
|
||||||
|
}
|
||||||
|
|
||||||
|
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||||
|
for attempt in range(retries):
|
||||||
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
if response.status_code == 502:
|
||||||
|
time.sleep(backoff_factor * (2 ** attempt))
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||||
|
for attempt in range(retries):
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
if response.status_code == 502:
|
||||||
|
time.sleep(backoff_factor * (2 ** attempt))
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_error(self, response, action):
|
||||||
|
error_message = response.json().get('error', 'Unknown error occurred')
|
||||||
|
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||||
|
|
||||||
|
|
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from core.rag.extractor.extractor_base import BaseExtractor
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from services.website_service import WebsiteService
|
||||||
|
|
||||||
|
|
||||||
|
class FirecrawlWebExtractor(BaseExtractor):
|
||||||
|
"""
|
||||||
|
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to scrape.
|
||||||
|
api_key: The API key for Firecrawl.
|
||||||
|
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
|
||||||
|
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
job_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
mode: str = 'crawl',
|
||||||
|
only_main_content: bool = False
|
||||||
|
):
|
||||||
|
"""Initialize with url, api_key, base_url and mode."""
|
||||||
|
self._url = url
|
||||||
|
self.job_id = job_id
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.mode = mode
|
||||||
|
self.only_main_content = only_main_content
|
||||||
|
|
||||||
|
def extract(self) -> list[Document]:
|
||||||
|
"""Extract content from the URL."""
|
||||||
|
documents = []
|
||||||
|
if self.mode == 'crawl':
|
||||||
|
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||||
|
if crawl_data is None:
|
||||||
|
return []
|
||||||
|
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||||
|
metadata={
|
||||||
|
'source_url': crawl_data.get('source_url'),
|
||||||
|
'description': crawl_data.get('description'),
|
||||||
|
'title': crawl_data.get('title')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
documents.append(document)
|
||||||
|
elif self.mode == 'scrape':
|
||||||
|
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||||
|
self.only_main_content)
|
||||||
|
|
||||||
|
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||||
|
metadata={
|
||||||
|
'source_url': scrape_data.get('source_url'),
|
||||||
|
'description': scrape_data.get('description'),
|
||||||
|
'title': scrape_data.get('title')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
documents.append(document)
|
||||||
|
return documents
|
@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Document as DocumentModel
|
from models.dataset import Document as DocumentModel
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceOauthBinding
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == tenant_id,
|
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
64
api/libs/bearer_data_source.py
Normal file
64
api/libs/bearer_data_source.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# [REVIEW] Implement if Needed? Do we need a new type of data source
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from api.models.source import DataSourceBearerBinding
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
||||||
|
class BearerDataSource:
|
||||||
|
def __init__(self, api_key: str, api_base_url: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base_url = api_base_url
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_bearer_data_source(self):
|
||||||
|
"""
|
||||||
|
Validate the data source
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FireCrawlDataSource(BearerDataSource):
|
||||||
|
def validate_bearer_data_source(self):
|
||||||
|
TEST_CRAWL_SITE_URL = "https://www.google.com"
|
||||||
|
FIRECRAWL_API_VERSION = "v0"
|
||||||
|
|
||||||
|
test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"url": TEST_CRAWL_SITE_URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(test_api_endpoint, headers=headers, json=data)
|
||||||
|
|
||||||
|
return response.json().get("status") == "success"
|
||||||
|
|
||||||
|
def save_credentials(self):
|
||||||
|
# save data source binding
|
||||||
|
data_source_binding = DataSourceBearerBinding.query.filter(
|
||||||
|
db.and_(
|
||||||
|
DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
|
||||||
|
DataSourceBearerBinding.provider == 'firecrawl',
|
||||||
|
DataSourceBearerBinding.endpoint_url == self.api_base_url,
|
||||||
|
DataSourceBearerBinding.bearer_key == self.api_key
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if data_source_binding:
|
||||||
|
data_source_binding.disabled = False
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
new_data_source_binding = DataSourceBearerBinding(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider='firecrawl',
|
||||||
|
endpoint_url=self.api_base_url,
|
||||||
|
bearer_key=self.api_key
|
||||||
|
)
|
||||||
|
db.session.add(new_data_source_binding)
|
||||||
|
db.session.commit()
|
@ -4,7 +4,7 @@ import requests
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceOauthBinding
|
||||||
|
|
||||||
|
|
||||||
class OAuthDataSource:
|
class OAuthDataSource:
|
||||||
@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
'total': len(pages)
|
'total': len(pages)
|
||||||
}
|
}
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.access_token == access_token
|
DataSourceOauthBinding.access_token == access_token
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
new_data_source_binding = DataSourceBinding(
|
new_data_source_binding = DataSourceOauthBinding(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
source_info=source_info,
|
source_info=source_info,
|
||||||
@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
'total': len(pages)
|
'total': len(pages)
|
||||||
}
|
}
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.access_token == access_token
|
DataSourceOauthBinding.access_token == access_token
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
data_source_binding.disabled = False
|
data_source_binding.disabled = False
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
new_data_source_binding = DataSourceBinding(
|
new_data_source_binding = DataSourceOauthBinding(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
source_info=source_info,
|
source_info=source_info,
|
||||||
@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
|
|
||||||
def sync_data_source(self, binding_id: str):
|
def sync_data_source(self, binding_id: str):
|
||||||
# save data source binding
|
# save data source binding
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.id == binding_id,
|
DataSourceOauthBinding.id == binding_id,
|
||||||
DataSourceBinding.disabled == False
|
DataSourceOauthBinding.disabled == False
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if data_source_binding:
|
if data_source_binding:
|
||||||
|
@ -0,0 +1,67 @@
|
|||||||
|
"""add-api-key-auth-binding
|
||||||
|
|
||||||
|
Revision ID: 7b45942e39bb
|
||||||
|
Revises: 47cc7df8c4f3
|
||||||
|
Create Date: 2024-05-14 07:31:29.702766
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '7b45942e39bb'
|
||||||
|
down_revision = '4e99a8df00ff'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('data_source_api_key_auth_bindings',
|
||||||
|
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.StringUUID(), nullable=False),
|
||||||
|
sa.Column('category', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('credentials', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
|
||||||
|
batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||||
|
batch_op.drop_index('source_info_idx')
|
||||||
|
|
||||||
|
op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
|
||||||
|
|
||||||
|
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||||
|
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('source_info_idx', postgresql_using='gin')
|
||||||
|
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||||
|
|
||||||
|
op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
|
||||||
|
|
||||||
|
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('source_info_idx', ['source_info'], unique=False)
|
||||||
|
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
|
||||||
|
batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
|
||||||
|
|
||||||
|
op.drop_table('data_source_api_key_auth_bindings')
|
||||||
|
# ### end Alembic commands ###
|
@ -270,7 +270,7 @@ class Document(db.Model):
|
|||||||
255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||||
doc_language = db.Column(db.String(255), nullable=True)
|
doc_language = db.Column(db.String(255), nullable=True)
|
||||||
|
|
||||||
DATA_SOURCES = ['upload_file', 'notion_import']
|
DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_status(self):
|
def display_status(self):
|
||||||
@ -322,7 +322,7 @@ class Document(db.Model):
|
|||||||
'created_at': file_detail.created_at.timestamp()
|
'created_at': file_detail.created_at.timestamp()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
elif self.data_source_type == 'notion_import':
|
elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
|
||||||
return json.loads(self.data_source_info)
|
return json.loads(self.data_source_info)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import StringUUID
|
from models import StringUUID
|
||||||
|
|
||||||
|
|
||||||
class DataSourceBinding(db.Model):
|
class DataSourceOauthBinding(db.Model):
|
||||||
__tablename__ = 'data_source_bindings'
|
__tablename__ = 'data_source_oauth_bindings'
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
|
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
|
||||||
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
|
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
|
||||||
@ -20,3 +22,33 @@ class DataSourceBinding(db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceApiKeyAuthBinding(db.Model):
|
||||||
|
__tablename__ = 'data_source_api_key_auth_bindings'
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'),
|
||||||
|
db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'),
|
||||||
|
db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'),
|
||||||
|
)
|
||||||
|
|
||||||
|
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||||
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
|
category = db.Column(db.String(255), nullable=False)
|
||||||
|
provider = db.Column(db.String(255), nullable=False)
|
||||||
|
credentials = db.Column(db.Text, nullable=True) # JSON
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'tenant_id': self.tenant_id,
|
||||||
|
'category': self.category,
|
||||||
|
'provider': self.provider,
|
||||||
|
'credentials': json.loads(self.credentials),
|
||||||
|
'created_at': self.created_at.timestamp(),
|
||||||
|
'updated_at': self.updated_at.timestamp(),
|
||||||
|
'disabled': self.disabled
|
||||||
|
}
|
||||||
|
@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000"
|
|||||||
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
|
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
|
||||||
CODE_EXECUTION_API_KEY="dify-sandbox"
|
CODE_EXECUTION_API_KEY="dify-sandbox"
|
||||||
|
|
||||||
|
FIRECRAWL_API_KEY = "fc-"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
|
0
api/services/auth/__init__.py
Normal file
0
api/services/auth/__init__.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthBase(ABC):
|
||||||
|
def __init__(self, credentials: dict):
|
||||||
|
self.credentials = credentials
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_credentials(self):
|
||||||
|
raise NotImplementedError
|
14
api/services/auth/api_key_auth_factory.py
Normal file
14
api/services/auth/api_key_auth_factory.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
from services.auth.firecrawl import FirecrawlAuth
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthFactory:
|
||||||
|
|
||||||
|
def __init__(self, provider: str, credentials: dict):
|
||||||
|
if provider == 'firecrawl':
|
||||||
|
self.auth = FirecrawlAuth(credentials)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid provider')
|
||||||
|
|
||||||
|
def validate_credentials(self):
|
||||||
|
return self.auth.validate_credentials()
|
70
api/services/auth/api_key_auth_service.py
Normal file
70
api/services/auth/api_key_auth_service.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.source import DataSourceApiKeyAuthBinding
|
||||||
|
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_provider_auth_list(tenant_id: str) -> list:
|
||||||
|
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||||
|
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||||
|
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||||
|
).all()
|
||||||
|
return data_source_api_key_bindings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_provider_auth(tenant_id: str, args: dict):
|
||||||
|
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||||
|
if auth_result:
|
||||||
|
# Encrypt the api key
|
||||||
|
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||||
|
args['credentials']['config']['api_key'] = api_key
|
||||||
|
|
||||||
|
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||||
|
data_source_api_key_binding.tenant_id = tenant_id
|
||||||
|
data_source_api_key_binding.category = args['category']
|
||||||
|
data_source_api_key_binding.provider = args['provider']
|
||||||
|
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||||
|
db.session.add(data_source_api_key_binding)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||||
|
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||||
|
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||||
|
DataSourceApiKeyAuthBinding.category == category,
|
||||||
|
DataSourceApiKeyAuthBinding.provider == provider,
|
||||||
|
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||||
|
).first()
|
||||||
|
if not data_source_api_key_bindings:
|
||||||
|
return None
|
||||||
|
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||||
|
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||||
|
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||||
|
DataSourceApiKeyAuthBinding.id == binding_id
|
||||||
|
).first()
|
||||||
|
if data_source_api_key_binding:
|
||||||
|
db.session.delete(data_source_api_key_binding)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_api_key_auth_args(cls, args):
|
||||||
|
if 'category' not in args or not args['category']:
|
||||||
|
raise ValueError('category is required')
|
||||||
|
if 'provider' not in args or not args['provider']:
|
||||||
|
raise ValueError('provider is required')
|
||||||
|
if 'credentials' not in args or not args['credentials']:
|
||||||
|
raise ValueError('credentials is required')
|
||||||
|
if not isinstance(args['credentials'], dict):
|
||||||
|
raise ValueError('credentials must be a dictionary')
|
||||||
|
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||||
|
raise ValueError('auth_type is required')
|
||||||
|
|
56
api/services/auth/firecrawl.py
Normal file
56
api/services/auth/firecrawl.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||||
|
|
||||||
|
|
||||||
|
class FirecrawlAuth(ApiKeyAuthBase):
|
||||||
|
def __init__(self, credentials: dict):
|
||||||
|
super().__init__(credentials)
|
||||||
|
auth_type = credentials.get('auth_type')
|
||||||
|
if auth_type != 'bearer':
|
||||||
|
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||||
|
self.api_key = credentials.get('config').get('api_key', None)
|
||||||
|
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError('No API key provided')
|
||||||
|
|
||||||
|
def validate_credentials(self):
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
options = {
|
||||||
|
'url': 'https://example.com',
|
||||||
|
'crawlerOptions': {
|
||||||
|
'excludes': [],
|
||||||
|
'includes': [],
|
||||||
|
'limit': 1
|
||||||
|
},
|
||||||
|
'pageOptions': {
|
||||||
|
'onlyMainContent': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._handle_error(response)
|
||||||
|
|
||||||
|
def _prepare_headers(self):
|
||||||
|
return {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {self.api_key}'
|
||||||
|
}
|
||||||
|
|
||||||
|
def _post_request(self, url, data, headers):
|
||||||
|
return requests.post(url, headers=headers, json=data)
|
||||||
|
|
||||||
|
def _handle_error(self, response):
|
||||||
|
if response.status_code in [402, 409, 500]:
|
||||||
|
error_message = response.json().get('error', 'Unknown error occurred')
|
||||||
|
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||||
|
else:
|
||||||
|
if response.text:
|
||||||
|
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||||
|
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||||
|
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
@ -31,7 +31,7 @@ from models.dataset import (
|
|||||||
DocumentSegment,
|
DocumentSegment,
|
||||||
)
|
)
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceOauthBinding
|
||||||
from services.errors.account import NoPermissionError
|
from services.errors.account import NoPermissionError
|
||||||
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
||||||
from services.errors.document import DocumentIndexingError
|
from services.errors.document import DocumentIndexingError
|
||||||
@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
|||||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||||
|
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||||
|
|
||||||
|
|
||||||
class DatasetService:
|
class DatasetService:
|
||||||
@ -508,18 +509,40 @@ class DocumentService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def retry_document(dataset_id: str, documents: list[Document]):
|
def retry_document(dataset_id: str, documents: list[Document]):
|
||||||
for document in documents:
|
for document in documents:
|
||||||
|
# add retry flag
|
||||||
|
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||||
|
cache_result = redis_client.get(retry_indexing_cache_key)
|
||||||
|
if cache_result is not None:
|
||||||
|
raise ValueError("Document is being retried, please try again later")
|
||||||
# retry document indexing
|
# retry document indexing
|
||||||
document.indexing_status = 'waiting'
|
document.indexing_status = 'waiting'
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# add retry flag
|
|
||||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
|
||||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||||
# trigger async task
|
# trigger async task
|
||||||
document_ids = [document.id for document in documents]
|
document_ids = [document.id for document in documents]
|
||||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
def sync_website_document(dataset_id: str, document: Document):
|
||||||
|
# add sync flag
|
||||||
|
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
|
||||||
|
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||||
|
if cache_result is not None:
|
||||||
|
raise ValueError("Document is being synced, please try again later")
|
||||||
|
# sync document indexing
|
||||||
|
document.indexing_status = 'waiting'
|
||||||
|
data_source_info = document.data_source_info_dict
|
||||||
|
data_source_info['mode'] = 'scrape'
|
||||||
|
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
redis_client.setex(sync_indexing_cache_key, 600, 1)
|
||||||
|
|
||||||
|
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||||
|
@staticmethod
|
||||||
def get_documents_position(dataset_id):
|
def get_documents_position(dataset_id):
|
||||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||||
if document:
|
if document:
|
||||||
@ -545,6 +568,9 @@ class DocumentService:
|
|||||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
count = count + len(notion_info['pages'])
|
count = count + len(notion_info['pages'])
|
||||||
|
elif document_data["data_source"]["type"] == "website_crawl":
|
||||||
|
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||||
|
count = len(website_info['urls'])
|
||||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||||
if count > batch_upload_limit:
|
if count > batch_upload_limit:
|
||||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||||
@ -683,12 +709,12 @@ class DocumentService:
|
|||||||
exist_document[data_source_info['notion_page_id']] = document.id
|
exist_document[data_source_info['notion_page_id']] = document.id
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info['workspace_id']
|
workspace_id = notion_info['workspace_id']
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
@ -717,6 +743,28 @@ class DocumentService:
|
|||||||
# delete not selected documents
|
# delete not selected documents
|
||||||
if len(exist_document) > 0:
|
if len(exist_document) > 0:
|
||||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||||
|
elif document_data["data_source"]["type"] == "website_crawl":
|
||||||
|
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||||
|
urls = website_info['urls']
|
||||||
|
for url in urls:
|
||||||
|
data_source_info = {
|
||||||
|
'url': url,
|
||||||
|
'provider': website_info['provider'],
|
||||||
|
'job_id': website_info['job_id'],
|
||||||
|
'only_main_content': website_info.get('only_main_content', False),
|
||||||
|
'mode': 'crawl',
|
||||||
|
}
|
||||||
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||||
|
document_data["data_source"]["type"],
|
||||||
|
document_data["doc_form"],
|
||||||
|
document_data["doc_language"],
|
||||||
|
data_source_info, created_from, position,
|
||||||
|
account, url, batch)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.flush()
|
||||||
|
document_ids.append(document.id)
|
||||||
|
documents.append(document)
|
||||||
|
position += 1
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# trigger async task
|
# trigger async task
|
||||||
@ -818,12 +866,12 @@ class DocumentService:
|
|||||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info['workspace_id']
|
workspace_id = notion_info['workspace_id']
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
@ -835,6 +883,17 @@ class DocumentService:
|
|||||||
"notion_page_icon": page['page_icon'],
|
"notion_page_icon": page['page_icon'],
|
||||||
"type": page['type']
|
"type": page['type']
|
||||||
}
|
}
|
||||||
|
elif document_data["data_source"]["type"] == "website_crawl":
|
||||||
|
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||||
|
urls = website_info['urls']
|
||||||
|
for url in urls:
|
||||||
|
data_source_info = {
|
||||||
|
'url': url,
|
||||||
|
'provider': website_info['provider'],
|
||||||
|
'job_id': website_info['job_id'],
|
||||||
|
'only_main_content': website_info.get('only_main_content', False),
|
||||||
|
'mode': 'crawl',
|
||||||
|
}
|
||||||
document.data_source_type = document_data["data_source"]["type"]
|
document.data_source_type = document_data["data_source"]["type"]
|
||||||
document.data_source_info = json.dumps(data_source_info)
|
document.data_source_info = json.dumps(data_source_info)
|
||||||
document.name = file_name
|
document.name = file_name
|
||||||
@ -873,6 +932,9 @@ class DocumentService:
|
|||||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
count = count + len(notion_info['pages'])
|
count = count + len(notion_info['pages'])
|
||||||
|
elif document_data["data_source"]["type"] == "website_crawl":
|
||||||
|
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||||
|
count = len(website_info['urls'])
|
||||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||||
if count > batch_upload_limit:
|
if count > batch_upload_limit:
|
||||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||||
@ -973,6 +1035,10 @@ class DocumentService:
|
|||||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||||
'notion_info_list']:
|
'notion_info_list']:
|
||||||
raise ValueError("Notion source info is required")
|
raise ValueError("Notion source info is required")
|
||||||
|
if args['data_source']['type'] == 'website_crawl':
|
||||||
|
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||||
|
'website_info_list']:
|
||||||
|
raise ValueError("Website source info is required")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_rule_args_validate(cls, args: dict):
|
def process_rule_args_validate(cls, args: dict):
|
||||||
|
171
api/services/website_service.py
Normal file
171
api/services/website_service.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def document_create_args_validate(cls, args: dict):
|
||||||
|
if 'url' not in args or not args['url']:
|
||||||
|
raise ValueError('url is required')
|
||||||
|
if 'options' not in args or not args['options']:
|
||||||
|
raise ValueError('options is required')
|
||||||
|
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||||
|
raise ValueError('limit is required')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def crawl_url(cls, args: dict) -> dict:
|
||||||
|
provider = args.get('provider')
|
||||||
|
url = args.get('url')
|
||||||
|
options = args.get('options')
|
||||||
|
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||||
|
'website',
|
||||||
|
provider)
|
||||||
|
if provider == 'firecrawl':
|
||||||
|
# decrypt api_key
|
||||||
|
api_key = encrypter.decrypt_token(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
token=credentials.get('config').get('api_key')
|
||||||
|
)
|
||||||
|
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||||
|
base_url=credentials.get('config').get('base_url', None))
|
||||||
|
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||||
|
only_main_content = options.get('only_main_content', False)
|
||||||
|
if not crawl_sub_pages:
|
||||||
|
params = {
|
||||||
|
'crawlerOptions': {
|
||||||
|
"includes": [],
|
||||||
|
"excludes": [],
|
||||||
|
"generateImgAltText": True,
|
||||||
|
"limit": 1,
|
||||||
|
'returnOnlyUrls': False,
|
||||||
|
'pageOptions': {
|
||||||
|
'onlyMainContent': only_main_content,
|
||||||
|
"includeHtml": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||||
|
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||||
|
params = {
|
||||||
|
'crawlerOptions': {
|
||||||
|
"includes": includes if includes else [],
|
||||||
|
"excludes": excludes if excludes else [],
|
||||||
|
"generateImgAltText": True,
|
||||||
|
"limit": options.get('limit', 1),
|
||||||
|
'returnOnlyUrls': False,
|
||||||
|
'pageOptions': {
|
||||||
|
'onlyMainContent': only_main_content,
|
||||||
|
"includeHtml": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.get('max_depth'):
|
||||||
|
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||||
|
job_id = firecrawl_app.crawl_url(url, params)
|
||||||
|
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||||
|
time = str(datetime.datetime.now().timestamp())
|
||||||
|
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||||
|
return {
|
||||||
|
'status': 'active',
|
||||||
|
'job_id': job_id
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid provider')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||||
|
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||||
|
'website',
|
||||||
|
provider)
|
||||||
|
if provider == 'firecrawl':
|
||||||
|
# decrypt api_key
|
||||||
|
api_key = encrypter.decrypt_token(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
token=credentials.get('config').get('api_key')
|
||||||
|
)
|
||||||
|
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||||
|
base_url=credentials.get('config').get('base_url', None))
|
||||||
|
result = firecrawl_app.check_crawl_status(job_id)
|
||||||
|
crawl_status_data = {
|
||||||
|
'status': result.get('status', 'active'),
|
||||||
|
'job_id': job_id,
|
||||||
|
'total': result.get('total', 0),
|
||||||
|
'current': result.get('current', 0),
|
||||||
|
'data': result.get('data', [])
|
||||||
|
}
|
||||||
|
if crawl_status_data['status'] == 'completed':
|
||||||
|
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||||
|
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||||
|
if start_time:
|
||||||
|
end_time = datetime.datetime.now().timestamp()
|
||||||
|
time_consuming = abs(end_time - float(start_time))
|
||||||
|
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||||
|
redis_client.delete(website_crawl_time_cache_key)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid provider')
|
||||||
|
return crawl_status_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||||
|
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||||
|
'website',
|
||||||
|
provider)
|
||||||
|
if provider == 'firecrawl':
|
||||||
|
file_key = 'website_files/' + job_id + '.txt'
|
||||||
|
if storage.exists(file_key):
|
||||||
|
data = storage.load_once(file_key)
|
||||||
|
if data:
|
||||||
|
data = json.loads(data.decode('utf-8'))
|
||||||
|
else:
|
||||||
|
# decrypt api_key
|
||||||
|
api_key = encrypter.decrypt_token(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
token=credentials.get('config').get('api_key')
|
||||||
|
)
|
||||||
|
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||||
|
base_url=credentials.get('config').get('base_url', None))
|
||||||
|
result = firecrawl_app.check_crawl_status(job_id)
|
||||||
|
if result.get('status') != 'completed':
|
||||||
|
raise ValueError('Crawl job is not completed')
|
||||||
|
data = result.get('data')
|
||||||
|
if data:
|
||||||
|
for item in data:
|
||||||
|
if item.get('source_url') == url:
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid provider')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||||
|
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||||
|
'website',
|
||||||
|
provider)
|
||||||
|
if provider == 'firecrawl':
|
||||||
|
# decrypt api_key
|
||||||
|
api_key = encrypter.decrypt_token(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
token=credentials.get('config').get('api_key')
|
||||||
|
)
|
||||||
|
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||||
|
base_url=credentials.get('config').get('base_url', None))
|
||||||
|
params = {
|
||||||
|
'pageOptions': {
|
||||||
|
'onlyMainContent': only_main_content,
|
||||||
|
"includeHtml": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = firecrawl_app.scrape_url(url, params)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid provider')
|
@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceOauthBinding
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue='dataset')
|
@shared_task(queue='dataset')
|
||||||
@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
page_id = data_source_info['notion_page_id']
|
page_id = data_source_info['notion_page_id']
|
||||||
page_type = data_source_info['type']
|
page_type = data_source_info['type']
|
||||||
page_edited_time = data_source_info['last_edited_time']
|
page_edited_time = data_source_info['last_edited_time']
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
DataSourceBinding.tenant_id == document.tenant_id,
|
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||||
DataSourceBinding.provider == 'notion',
|
DataSourceOauthBinding.provider == 'notion',
|
||||||
DataSourceBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
|
90
api/tasks/sync_website_document_indexing_task.py
Normal file
90
api/tasks/sync_website_document_indexing_task.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
from celery import shared_task
|
||||||
|
|
||||||
|
from core.indexing_runner import IndexingRunner
|
||||||
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue='dataset')
|
||||||
|
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||||
|
"""
|
||||||
|
Async process document
|
||||||
|
:param dataset_id:
|
||||||
|
:param document_id:
|
||||||
|
|
||||||
|
Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
|
||||||
|
"""
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
|
sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
|
||||||
|
# check document limit
|
||||||
|
features = FeatureService.get_features(dataset.tenant_id)
|
||||||
|
try:
|
||||||
|
if features.billing.enabled:
|
||||||
|
vector_space = features.vector_space
|
||||||
|
if 0 < vector_space.limit <= vector_space.size:
|
||||||
|
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||||
|
"your subscription.")
|
||||||
|
except Exception as e:
|
||||||
|
document = db.session.query(Document).filter(
|
||||||
|
Document.id == document_id,
|
||||||
|
Document.dataset_id == dataset_id
|
||||||
|
).first()
|
||||||
|
if document:
|
||||||
|
document.indexing_status = 'error'
|
||||||
|
document.error = str(e)
|
||||||
|
document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
redis_client.delete(sync_indexing_cache_key)
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
|
||||||
|
document = db.session.query(Document).filter(
|
||||||
|
Document.id == document_id,
|
||||||
|
Document.dataset_id == dataset_id
|
||||||
|
).first()
|
||||||
|
try:
|
||||||
|
if document:
|
||||||
|
# clean old data
|
||||||
|
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||||
|
|
||||||
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||||
|
if segments:
|
||||||
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
# delete from vector index
|
||||||
|
index_processor.clean(dataset, index_node_ids)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
db.session.delete(segment)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
document.indexing_status = 'parsing'
|
||||||
|
document.processing_started_at = datetime.datetime.utcnow()
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
indexing_runner = IndexingRunner()
|
||||||
|
indexing_runner.run([document])
|
||||||
|
redis_client.delete(sync_indexing_cache_key)
|
||||||
|
except Exception as ex:
|
||||||
|
document.indexing_status = 'error'
|
||||||
|
document.error = str(ex)
|
||||||
|
document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
logging.info(click.style(str(ex), fg='yellow'))
|
||||||
|
redis_client.delete(sync_indexing_cache_key)
|
||||||
|
pass
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||||
|
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||||
|
url = "https://firecrawl.dev"
|
||||||
|
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
|
||||||
|
base_url = 'https://api.firecrawl.dev'
|
||||||
|
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||||
|
base_url=base_url)
|
||||||
|
params = {
|
||||||
|
'crawlerOptions': {
|
||||||
|
"includes": [],
|
||||||
|
"excludes": [],
|
||||||
|
"generateImgAltText": True,
|
||||||
|
"maxDepth": 1,
|
||||||
|
"limit": 1,
|
||||||
|
'returnOnlyUrls': False,
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mocked_firecrawl = {
|
||||||
|
"jobId": "test",
|
||||||
|
}
|
||||||
|
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||||
|
job_id = firecrawl_app.crawl_url(url, params)
|
||||||
|
print(job_id)
|
||||||
|
assert isinstance(job_id, str)
|
0
api/tests/unit_tests/oss/__init__.py
Normal file
0
api/tests/unit_tests/oss/__init__.py
Normal file
0
api/tests/unit_tests/oss/local/__init__.py
Normal file
0
api/tests/unit_tests/oss/local/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user