Common: Support postgreSQL database as the metadata db. (#2357)

https://github.com/infiniflow/ragflow/issues/2356

### What problem does this PR solve?

As title

### Type of change

- [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Fachuan Bai 2024-09-12 15:12:39 +08:00 committed by GitHub
parent ba834aee26
commit f8e9a0590f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 12 deletions

View File

@ -18,18 +18,19 @@ import os
import sys import sys
import typing import typing
import operator import operator
from enum import Enum
from functools import wraps from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin from flask_login import UserMixin
from playhouse.migrate import MySQLMigrator, migrate from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from peewee import ( from peewee import (
BigIntegerField, BooleanField, CharField, BigIntegerField, BooleanField, CharField,
CompositeKey, IntegerField, TextField, FloatField, DateTimeField, CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata Field, Model, Metadata
) )
from playhouse.pool import PooledMySQLDatabase from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
from api.utils.log_utils import getLogger from api.utils.log_utils import getLogger
from api import utils from api import utils
@ -58,8 +59,13 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
"write_access"} "write_access"}
class TextFieldType(Enum):
MYSQL = 'LONGTEXT'
POSTGRES = 'TEXT'
class LongTextField(TextField): class LongTextField(TextField):
field_type = 'LONGTEXT' field_type = TextFieldType[DATABASE_TYPE.upper()].value
class JSONField(LongTextField): class JSONField(LongTextField):
@ -266,18 +272,69 @@ class JsonSerializedField(SerializedField):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs) object_pairs_hook=object_pairs_hook, **kwargs)
class PooledDatabase(Enum):
MYSQL = PooledMySQLDatabase
POSTGRES = PooledPostgresqlDatabase
class DatabaseMigrator(Enum):
MYSQL = MySQLMigrator
POSTGRES = PostgresqlMigrator
@singleton @singleton
class BaseDataBase: class BaseDataBase:
def __init__(self): def __init__(self):
database_config = DATABASE.copy() database_config = DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase( self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
db_name, **database_config) stat_logger.info('init database on cluster mode successfully')
stat_logger.info('init mysql database on cluster mode successfully')
class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB
class DatabaseLock: def lock(self):
cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire postgres lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(
f'postgres lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'postgres lock {self.lock_name} does not exist')
def __enter__(self):
if isinstance(self.db, PostgresDatabaseLock):
self.lock()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PostgresDatabaseLock):
self.unlock()
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)
return magic
class MysqlDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None): def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name self.lock_name = lock_name
self.timeout = int(timeout) self.timeout = int(timeout)
@ -325,8 +382,13 @@ class DatabaseLock:
return magic return magic
class DatabaseLock(Enum):
MYSQL = MysqlDatabaseLock
POSTGRES = PostgresDatabaseLock
DB = BaseDataBase().database_connection DB = BaseDataBase().database_connection
DB.lock = DatabaseLock DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
def close_connection(): def close_connection():
@ -918,7 +980,7 @@ class CanvasTemplate(DataBaseModel):
def migrate_db(): def migrate_db():
with DB.transaction(): with DB.transaction():
migrator = MySQLMigrator(DB) migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
try: try:
migrate( migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",

View File

@ -17,6 +17,8 @@ import operator
from functools import reduce from functools import reduce
from typing import Dict, Type, Union from typing import Dict, Type, Union
from playhouse.pool import PooledMySQLDatabase
from api.utils import current_timestamp, timestamp_to_date from api.utils import current_timestamp, timestamp_to_date
from api.db.db_models import DB, DataBaseModel from api.db.db_models import DB, DataBaseModel
@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
with DB.atomic(): with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size]) query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict: if replace_on_conflict:
if isinstance(DB, PooledMySQLDatabase):
query = query.on_conflict(preserve=preserve) query = query.on_conflict(preserve=preserve)
else:
query = query.on_conflict(conflict_target="id", preserve=preserve)
query.execute() query.execute()

View File

@ -164,7 +164,8 @@ RANDOM_INSTANCE_ID = get_base_config(
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
DATABASE = decrypt_database_config(name="mysql") DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# Switch # Switch
# upload # upload

View File

@ -9,6 +9,14 @@ mysql:
port: 3306 port: 3306
max_connections: 100 max_connections: 100
stale_timeout: 30 stale_timeout: 30
postgres:
name: 'rag_flow'
user: 'rag_flow'
password: 'infini_rag_flow'
host: 'postgres'
port: 5432
max_connections: 100
stale_timeout: 30
minio: minio:
user: 'rag_flow' user: 'rag_flow'
password: 'infini_rag_flow' password: 'infini_rag_flow'

View File

@ -31,6 +31,7 @@ Flask==3.0.3
Flask_Cors==5.0.0 Flask_Cors==5.0.0
Flask_Login==0.6.3 Flask_Login==0.6.3
flask_session==0.8.0 flask_session==0.8.0
psycopg2==2.9.9
google_search_results==2.4.2 google_search_results==2.4.2
groq==0.9.0 groq==0.9.0
hanziconv==0.3.2 hanziconv==0.3.2