build python version rag-flow (#21)

* clean rust version project

* clean rust version project

* build python version rag-flow
This commit is contained in:
KevinHuSh 2024-01-15 08:46:22 +08:00 committed by GitHub
parent db8cae3f1e
commit 30791976d5
123 changed files with 4985 additions and 4239 deletions

View File

@ -1,9 +0,0 @@
# Database
HOST=127.0.0.1
PORT=8000
DATABASE_URL="postgresql://infiniflow:infiniflow@localhost/docgpt"
# S3 Storage
MINIO_HOST="127.0.0.1:9000"
MINIO_USR="infiniflow"
MINIO_PWD="infiniflow_docgpt"

View File

@ -1,42 +0,0 @@
[package]
name = "doc_gpt"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix-web = "4.3.1"
actix-rt = "2.8.0"
actix-files = "0.6.2"
actix-multipart = "0.4"
actix-session = { version = "0.5" }
actix-identity = { version = "0.4" }
actix-web-httpauth = { version = "0.6" }
actix-ws = "0.2.5"
uuid = { version = "1.6.1", features = [
"v4",
"fast-rng",
"macro-diagnostics",
] }
thiserror = "1.0"
postgres = "0.19.7"
sea-orm = { version = "0.12.9", features = ["sqlx-postgres", "runtime-tokio-native-tls", "macros"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1.0"
tracing-subscriber = "0.3.18"
dotenvy = "0.15.7"
listenfd = "1.0.1"
chrono = "0.4.31"
migration = { path = "./migration" }
minio = "0.1.0"
futures-util = "0.3.29"
actix-multipart-extract = "0.1.5"
regex = "1.10.2"
tokio = { version = "1.35.1", features = ["rt", "time", "macros"] }
[[bin]]
name = "doc_gpt"
[workspace]
members = [".", "migration"]

0
python/conf/mapping.json → conf/mapping.json Executable file → Normal file
View File

30
conf/private.pem Normal file
View File

@ -0,0 +1,30 @@
-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,EFF8327C41E531AD
7jdPFDAA6fiTzOIU7XGzKuT324JKZEcK5vBRJqBkA5XO6ENN1wLdhh3zQbl1Ejfv
KMSUIgbtQEJB4bvOzS//okbZa1vCNYuTS/NGcpKUnhqdOmAL3hl/kOtOLLjTZrwo
3KX8iujLH7wQ64GxArtpUuaFq1k0whN1BB5RGJp3IO/L6pMpSWVRKO+JPUrD1Ujr
XA/LUKQJaZtXVUVOYPtIwbyqPsh93QBetJnRwwV3gNOwGpcX2jDpyTxDUkLJCPPg
6Hw0pwlQEd8A11sjxCBbASwLeJO1L0w69QiX9chyOkZ+sfDsVpPt/wf1NexA7Cdj
9uifJ4JGbby39QD6mInZGtnRzQRdafjuXlBR2I0Qa7fBRu8QsfhmLbWZfWno7j08
4bAAoqB1vRNfSu8LVJXdEEh/HKuwu11pgRr5eH8WQ3hJg+Y2k7zDHpp1VaHL7/Kn
S+aN5bhQ4Xt0Ujdi1+rsmNchnF6LWsDezHWJeWUM6X7dJnqIBl8oCyghbghT8Tyw
aEKWXc2+7FsP5yd0NfG3PFYOLdLgfI43pHTAv5PEQ47w9r1XOwfblKKBUDEzaput
T3t5wQ6wxdyhRxeO4arCHfe/i+j3fzvhlwgbuwrmrkWGWSS86eMTaoGM8+uUrHv0
6TbU0tj6DKKUslVk1dCHh9TnmNsXZuLJkceZF38PSKNxhzudU8OTtzhS0tFL91HX
vo7N+XdiGMs8oOSpjE6RPlhFhVAKGJpXwBj/vXLLcmzesA7ZB2kYtFKMIdsUQpls
PE/4K5PEX2d8pxA5zxo0HleA1YjW8i5WEcDQThZQzj2sWvg06zSjenVFrbCm9Bro
hFpAB/3zJHxdRN2MpNpvK35WITy1aDUdX1WdyrlcRtIE5ssFTSoxSj9ibbDZ78+z
gtbw/MUi6vU6Yz1EjvoYu/bmZAHt9Aagcxw6k58fjO2cEB9njK7xbbiZUSwpJhEe
U/PxK+SdOU/MmGKeqdgqSfhJkq0vhacvsEjFGRAfivSCHkL0UjhObU+rSJ3g1RMO
oukAev6TOAwbTKVWjg3/EX+pl/zorAgaPNYFX64TSH4lE3VjeWApITb9Z5C/sVxR
xW6hU9qyjzWYWY+91y16nkw1l7VQvWHUZwV7QzTScC2BOzDVpeqY1KiYJxgoo6sX
ZCqR5oh4vToG4W8ZrRyauwUaZJ3r+zhAgm+6n6TJQNwFEl0muji+1nPl32EiFsRs
qR6CtuhUOVQM4VnILDwFJfuGYRFtKzQgvseLNU4ZqAVqQj8l4ARGAP2P1Au/uUKy
oGzI7a+b5MvRHuvkxPAclOgXgX/8yyOLaBg+mgaqv9h2JIJD28PzouFl3BajRaVB
7GWTnROJYhX5SuX/g585SLRKoQUtK0WhdJCjTRfyRJPwfdppgdTbWO99R4G+ir02
JQdSkZf2vmZRXenPNTEPDOUY6nVN6sUuBjmtOwoUF194ODgpYB6IaHqK08sa1pUh
1mZyxitHdPbygePTe20XWMZFoK2knAqN0JPPbbNjCqiVV+7oqQAnkDIutspu9t2m
ny3jefFmNozbblQMghLUrq+x9wOEgvS76Sqvq3DG/2BkLzJF3MNkvw==
-----END RSA PRIVATE KEY-----

9
conf/public.pem Normal file
View File

@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/
z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp
2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOO
UEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVK
RNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK
6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs
2wIDAQAB
-----END PUBLIC KEY-----

28
conf/service_conf.yaml Normal file
View File

@ -0,0 +1,28 @@
authentication:
client:
switch: false
http_app_key:
http_secret_key:
site:
switch: false
permission:
switch: false
component: false
dataset: false
ragflow:
# you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
host: 127.0.0.1
http_port: 9380
database:
name: 'rag_flow'
user: 'root'
passwd: 'infini_rag_flow'
host: '123.60.95.134'
port: 5455
max_connections: 100
stale_timeout: 30
oauth:
github:
client_id: 302129228f0d96055bee
secret_key: e518e55ccfcdfcae8996afc40f110e9c95f14fc4
url: https://github.com/login/oauth/access_token

View File

@ -1,21 +0,0 @@
# Version of Elastic products
STACK_VERSION=8.11.3
# Set the cluster name
CLUSTER_NAME=docgpt
# Port to expose Elasticsearch HTTP API to the host
ES_PORT=9200
# Port to expose Kibana to the host
KIBANA_PORT=6601
# Increase or decrease based on the available host memory (in bytes)
MEM_LIMIT=4073741824
POSTGRES_USER=root
POSTGRES_PASSWORD=infiniflow_docgpt
POSTGRES_DB=docgpt
MINIO_USER=infiniflow
MINIO_PASSWORD=infiniflow_docgpt

View File

@ -1,7 +1,7 @@
version: '2.2' version: '2.2'
services: services:
es01: es01:
container_name: docgpt-es-01 container_name: ragflow-es-01
image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION} image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION}
volumes: volumes:
- esdata01:/usr/share/elasticsearch/data - esdata01:/usr/share/elasticsearch/data
@ -20,14 +20,14 @@ services:
soft: -1 soft: -1
hard: -1 hard: -1
networks: networks:
- docgpt - ragflow
restart: always restart: always
kibana: kibana:
depends_on: depends_on:
- es01 - es01
image: docker.elastic.co/kibana/kibana:${STACK_VERSION} image: docker.elastic.co/kibana/kibana:${STACK_VERSION}
container_name: docgpt-kibana container_name: ragflow-kibana
volumes: volumes:
- kibanadata:/usr/share/kibana/data - kibanadata:/usr/share/kibana/data
ports: ports:
@ -37,26 +37,39 @@ services:
- ELASTICSEARCH_HOSTS=http://es01:9200 - ELASTICSEARCH_HOSTS=http://es01:9200
mem_limit: ${MEM_LIMIT} mem_limit: ${MEM_LIMIT}
networks: networks:
- docgpt - ragflow
postgres: mysql:
image: postgres image: mysql:5.7.18
container_name: docgpt-postgres container_name: ragflow-mysql
environment: environment:
- POSTGRES_USER=${POSTGRES_USER} - MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - TZ="Asia/Shanghai"
- POSTGRES_DB=${POSTGRES_DB} command:
--max_connections=1000
--character-set-server=utf8mb4
--collation-server=utf8mb4_general_ci
--default-authentication-plugin=mysql_native_password
--tls_version="TLSv1.2,TLSv1.3"
--init-file /data/application/init.sql
ports: ports:
- 5455:5432 - ${MYSQL_PORT}:3306
volumes: volumes:
- pg_data:/var/lib/postgresql/data - mysql_data:/var/lib/mysql
- ./init.sql:/data/application/init.sql
networks: networks:
- docgpt - ragflow
healthcheck:
test: [ "CMD-SHELL", "curl --silent localhost:3306 >/dev/null || exit 1" ]
interval: 10s
timeout: 10s
retries: 3
restart: always restart: always
minio: minio:
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
container_name: docgpt-minio container_name: ragflow-minio
command: server --console-address ":9001" /data command: server --console-address ":9001" /data
ports: ports:
- 9000:9000 - 9000:9000
@ -67,7 +80,7 @@ services:
volumes: volumes:
- minio_data:/data - minio_data:/data
networks: networks:
- docgpt - ragflow
restart: always restart: always
@ -76,11 +89,11 @@ volumes:
driver: local driver: local
kibanadata: kibanadata:
driver: local driver: local
pg_data: mysql_data:
driver: local driver: local
minio_data: minio_data:
driver: local driver: local
networks: networks:
docgpt: ragflow:
driver: bridge driver: bridge

2
docker/init.sql Normal file
View File

@ -0,0 +1,2 @@
CREATE DATABASE IF NOT EXISTS rag_flow;
USE rag_flow;

View File

@ -1,20 +0,0 @@
[package]
name = "migration"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
name = "migration"
path = "src/lib.rs"
[dependencies]
async-std = { version = "1", features = ["attributes", "tokio1"] }
chrono = "0.4.31"
[dependencies.sea-orm-migration]
version = "0.12.0"
features = [
"runtime-tokio-rustls", # `ASYNC_RUNTIME` feature
"sqlx-postgres", # `DATABASE_DRIVER` feature
]

View File

@ -1,41 +0,0 @@
# Running Migrator CLI
- Generate a new migration file
```sh
cargo run -- generate MIGRATION_NAME
```
- Apply all pending migrations
```sh
cargo run
```
```sh
cargo run -- up
```
- Apply first 10 pending migrations
```sh
cargo run -- up -n 10
```
- Rollback last applied migrations
```sh
cargo run -- down
```
- Rollback last 10 applied migrations
```sh
cargo run -- down -n 10
```
- Drop all tables from the database, then reapply all migrations
```sh
cargo run -- fresh
```
- Rollback all applied migrations, then reapply all migrations
```sh
cargo run -- refresh
```
- Rollback all applied migrations
```sh
cargo run -- reset
```
- Check the status of all migrations
```sh
cargo run -- status
```

View File

@ -1,12 +0,0 @@
pub use sea_orm_migration::prelude::*;
mod m20220101_000001_create_table;
pub struct Migrator;
#[async_trait::async_trait]
impl MigratorTrait for Migrator {
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
vec![Box::new(m20220101_000001_create_table::Migration)]
}
}

View File

@ -1,440 +0,0 @@
use sea_orm_migration::prelude::*;
use chrono::{ FixedOffset, Utc };
#[allow(dead_code)]
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.create_table(
Table::create()
.table(UserInfo::Table)
.if_not_exists()
.col(
ColumnDef::new(UserInfo::Uid)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(UserInfo::Email).string().not_null())
.col(ColumnDef::new(UserInfo::Nickname).string().not_null())
.col(ColumnDef::new(UserInfo::AvatarBase64).string())
.col(ColumnDef::new(UserInfo::ColorScheme).string().default("dark"))
.col(ColumnDef::new(UserInfo::ListStyle).string().default("list"))
.col(ColumnDef::new(UserInfo::Language).string().default("chinese"))
.col(ColumnDef::new(UserInfo::Password).string().not_null())
.col(
ColumnDef::new(UserInfo::LastLoginAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
)
.col(
ColumnDef::new(UserInfo::CreatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(
ColumnDef::new(UserInfo::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(UserInfo::IsDeleted).boolean().default(false))
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(TagInfo::Table)
.if_not_exists()
.col(
ColumnDef::new(TagInfo::Tid)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(TagInfo::Uid).big_integer().not_null())
.col(ColumnDef::new(TagInfo::TagName).string().not_null())
.col(ColumnDef::new(TagInfo::Regx).string())
.col(ColumnDef::new(TagInfo::Color).tiny_unsigned().default(1))
.col(ColumnDef::new(TagInfo::Icon).tiny_unsigned().default(1))
.col(ColumnDef::new(TagInfo::FolderId).big_integer())
.col(
ColumnDef::new(TagInfo::CreatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(
ColumnDef::new(TagInfo::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(TagInfo::IsDeleted).boolean().default(false))
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(Tag2Doc::Table)
.if_not_exists()
.col(
ColumnDef::new(Tag2Doc::Id)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(Tag2Doc::TagId).big_integer())
.col(ColumnDef::new(Tag2Doc::Did).big_integer())
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(Kb2Doc::Table)
.if_not_exists()
.col(
ColumnDef::new(Kb2Doc::Id)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(Kb2Doc::KbId).big_integer())
.col(ColumnDef::new(Kb2Doc::Did).big_integer())
.col(ColumnDef::new(Kb2Doc::KbProgress).float().default(0))
.col(ColumnDef::new(Kb2Doc::KbProgressMsg).string().default(""))
.col(
ColumnDef::new(Kb2Doc::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(Kb2Doc::IsDeleted).boolean().default(false))
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(Dialog2Kb::Table)
.if_not_exists()
.col(
ColumnDef::new(Dialog2Kb::Id)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(Dialog2Kb::DialogId).big_integer())
.col(ColumnDef::new(Dialog2Kb::KbId).big_integer())
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(Doc2Doc::Table)
.if_not_exists()
.col(
ColumnDef::new(Doc2Doc::Id)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(Doc2Doc::ParentId).big_integer())
.col(ColumnDef::new(Doc2Doc::Did).big_integer())
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(KbInfo::Table)
.if_not_exists()
.col(
ColumnDef::new(KbInfo::KbId)
.big_integer()
.auto_increment()
.not_null()
.primary_key()
)
.col(ColumnDef::new(KbInfo::Uid).big_integer().not_null())
.col(ColumnDef::new(KbInfo::KbName).string().not_null())
.col(ColumnDef::new(KbInfo::Icon).tiny_unsigned().default(1))
.col(
ColumnDef::new(KbInfo::CreatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(
ColumnDef::new(KbInfo::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(KbInfo::IsDeleted).boolean().default(false))
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(DocInfo::Table)
.if_not_exists()
.col(
ColumnDef::new(DocInfo::Did)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(DocInfo::Uid).big_integer().not_null())
.col(ColumnDef::new(DocInfo::DocName).string().not_null())
.col(ColumnDef::new(DocInfo::Location).string().not_null())
.col(ColumnDef::new(DocInfo::Size).big_integer().not_null())
.col(ColumnDef::new(DocInfo::Type).string().not_null())
.col(ColumnDef::new(DocInfo::ThumbnailBase64).string().default(""))
.comment("doc type|folder")
.col(
ColumnDef::new(DocInfo::CreatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(
ColumnDef::new(DocInfo::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(DocInfo::IsDeleted).boolean().default(false))
.to_owned()
).await?;
manager.create_table(
Table::create()
.table(DialogInfo::Table)
.if_not_exists()
.col(
ColumnDef::new(DialogInfo::DialogId)
.big_integer()
.not_null()
.auto_increment()
.primary_key()
)
.col(ColumnDef::new(DialogInfo::Uid).big_integer().not_null())
.col(ColumnDef::new(DialogInfo::KbId).big_integer().not_null())
.col(ColumnDef::new(DialogInfo::DialogName).string().not_null())
.col(ColumnDef::new(DialogInfo::History).string().comment("json"))
.col(
ColumnDef::new(DialogInfo::CreatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(
ColumnDef::new(DialogInfo::UpdatedAt)
.timestamp_with_time_zone()
.default(Expr::current_timestamp())
.not_null()
)
.col(ColumnDef::new(DialogInfo::IsDeleted).boolean().default(false))
.to_owned()
).await?;
let root_insert = Query::insert()
.into_table(UserInfo::Table)
.columns([UserInfo::Email, UserInfo::Nickname, UserInfo::Password])
.values_panic(["kai.hu@infiniflow.org".into(), "root".into(), "123456".into()])
.to_owned();
let doc_insert = Query::insert()
.into_table(DocInfo::Table)
.columns([
DocInfo::Uid,
DocInfo::DocName,
DocInfo::Size,
DocInfo::Type,
DocInfo::Location,
])
.values_panic([(1).into(), "/".into(), (0).into(), "folder".into(), "".into()])
.to_owned();
let tag_insert = Query::insert()
.into_table(TagInfo::Table)
.columns([TagInfo::Uid, TagInfo::TagName, TagInfo::Regx, TagInfo::Color, TagInfo::Icon])
.values_panic([
(1).into(),
"Video".into(),
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)".into(),
(1).into(),
(1).into(),
])
.values_panic([
(1).into(),
"Picture".into(),
".*\\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)".into(),
(2).into(),
(2).into(),
])
.values_panic([
(1).into(),
"Music".into(),
".*\\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)".into(),
(3).into(),
(3).into(),
])
.values_panic([
(1).into(),
"Document".into(),
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)".into(),
(3).into(),
(3).into(),
])
.to_owned();
manager.exec_stmt(root_insert).await?;
manager.exec_stmt(doc_insert).await?;
manager.exec_stmt(tag_insert).await?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.drop_table(Table::drop().table(UserInfo::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(TagInfo::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(Tag2Doc::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(Kb2Doc::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(Dialog2Kb::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(Doc2Doc::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(KbInfo::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(DocInfo::Table).to_owned()).await?;
manager.drop_table(Table::drop().table(DialogInfo::Table).to_owned()).await?;
Ok(())
}
}
#[derive(DeriveIden)]
enum UserInfo {
Table,
Uid,
Email,
Nickname,
AvatarBase64,
ColorScheme,
ListStyle,
Language,
Password,
LastLoginAt,
CreatedAt,
UpdatedAt,
IsDeleted,
}
#[derive(DeriveIden)]
enum TagInfo {
Table,
Tid,
Uid,
TagName,
Regx,
Color,
Icon,
FolderId,
CreatedAt,
UpdatedAt,
IsDeleted,
}
#[derive(DeriveIden)]
enum Tag2Doc {
Table,
Id,
TagId,
Did,
}
#[derive(DeriveIden)]
enum Kb2Doc {
Table,
Id,
KbId,
Did,
KbProgress,
KbProgressMsg,
UpdatedAt,
IsDeleted,
}
#[derive(DeriveIden)]
enum Dialog2Kb {
Table,
Id,
DialogId,
KbId,
}
#[derive(DeriveIden)]
enum Doc2Doc {
Table,
Id,
ParentId,
Did,
}
#[derive(DeriveIden)]
enum KbInfo {
Table,
KbId,
Uid,
KbName,
Icon,
CreatedAt,
UpdatedAt,
IsDeleted,
}
#[derive(DeriveIden)]
enum DocInfo {
Table,
Did,
Uid,
DocName,
Location,
Size,
Type,
ThumbnailBase64,
CreatedAt,
UpdatedAt,
IsDeleted,
}
#[derive(DeriveIden)]
enum DialogInfo {
Table,
Uid,
KbId,
DialogId,
DialogName,
History,
CreatedAt,
UpdatedAt,
IsDeleted,
}

View File

@ -1,6 +0,0 @@
use sea_orm_migration::prelude::*;
#[async_std::main]
async fn main() {
cli::run_cli(migration::Migrator).await;
}

29
python/Dockerfile Normal file
View File

@ -0,0 +1,29 @@
FROM ubuntu:22.04 as base
RUN apt-get update
ENV TZ="Asia/Taipei"
RUN apt-get install -yq \
build-essential \
curl \
libncursesw5-dev \
libssl-dev \
libsqlite3-dev \
libgdbm-dev \
libc6-dev \
libbz2-dev \
software-properties-common \
python3.11 python3.11-dev python3-pip
RUN apt-get install -yq git
RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple
RUN pip3 config set global.trusted-host mirror.baidu.com
RUN pip3 install --upgrade pip
RUN pip3 install torch==2.0.1
RUN pip3 install torch-model-archiver==0.8.2
RUN pip3 install torchvision==0.15.2
COPY requirements.txt .
WORKDIR /docgpt
ENV PYTHONPATH=/docgpt/

View File

@ -1,22 +0,0 @@
```shell
docker pull postgres
LOCAL_POSTGRES_DATA=./postgres-data
docker run
--name docass-postgres
-p 5455:5432
-v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data
-e POSTGRES_USER=root
-e POSTGRES_PASSWORD=infiniflow_docass
-e POSTGRES_DB=docass
-d
postgres
docker network create elastic
docker pull elasticsearch:8.11.3;
docker pull docker.elastic.co/kibana/kibana:8.11.3
```

63
python/] Normal file
View File

@ -0,0 +1,63 @@
from abc import ABC
from openai import OpenAI
import os
import base64
from io import BytesIO
class Base(ABC):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
class GptV4(Base):
def __init__(self):
import openapi
openapi.api_key = os.environ["OPENAPI_KEY"]
self.client = OpenAI()
def describe(self, image, max_tokens=300):
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
res = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}
],
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
class QWen(Base):
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
response = Generation.call(
Generation.Models.qwen_turbo,
messages=messages,
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,41 +0,0 @@
{
"version":1,
"disable_existing_loggers":false,
"formatters":{
"simple":{
"format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s"
}
},
"handlers":{
"console":{
"class":"logging.StreamHandler",
"level":"DEBUG",
"formatter":"simple",
"stream":"ext://sys.stdout"
},
"info_file_handler":{
"class":"logging.handlers.TimedRotatingFileHandler",
"level":"INFO",
"formatter":"simple",
"filename":"log/info.log",
"when": "MIDNIGHT",
"interval":1,
"backupCount":30,
"encoding":"utf8"
},
"error_file_handler":{
"class":"logging.handlers.TimedRotatingFileHandler",
"level":"ERROR",
"formatter":"simple",
"filename":"log/errors.log",
"when": "MIDNIGHT",
"interval":1,
"backupCount":30,
"encoding":"utf8"
}
},
"root":{
"level":"DEBUG",
"handlers":["console","info_file_handler","error_file_handler"]
}
}

View File

@ -1,9 +0,0 @@
[infiniflow]
es=http://es01:9200
postgres_user=root
postgres_password=infiniflow_docgpt
postgres_host=postgres
postgres_port=5432
minio_host=minio:9000
minio_user=infiniflow
minio_password=infiniflow_docgpt

View File

@ -1,21 +0,0 @@
import os
from .embedding_model import *
from .chat_model import *
from .cv_model import *
EmbeddingModel = None
ChatModel = None
CvModel = None
if os.environ.get("OPENAI_API_KEY"):
EmbeddingModel = GptEmbed()
ChatModel = GptTurbo()
CvModel = GptV4()
elif os.environ.get("DASHSCOPE_API_KEY"):
EmbeddingModel = QWenEmbd()
ChatModel = QWenChat()
CvModel = QWenCV()
else:
EmbeddingModel = HuEmbedding()

View File

@ -1,61 +0,0 @@
from abc import ABC
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
class Base(ABC):
def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!")
class HuEmbedding(Base):
def __init__(self):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32):
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res)
class GptEmbed(Base):
def __init__(self):
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model="text-embedding-ada-002")
return [d["embedding"] for d in res["data"]]
class QWenEmbd(Base):
def encode(self, texts: list, batch_size=32, text_type="document"):
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
import dashscope
from http import HTTPStatus
res = []
for txt in texts:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v2,
input=txt[:2048],
text_type=text_type
)
res.append(resp["output"]["embeddings"][0]["embedding"])
return res

0
python/output/ToPDF.pdf Normal file
View File

View File

@ -1,25 +0,0 @@
from openpyxl import load_workbook
import sys
from io import BytesIO
class HuExcelParser:
def __call__(self, fnm):
if isinstance(fnm, str):
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(fnm))
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
lines = []
for r in ws.rows:
lines.append(
"\t".join([str(c.value) if c.value is not None else "" for c in r]))
res.append(f"{sheetname}\n" + "\n".join(lines))
return res
if __name__ == "__main__":
psr = HuExcelParser()
psr(sys.argv[1])

View File

@ -1,194 +0,0 @@
accelerate==0.24.1
addict==2.4.0
aiobotocore==2.7.0
aiofiles==23.2.1
aiohttp==3.8.6
aioitertools==0.11.0
aiosignal==1.3.1
aliyun-python-sdk-core==2.14.0
aliyun-python-sdk-kms==2.16.2
altair==5.1.2
anyio==3.7.1
astor==0.8.1
async-timeout==4.0.3
attrdict==2.0.1
attrs==23.1.0
Babel==2.13.1
bce-python-sdk==0.8.92
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
blinker==1.7.0
botocore==1.31.64
cachetools==5.3.2
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
contourpy==1.2.0
crcmod==1.7
cryptography==41.0.5
cssselect==1.2.0
cssutils==2.9.0
cycler==0.12.1
Cython==3.0.5
datasets==2.13.0
datrie==0.8.2
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
einops==0.7.0
elastic-transport==8.10.0
elasticsearch==8.10.1
elasticsearch-dsl==8.9.0
et-xmlfile==1.1.0
fastapi==0.104.1
ffmpy==0.3.1
filelock==3.13.1
fire==0.5.0
FlagEmbedding==1.1.5
Flask==3.0.0
flask-babel==4.0.0
fonttools==4.44.0
frozenlist==1.4.0
fsspec==2023.10.0
future==0.18.3
gast==0.5.4
-e
git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py
gradio==3.50.2
gradio_client==0.6.1
greenlet==3.0.1
h11==0.14.0
hanziconv==0.3.2
httpcore==1.0.1
httpx==0.25.1
huggingface-hub==0.17.3
idna==3.4
imageio==2.31.6
imgaug==0.4.0
importlib-metadata==6.8.0
importlib-resources==6.1.0
install==1.3.5
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.3.2
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
kiwisolver==1.4.5
lazy_loader==0.3
lmdb==1.4.1
lxml==4.9.3
MarkupSafe==2.1.3
matplotlib==3.8.1
modelscope==1.9.4
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
networkx==3.2.1
nltk==3.8.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
openpyxl==3.1.2
opt-einsum==3.3.0
orjson==3.9.10
oss2==2.18.3
packaging==23.2
paddleocr==2.7.0.3
paddlepaddle-gpu==2.5.2.post120
pandas==2.1.2
pdf2docx==0.5.5
pdfminer.six==20221105
pdfplumber==0.10.3
Pillow==10.0.1
platformdirs==3.11.0
premailer==3.10.0
protobuf==4.25.0
psutil==5.9.6
pyarrow==14.0.0
pyclipper==1.3.0.post5
pycocotools==2.0.7
pycparser==2.21
pycryptodome==3.19.0
pydantic==1.10.13
pydub==0.25.1
PyMuPDF==1.20.2
pyparsing==3.1.1
pypdfium2==4.23.1
python-dateutil==2.8.2
python-docx==1.1.0
python-multipart==0.0.6
pytz==2023.3.post1
PyYAML==6.0.1
rapidfuzz==3.5.2
rarfile==4.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rpds-py==0.12.0
s3fs==2023.10.0
safetensors==0.4.0
scikit-image==0.22.0
scikit-learn==1.3.2
scipy==1.11.3
semantic-version==2.10.0
sentence-transformers==2.2.2
sentencepiece==0.1.98
shapely==2.0.2
simplejson==3.19.2
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
SQLAlchemy==2.0.23
starlette==0.27.0
sympy==1.12
tabulate==0.9.0
tblib==3.0.0
termcolor==2.3.0
threadpoolctl==3.2.0
tifffile==2023.9.26
tiktoken==0.5.1
timm==0.9.10
tokenizers==0.13.3
tomli==2.0.1
toolz==0.12.0
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
tornado==6.3.3
tqdm==4.66.1
transformers==4.33.0
transformers-stream-generator==0.0.4
triton==2.1.0
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.7
uvicorn==0.24.0
uvloop==0.19.0
visualdl==2.5.3
websockets==11.0.3
Werkzeug==3.0.1
wrapt==1.15.0
xgboost==2.0.1
xinference==0.6.0
xorbits==0.7.0
xoscar==0.1.3
xxhash==3.4.1
yapf==0.40.2
yarl==1.9.2
zipp==3.17.0

8
python/res/1-0.tm Normal file
View File

@ -0,0 +1,8 @@
2023-12-20 11:44:08.791336+00:00
2023-12-20 11:44:08.853249+00:00
2023-12-20 11:44:08.909933+00:00
2023-12-21 00:47:09.996757+00:00
2023-12-20 11:44:08.965855+00:00
2023-12-20 11:44:09.011682+00:00
2023-12-21 00:47:10.063326+00:00
2023-12-20 11:44:09.069486+00:00

View File

@ -0,0 +1,3 @@
2023-12-27 08:21:49.309802+00:00
2023-12-27 08:37:22.407772+00:00
2023-12-27 08:59:18.845627+00:00

View File

@ -1,118 +0,0 @@
import sys, datetime, random, re, cv2
from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../")
from util.db_conn import Postgres
from util.minio_conn import HuMinio
from util import findMaxDt
import base64
from io import BytesIO
import pandas as pd
from PIL import Image
import pdfplumber
PG = Postgres("infiniflow", "docgpt")
MINIO = HuMinio("infiniflow")
def set_thumbnail(did, base64):
sql = f"""
update doc_info set thumbnail_base64='{base64}'
where
did={did}
"""
PG.update(sql)
def collect(comm, mod, tm):
sql = f"""
select
did, uid, doc_name, location, updated_at
from doc_info
where
updated_at >= '{tm}'
and MOD(did, {comm}) = {mod}
and is_deleted=false
and type <> 'folder'
and thumbnail_base64=''
order by updated_at asc
limit 10
"""
docs = PG.select(sql)
if len(docs) == 0:return pd.DataFrame()
mtm = str(docs["updated_at"].max())[:19]
print("TOTAL:", len(docs), "To: ", mtm)
return docs
def build(row):
if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$",
row["doc_name"].lower().strip()):
set_thumbnail(row["did"], "_")
return
def thumbnail(img, SIZE=128):
w,h = img.size
p = SIZE/max(w, h)
w, h = int(w*p), int(h*p)
img.thumbnail((w, h))
buffered = BytesIO()
try:
img.save(buffered, format="JPEG")
except Exception as e:
try:
img.save(buffered, format="PNG")
except Exception as ee:
pass
return base64.b64encode(buffered.getvalue()).decode("utf-8")
iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
if re.search(r"\.pdf$", row["doc_name"].lower().strip()):
pdf = pdfplumber.open(iobytes)
img = pdf.pages[0].to_image().annotated
set_thumbnail(row["did"], thumbnail(img))
if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()):
img = Image.open(iobytes)
set_thumbnail(row["did"], thumbnail(img))
if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()):
url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]),
row["location"],
expires=datetime.timedelta(seconds=60)
)
cap = cv2.VideoCapture(url)
succ = cap.isOpened()
i = random.randint(1, 11)
while succ:
ret, frame = cap.read()
if not ret: break
if i > 0:
i -= 1
continue
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
print(img.size)
set_thumbnail(row["did"], thumbnail(img))
cap.release()
cv2.destroyAllWindows()
def main(comm, mod):
global model
tm_fnm = f"res/thumbnail-{comm}-{mod}.tm"
tm = findMaxDt(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
build(r)
tmf.write(str(r["updated_at"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

View File

@ -1,165 +0,0 @@
#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from nlp import search
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
from llm import ChatModel, EmbeddingModel
SE = None
CFIELD="content_ltks"
EMBEDDING = EmbeddingModel
LLM = ChatModel
def get_QA_pairs(hists):
pa = []
for h in hists:
for k in ["user", "assistant"]:
if h.get(k):
pa.append({
"content": h[k],
"role": k,
})
for p in pa[:-1]: assert len(p) == 2, p
return pa
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
max_len //= len(top_i)
# add instruction to prompt
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
if len(instructions)>2:
# Said that LLM is sensitive to the first and the last one, so
# rearrange the order of references
instructions.append(copy.deepcopy(instructions[1]))
instructions.pop(1)
def token_num(txt):
c = 0
for tk in re.split(r"[,。/?‘’”“:;:;!]", txt):
if re.match(r"[a-zA-Z-]+$", tk):
c += 1
continue
c += len(tk)
return c
_inst = ""
for ins in instructions:
if token_num(_inst) > 4096:
_inst += "\n知识库:" + instructions[-1][:max_len]
break
_inst += "\n知识库:" + ins[:max_len]
return _inst
def prompt_and_answer(history, inst):
hist = get_QA_pairs(history)
chks = []
for s in re.split(r"[:;;。\n\r]+", inst):
if s: chks.append(s)
chks = len(set(chks))/(0.1+len(chks))
print("Duplication portion:", chks)
system = """
你是一个智能助手请总结知识库的内容来回答问题请列举知识库中的数据详细回答%s当所有知识库内容都与问题无关时你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话回答需要考虑聊天历史
以下是知识库
%s
以上是知识库
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
print("【PROMPT】:", system)
start = timer()
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
print("GENERATE: ", timer()-start)
print("===>>", response)
return response
class Handler(RequestHandler):
def post(self):
global SE,MUST_TK_NUM
param = json.loads(self.request.body.decode('utf-8'))
try:
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
res = SE.search({
"question": question,
"kb_ids": param.get("kb_ids", []),
"size": param.get("topn", 15)},
search.index_name(param["uid"])
)
sim = SE.rerank(res, question)
rk_idx = np.argsort(sim*-1)
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
inst = get_instruction(res, topidx)
ans, topidx = prompt_and_answer(param["history"], inst)
ans = SE.insert_citations(ans, topidx, res)
refer = OrderedDict()
docnms = {}
for i in rk_idx:
did = res.field[res.ids[i]]["doc_id"]
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
if did not in refer: refer[did] = []
refer[did].append({
"chunk_id": res.ids[i],
"content": res.field[res.ids[i]]["content_ltks"],
"image": ""
})
print("::::::::::::::", ans)
self.write(json.dumps({
"code":0,
"msg":"success",
"data":{
"uid": param["uid"],
"dialog_id": param["dialog_id"],
"assistant": ans,
"refer": [{
"did": did,
"doc_name": docnms[did],
"chunks": chunks
} for did, chunks in refer.items()]
}
}))
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
except Exception as e:
logging.error("Request 500: "+str(e))
self.write(json.dumps({
"code":500,
"msg":str(e),
"data":{}
}))
print(traceback.format_exc())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
ARGS = parser.parse_args()
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
http_server = HTTPServer(app)
http_server.bind(ARGS.port)
http_server.start(3)
IOLoop.current().start()

View File

@ -1,258 +0,0 @@
import json, os, sys, hashlib, copy, time, random, re
from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../")
from util.es_conn import HuEs
from util.db_conn import Postgres
from util.minio_conn import HuMinio
from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel
from nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
from elasticsearch_dsl import Q
from PIL import Image
from parser import (
PdfParser,
DocxParser,
ExcelParser
)
from nlp.huchunk import (
PdfChunker,
DocxChunker,
ExcelChunker,
PptChunker,
TextChunker
)
ES = HuEs("infiniflow")
BATCH_SIZE = 64
PG = Postgres("infiniflow", "docgpt")
MINIO = HuMinio("infiniflow")
PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
def chuck_doc(name, binary):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0: return PDF(binary)
if suff.find("doc") >= 0: return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
if suff.find("ppt") >= 0: return PPT(binary)
if os.envirement.get("PARSE_IMAGE") \
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
from llm import CvModel
txt = CvModel.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return TextChunker()(binary)
def collect(comm, mod, tm):
sql = f"""
select
id as kb2doc_id,
kb_id,
did,
updated_at,
is_deleted
from kb2_doc
where
updated_at >= '{tm}'
and kb_progress = 0
and MOD(did, {comm}) = {mod}
order by updated_at asc
limit 1000
"""
kb2doc = PG.select(sql)
if len(kb2doc) == 0:return pd.DataFrame()
sql = """
select
did,
uid,
doc_name,
location,
size
from doc_info
where
did in (%s)
"""%",".join([str(i) for i in kb2doc["did"].unique()])
docs = PG.select(sql)
docs = docs.fillna("")
docs = docs.join(kb2doc.set_index("did"), on="did", how="left")
mtm = str(docs["updated_at"].max())[:19]
print("TOTAL:", len(docs), "To: ", mtm)
return docs
def set_progress(kb2doc_id, prog, msg="Processing..."):
sql = f"""
update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}'
where
id={kb2doc_id}
"""
PG.update(sql)
def build(row):
if row["size"] > 256000000:
set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )")
return []
res = ES.search(Q("term", doc_id=row["did"]))
if ES.getTotal(res) > 0:
ES.updateScriptByQuery(Q("term", doc_id=row["did"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
"""%(str(row["kb_id"]), str(row["kb_id"])),
idxnm = search.index_name(row["uid"])
)
set_progress(row["kb2doc_id"], 1, "Done")
return []
random.seed(time.time())
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
try:
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
else:
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
return []
if not obj.text_chunks and not obj.table_chunks:
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
return []
set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks:
d = copy.deepcopy(doc)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
if not img:
docs.append(d)
continue
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
else: output_buffer = BytesIO(img)
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
docs.append(d)
for arr, img in obj.table_chunks:
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
docs.append(d)
set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.")
return docs
def init_kb(row):
idxnm = search.index_name(row["uid"])
if ES.indexExist(idxnm): return
return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
model = None
def embedding(docs):
global model
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
vects = 0.1 * tts + 0.9 * cnts
assert len(vects) == len(docs)
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
def rm_doc_from_kb(df):
if len(df) == 0:return
for _,r in df.iterrows():
ES.updateScriptByQuery(Q("term", doc_id=r["did"]),
scripts="""
if(ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.remove(
ctx._source.kb_id.indexOf('%s')
);
"""%(str(r["kb_id"]),str(r["kb_id"])),
idxnm = search.index_name(r["uid"])
)
if len(df) == 0:return
sql = """
delete from kb2_doc where id in (%s)
"""%",".join([str(i) for i in df["kb2doc_id"]])
PG.update(sql)
def main(comm, mod):
global model
from llm import HuEmbedding
model = HuEmbedding()
tm_fnm = f"res/{comm}-{mod}.tm"
tm = findMaxDt(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:return
rm_doc_from_kb(rows.loc[rows.is_deleted == True])
rows = rows.loc[rows.is_deleted == False].reset_index(drop=True)
if len(rows) == 0:return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
cks = build(r)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
continue
## TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
embedding(cks)
set_progress(r["kb2doc_id"], random.randint(70, 95)/100.,
"Finished embedding! Start to build index!")
init_kb(r)
es_r = ES.bulk(cks, search.index_name(r["uid"]))
if es_r:
set_progress(r["kb2doc_id"], -1, "Index failure!")
print(es_r)
else: set_progress(r["kb2doc_id"], 1., "Done!")
tmf.write(str(r["updated_at"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

15
python/tmp.log Normal file
View File

@ -0,0 +1,15 @@
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 106184.91it/s]
----------- Model Configuration -----------
Model Arch: GFL
Transform Order:
--transform op: Resize
--transform op: NormalizeImage
--transform op: Permute
--transform op: PadStride
--------------------------------------------
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
WARNING:root:The files are stored in /opt/home/kevinhu/docgpt/, please check it!

View File

@ -1,31 +0,0 @@
from configparser import ConfigParser
import os
import inspect
CF = ConfigParser()
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
if not os.path.exists(__fnm):
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
assert os.path.exists(
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
if not os.path.exists(__fnm):
__fnm = "./sys.cnf"
CF.read(__fnm)
class Config:
def __init__(self, env):
self.env = env
if env == "spark":
CF.read("./cv.cnf")
def get(self, key, default=None):
global CF
return os.environ.get(key.upper(),
CF[self.env].get(key, default)
)
def init(env):
return Config(env)

View File

@ -1,70 +0,0 @@
import logging
import time
from util import config
import pandas as pd
class Postgres(object):
def __init__(self, env, dbnm):
self.config = config.init(env)
self.conn = None
self.dbnm = dbnm
self.__open__()
def __open__(self):
import psycopg2
try:
if self.conn:
self.__close__()
del self.conn
except Exception as e:
pass
try:
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
user={self.config.get('postgres_user')}
password={self.config.get('postgres_password')}
host={self.config.get('postgres_host')}
port={self.config.get('postgres_port')}""")
except Exception as e:
logging.error(
"Fail to connect %s " %
self.config.get("pgdb_host") + str(e))
def __close__(self):
try:
self.conn.close()
except Exception as e:
logging.error(
"Fail to close %s " %
self.config.get("pgdb_host") + str(e))
def select(self, sql):
for _ in range(10):
try:
return pd.read_sql(sql, self.conn)
except Exception as e:
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return pd.DataFrame()
def update(self, sql):
for _ in range(10):
try:
cur = self.conn.cursor()
cur.execute(sql)
updated_rows = cur.rowcount
self.conn.commit()
cur.close()
return updated_rows
except Exception as e:
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return 0
if __name__ == "__main__":
Postgres("infiniflow", "docgpt")

View File

@ -1,36 +0,0 @@
import json
import logging.config
import os
def log_dir():
fnm = os.path.join(os.path.dirname(__file__), '../log/')
if not os.path.exists(fnm):
fnm = os.path.join(os.path.dirname(__file__), '../../log/')
assert os.path.exists(fnm), f"Can't locate log dir: {fnm}"
return fnm
def setup_logging(default_path="conf/logging.json",
default_level=logging.INFO,
env_key="LOG_CFG"):
path = default_path
value = os.getenv(env_key, None)
if value:
path = value
if os.path.exists(path):
with open(path, "r") as f:
config = json.load(f)
fnm = log_dir()
config["handlers"]["info_file_handler"]["filename"] = fnm + "info.log"
config["handlers"]["error_file_handler"]["filename"] = fnm + "error.log"
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
__fnm = os.path.join(os.path.dirname(__file__), 'conf/logging.json')
if not os.path.exists(__fnm):
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/logging.json')
setup_logging(__fnm)

0
rag/__init__.py Normal file
View File

32
rag/llm/__init__.py Normal file
View File

@ -0,0 +1,32 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .embedding_model import *
from .chat_model import *
from .cv_model import *
EmbeddingModel = {
"local": HuEmbedding,
"OpenAI": OpenAIEmbed,
"通义千问": QWenEmbed,
}
CvModel = {
"OpenAI": GptV4,
"通义千问": QWenCV,
}

View File

@ -1,3 +1,18 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import os import os

View File

@ -1,3 +1,18 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import os import os
@ -6,6 +21,9 @@ from io import BytesIO
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name):
pass
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
@ -40,14 +58,15 @@ class Base(ABC):
class GptV4(Base): class GptV4(Base):
def __init__(self): def __init__(self, key, model_name="gpt-4-vision-preview"):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) self.client = OpenAI(key)
self.model_name = model_name
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
b64 = self.image2base64(image) b64 = self.image2base64(image)
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model="gpt-4-vision-preview", model=self.model_name,
messages=self.prompt(b64), messages=self.prompt(b64),
max_tokens=max_tokens, max_tokens=max_tokens,
) )
@ -55,11 +74,15 @@ class GptV4(Base):
class QWenCV(Base): class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1"):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY response = MultiModalConversation.call(model=self.model_name,
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
messages=self.prompt(self.image2base64(image))) messages=self.prompt(self.image2base64(image)))
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'] return response.output.choices[0]['message']['content']

View File

@ -0,0 +1,94 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
from rag.utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name):
pass
def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!")
class HuEmbedding(Base):
def __init__(self):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32):
token_count = 0
for t in texts: token_count += num_tokens_from_string(t)
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res), token_count
class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002"):
self.client = OpenAI(key)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
token_count = 0
for t in texts: token_count += num_tokens_from_string(t)
res = self.client.embeddings.create(input=texts,
model=self.model_name)
return [d["embedding"] for d in res["data"]], token_count
class QWenEmbed(Base):
def __init__(self, key, model_name="text_embedding_v2"):
dashscope.api_key = key
self.model_name = model_name
def encode(self, texts: list, batch_size=32, text_type="document"):
import dashscope
res = []
token_count = 0
for txt in texts:
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=txt[:2048],
text_type=text_type
)
res.append(resp["output"]["embeddings"][0]["embedding"])
token_count += resp["usage"]["total_tokens"]
return res, token_count

0
rag/nlp/__init__.py Normal file
View File

11
python/nlp/huqie.py → rag/nlp/huqie.py Executable file → Normal file
View File

@ -9,6 +9,8 @@ import string
import sys import sys
from hanziconv import HanziConv from hanziconv import HanziConv
from web_server.utils.file_utils import get_project_base_directory
class Huqie: class Huqie:
def key_(self, line): def key_(self, line):
@ -41,14 +43,7 @@ class Huqie:
self.DEBUG = debug self.DEBUG = debug
self.DENOMINATOR = 1000000 self.DENOMINATOR = 1000000
self.trie_ = datrie.Trie(string.printable) self.trie_ = datrie.Trie(string.printable)
self.DIR_ = "" self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
if os.path.exists("../res/huqie.txt"):
self.DIR_ = "../res/huqie"
if os.path.exists("./res/huqie.txt"):
self.DIR_ = "./res/huqie"
if os.path.exists("./huqie.txt"):
self.DIR_ = "./huqie"
assert self.DIR_, f"【Can't find huqie】"
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)" self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
try: try:

6
python/nlp/query.py → rag/nlp/query.py Executable file → Normal file
View File

@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
import json import json
import re import re
import sys
import os
import logging import logging
import copy import copy
import math import math
from elasticsearch_dsl import Q, Search from elasticsearch_dsl import Q, Search
from nlp import huqie, term_weight, synonym from rag.nlp import huqie, term_weight, synonym
class EsQueryer: class EsQueryer:

View File

@ -1,13 +1,11 @@
# -*- coding: utf-8 -*-
import re import re
from elasticsearch_dsl import Q, Search, A from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple, Dict, Union from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
from util import setup_logging, rmSpace from rag.utils import rmSpace
from nlp import huqie, query from rag.nlp import huqie, query
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np import numpy as np
from copy import deepcopy
def index_name(uid): return f"docgpt_{uid}" def index_name(uid): return f"docgpt_{uid}"

13
python/nlp/synonym.py → rag/nlp/synonym.py Executable file → Normal file
View File

@ -1,8 +1,11 @@
import json import json
import os
import time import time
import logging import logging
import re import re
from web_server.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:
def __init__(self, redis=None): def __init__(self, redis=None):
@ -10,15 +13,9 @@ class Dealer:
self.lookup_num = 100000000 self.lookup_num = 100000000
self.load_tm = time.time() - 1000000 self.load_tm = time.time() - 1000000
self.dictionary = None self.dictionary = None
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try: try:
self.dictionary = json.load(open("./synonym.json", 'r')) self.dictionary = json.load(open(path, 'r'))
except Exception as e:
pass
try:
self.dictionary = json.load(open("./res/synonym.json", 'r'))
except Exception as e:
try:
self.dictionary = json.load(open("../res/synonym.json", 'r'))
except Exception as e: except Exception as e:
logging.warn("Miss synonym.json") logging.warn("Miss synonym.json")
self.dictionary = {} self.dictionary = {}

12
python/nlp/term_weight.py → rag/nlp/term_weight.py Executable file → Normal file
View File

@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
import math import math
import json import json
import re import re
import os import os
import numpy as np import numpy as np
from nlp import huqie from rag.nlp import huqie
from web_server.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:
@ -60,16 +62,14 @@ class Dealer:
return set(res.keys()) return set(res.keys())
return res return res
fnm = os.path.join(os.path.dirname(__file__), '../res/') fnm = os.path.join(get_project_base_directory(), "res")
if not os.path.exists(fnm):
fnm = os.path.join(os.path.dirname(__file__), '../../res/')
self.ne, self.df = {}, {} self.ne, self.df = {}, {}
try: try:
self.ne = json.load(open(fnm + "ner.json", "r")) self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception as e: except Exception as e:
print("[WARNING] Load ner.json FAIL!") print("[WARNING] Load ner.json FAIL!")
try: try:
self.df = load_dict(fnm + "term.freq") self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception as e: except Exception as e:
print("[WARNING] Load term.freq FAIL!") print("[WARNING] Load term.freq FAIL!")

View File

@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
from docx import Document from docx import Document
import re import re
import pandas as pd import pandas as pd
from collections import Counter from collections import Counter
from nlp import huqie from rag.nlp import huqie
from io import BytesIO from io import BytesIO

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
from openpyxl import load_workbook
import sys
from io import BytesIO
class HuExcelParser:
def __call__(self, fnm):
if isinstance(fnm, str):
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(fnm))
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
ti = list(rows[0])
for r in list(rows[1:]):
l = []
for i,c in enumerate(r):
if not c.value:continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
res.append(l)
return res
if __name__ == "__main__":
psr = HuExcelParser()
psr(sys.argv[1])

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import xgboost as xgb import xgboost as xgb
from io import BytesIO from io import BytesIO
import torch import torch
@ -6,11 +7,11 @@ import pdfplumber
import logging import logging
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from nlp import huqie from rag.nlp import huqie
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
from cv.table_recognize import TableTransformer from rag.cv.table_recognize import TableTransformer
from cv.ppdetection import PPDet from rag.cv.ppdetection import PPDet
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
logging.getLogger("pdfminer").setLevel(logging.WARNING) logging.getLogger("pdfminer").setLevel(logging.WARNING)

0
python/res/ner.json → rag/res/ner.json Executable file → Normal file
View File

37
rag/settings.py Normal file
View File

@ -0,0 +1,37 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from web_server.utils import get_base_config,decrypt_database_config
from web_server.utils.file_utils import get_project_base_directory
from web_server.utils.log_utils import LoggerFactory, getLogger
# Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {})
MINIO = decrypt_database_config(name="minio")
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
# Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 10
es_logger = getLogger("es")
minio_logger = getLogger("minio")
cron_logger = getLogger("cron_logger")

279
rag/svr/parse_user_docs.py Normal file
View File

@ -0,0 +1,279 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import hashlib
import copy
import time
import random
import re
from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt
from rag.nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
from elasticsearch_dsl import Q
from PIL import Image
from rag.parser import (
PdfParser,
DocxParser,
ExcelParser
)
from rag.nlp.huchunk import (
PdfChunker,
DocxChunker,
ExcelChunker,
PptChunker,
TextChunker
)
from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory
BATCH_SIZE = 64
PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
def chuck_doc(name, binary, cvmdl=None):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0:
return PDF(binary)
if suff.find("doc") >= 0:
return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
return EXC(binary)
if suff.find("ppt") >= 0:
return PPT(binary)
if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
txt = cvmdl.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return TextChunker()(binary)
def collect(comm, mod, tm):
docs = DocumentService.get_newly_uploaded(tm, mod, comm)
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19]
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
def set_progress(docid, prog, msg="Processing...", begin=False):
d = {"progress": prog, "progress_msg": msg}
if begin:
d["process_begin_at"] = get_format_time()
try:
DocumentService.update_by_id(
docid, {"progress": prog, "progress_msg": msg})
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
""" % (str(row["kb_id"]), str(row["kb_id"])),
idxnm=search.index_name(row["tenant_id"])
)
set_progress(row["id"], 1, "Done")
return []
random.seed(time.time())
set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True)
try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
row["id"], -1, "Can not find file <%s>" %
row["doc_name"])
else:
set_progress(
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
return []
if not obj.text_chunks and not obj.table_chunks:
set_progress(
row["id"],
1,
"Nothing added! Mostly, file type unsupported yet.")
return []
set_progress(row["id"], random.randint(20, 60) / 100.,
"Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]),
"updated_at": str(row["update_time"]).replace("T", " ")[:19]
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks:
d = copy.deepcopy(doc)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
if not img:
docs.append(d)
continue
if isinstance(img, Image):
img.save(output_buffer, format='JPEG')
else:
output_buffer = BytesIO(img)
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
docs.append(d)
for arr, img in obj.table_chunks:
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
docs.append(d)
set_progress(row["id"], random.randint(60, 70) /
100., "Continue embedding the content.")
return docs
def init_kb(row):
idxnm = search.index_name(row["tenant_id"])
if ELASTICSEARCH.indexExist(idxnm):
return
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def embedding(docs, mdl):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
tk_count = 0
tts, c = mdl.encode(tts)
tk_count += c
cnts, c = mdl.encode(cnts)
tk_count += c
vects = 0.1 * tts + 0.9 * cnts
assert len(vects) == len(docs)
for i, d in enumerate(docs):
d["q_vec"] = vects[i].tolist()
return tk_count
def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return
model_config = model_config[0]
if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
def main(comm, mod):
global model
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:
return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
if not embd_mdl:
set_progress(r["id"], -1, "Can't find embedding model!")
cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
continue
cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
st_tm = timer()
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
try:
tk_count = embedding(cks, embd_mdl)
except Exception as e:
set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
continue
set_progress(r["id"], random.randint(70, 95) / 100.,
"Finished embedding! Start to build index!")
init_kb(r)
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
if es_r:
set_progress(r["id"], -1, "Index failure!")
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

View File

@ -1,6 +1,23 @@
import os
import re import re
import tiktoken
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton
from .minio_conn import MINIO
from .es_conn import ELASTICSEARCH
def rmSpace(txt): def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt) txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt) return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
@ -22,3 +39,9 @@ def findMaxDt(fnm):
except Exception as e: except Exception as e:
print("WARNING: can't find " + fnm) print("WARNING: can't find " + fnm)
return m return m
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens

92
python/util/es_conn.py → rag/utils/es_conn.py Executable file → Normal file
View File

@ -1,51 +1,39 @@
import re import re
import logging
import json import json
import time import time
import copy import copy
import elasticsearch import elasticsearch
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Search, Index, Q from elasticsearch_dsl import UpdateByQuery, Search, Index
from util import config from rag.settings import es_logger
from rag import settings
from rag.utils import singleton
logging.info("Elasticsearch version: ", elasticsearch.__version__) es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__))
def instance(env):
CF = config.init(env)
ES_DRESS = CF.get("es").split(",")
ES = Elasticsearch(
ES_DRESS,
timeout=600
)
logging.info("ES: ", ES_DRESS, ES.info())
return ES
@singleton
class HuEs: class HuEs:
def __init__(self, env): def __init__(self):
self.env = env
self.info = {} self.info = {}
self.config = config.init(env)
self.conn() self.conn()
self.idxnm = self.config.get("idx_nm", "") self.idxnm = settings.ES.get("index_name", "")
if not self.es.ping(): if not self.es.ping():
raise Exception("Can't connect to ES cluster") raise Exception("Can't connect to ES cluster")
def conn(self): def conn(self):
for _ in range(10): for _ in range(10):
try: try:
c = instance(self.env) self.es = Elasticsearch(
if c: settings.ES["hosts"].split(","),
self.es = c timeout=600
self.info = c.info() )
logging.info("Connect to es.") if self.es:
self.info = self.es.info()
es_logger.info("Connect to es.")
break break
except Exception as e: except Exception as e:
logging.error("Fail to connect to es: " + str(e)) es_logger.error("Fail to connect to es: " + str(e))
time.sleep(1) time.sleep(1)
def version(self): def version(self):
@ -80,11 +68,11 @@ class HuEs:
refresh=False, refresh=False,
doc_type="_doc", doc_type="_doc",
retry_on_conflict=100) retry_on_conflict=100)
logging.info("Successfully upsert: %s" % id) es_logger.info("Successfully upsert: %s" % id)
T = True T = True
break break
except Exception as e: except Exception as e:
logging.warning("Fail to index: " + es_logger.warning("Fail to index: " +
json.dumps(d, ensure_ascii=False) + str(e)) json.dumps(d, ensure_ascii=False) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
@ -94,7 +82,7 @@ class HuEs:
if not T: if not T:
res.append(d) res.append(d)
logging.error( es_logger.error(
"Fail to index: " + "Fail to index: " +
re.sub( re.sub(
"[\r\n]", "[\r\n]",
@ -147,7 +135,7 @@ class HuEs:
return res return res
except Exception as e: except Exception as e:
logging.warn("Fail to bulk: " + str(e)) es_logger.warn("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
@ -162,7 +150,7 @@ class HuEs:
ids[id] = copy.deepcopy(d["raw"]) ids[id] = copy.deepcopy(d["raw"])
acts.append({"update": {"_id": id, "_index": self.idxnm}}) acts.append({"update": {"_id": id, "_index": self.idxnm}})
acts.append(d["script"]) acts.append(d["script"])
logging.info("bulk upsert: %s" % id) es_logger.info("bulk upsert: %s" % id)
res = [] res = []
for _ in range(10): for _ in range(10):
@ -189,7 +177,7 @@ class HuEs:
return res return res
except Exception as e: except Exception as e:
logging.warning("Fail to bulk: " + str(e)) es_logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
@ -212,10 +200,10 @@ class HuEs:
id=d["id"], id=d["id"],
refresh=True, refresh=True,
doc_type="_doc") doc_type="_doc")
logging.info("Remove %s" % d["id"]) es_logger.info("Remove %s" % d["id"])
return True return True
except Exception as e: except Exception as e:
logging.warn("Fail to delete: " + str(d) + str(e)) es_logger.warn("Fail to delete: " + str(d) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
@ -223,7 +211,7 @@ class HuEs:
return True return True
self.conn() self.conn()
logging.error("Fail to delete: " + str(d)) es_logger.error("Fail to delete: " + str(d))
return False return False
@ -242,7 +230,7 @@ class HuEs:
raise Exception("Es Timeout.") raise Exception("Es Timeout.")
return res return res
except Exception as e: except Exception as e:
logging.error( es_logger.error(
"ES search exception: " + "ES search exception: " +
str(e) + str(e) +
"【Q】" + "【Q】" +
@ -250,7 +238,7 @@ class HuEs:
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
raise e raise e
logging.error("ES search timeout for 3 times!") es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ES search timeout.")
def updateByQuery(self, q, d): def updateByQuery(self, q, d):
@ -267,7 +255,7 @@ class HuEs:
r = ubq.execute() r = ubq.execute()
return True return True
except Exception as e: except Exception as e:
logging.error("ES updateByQuery exception: " + es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict())) str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
@ -288,7 +276,7 @@ class HuEs:
r = ubq.execute() r = ubq.execute()
return True return True
except Exception as e: except Exception as e:
logging.error("ES updateByQuery exception: " + es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict())) str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
@ -304,7 +292,7 @@ class HuEs:
body=Search().query(query).to_dict()) body=Search().query(query).to_dict())
return True return True
except Exception as e: except Exception as e:
logging.error("ES updateByQuery deleteByQuery: " + es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict())) str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
@ -329,7 +317,8 @@ class HuEs:
routing=routing, refresh=False) # , doc_type="_doc") routing=routing, refresh=False) # , doc_type="_doc")
return True return True
except Exception as e: except Exception as e:
logging.error("ES update exception: " + str(e) + " id" + str(id) + ", version:" + str(self.version()) + es_logger.error(
"ES update exception: " + str(e) + " id" + str(id) + ", version:" + str(self.version()) +
json.dumps(script, ensure_ascii=False)) json.dumps(script, ensure_ascii=False))
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
@ -342,7 +331,7 @@ class HuEs:
try: try:
return s.exists() return s.exists()
except Exception as e: except Exception as e:
logging.error("ES updateByQuery indexExist: " + str(e)) es_logger.error("ES updateByQuery indexExist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
@ -354,7 +343,7 @@ class HuEs:
return self.es.exists(index=(idxnm if idxnm else self.idxnm), return self.es.exists(index=(idxnm if idxnm else self.idxnm),
id=docid) id=docid)
except Exception as e: except Exception as e:
logging.error("ES Doc Exist: " + str(e)) es_logger.error("ES Doc Exist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
@ -368,13 +357,13 @@ class HuEs:
settings=mapping["settings"], settings=mapping["settings"],
mappings=mapping["mappings"]) mappings=mapping["mappings"])
except Exception as e: except Exception as e:
logging.error("ES create index error %s ----%s" % (idxnm, str(e))) es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
def deleteIdx(self, idxnm): def deleteIdx(self, idxnm):
try: try:
return self.es.indices.delete(idxnm, allow_no_indices=True) return self.es.indices.delete(idxnm, allow_no_indices=True)
except Exception as e: except Exception as e:
logging.error("ES delete index error %s ----%s" % (idxnm, str(e))) es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
def getTotal(self, res): def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})): if isinstance(res["hits"]["total"], type({})):
@ -405,12 +394,12 @@ class HuEs:
) )
break break
except Exception as e: except Exception as e:
logging.error("ES scrolling fail. " + str(e)) es_logger.error("ES scrolling fail. " + str(e))
time.sleep(3) time.sleep(3)
sid = page['_scroll_id'] sid = page['_scroll_id']
scroll_size = page['hits']['total']["value"] scroll_size = page['hits']['total']["value"]
logging.info("[TOTAL]%d" % scroll_size) es_logger.info("[TOTAL]%d" % scroll_size)
# Start scrolling # Start scrolling
while scroll_size > 0: while scroll_size > 0:
yield page["hits"]["hits"] yield page["hits"]["hits"]
@ -419,10 +408,13 @@ class HuEs:
page = self.es.scroll(scroll_id=sid, scroll=scroll_time) page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
break break
except Exception as e: except Exception as e:
logging.error("ES scrolling fail. " + str(e)) es_logger.error("ES scrolling fail. " + str(e))
time.sleep(3) time.sleep(3)
# Update the scroll ID # Update the scroll ID
sid = page['_scroll_id'] sid = page['_scroll_id']
# Get the number of results that we returned in the last scroll # Get the number of results that we returned in the last scroll
scroll_size = len(page['hits']['hits']) scroll_size = len(page['hits']['hits'])
ELASTICSEARCH = HuEs()

View File

@ -1,13 +1,15 @@
import logging import os
import time import time
from util import config
from minio import Minio from minio import Minio
from io import BytesIO from io import BytesIO
from rag import settings
from rag.settings import minio_logger
from rag.utils import singleton
@singleton
class HuMinio(object): class HuMinio(object):
def __init__(self, env): def __init__(self):
self.config = config.init(env)
self.conn = None self.conn = None
self.__open__() self.__open__()
@ -19,15 +21,14 @@ class HuMinio(object):
pass pass
try: try:
self.conn = Minio(self.config.get("minio_host"), self.conn = Minio(settings.MINIO["host"],
access_key=self.config.get("minio_user"), access_key=settings.MINIO["user"],
secret_key=self.config.get("minio_password"), secret_key=settings.MINIO["passwd"],
secure=False secure=False
) )
except Exception as e: except Exception as e:
logging.error( minio_logger.error(
"Fail to connect %s " % "Fail to connect %s " % settings.MINIO["host"] + str(e))
self.config.get("minio_host") + str(e))
def __close__(self): def __close__(self):
del self.conn del self.conn
@ -45,34 +46,51 @@ class HuMinio(object):
) )
return r return r
except Exception as e: except Exception as e:
logging.error(f"Fail put {bucket}/{fnm}: " + str(e)) minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.remove_object(bucket, fnm)
except Exception as e:
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(10): for _ in range(10):
try: try:
r = self.conn.get_object(bucket, fnm) r = self.conn.get_object(bucket, fnm)
return r.read() return r.read()
except Exception as e: except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: " + str(e)) minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
def obj_exist(self, bucket, fnm):
try:
if self.conn.stat_object(bucket, fnm):return True
return False
except Exception as e:
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
return False
def get_presigned_url(self, bucket, fnm, expires): def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10): for _ in range(10):
try: try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires) return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e: except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: " + str(e)) minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
MINIO = HuMinio()
if __name__ == "__main__": if __name__ == "__main__":
conn = HuMinio("infiniflow") conn = HuMinio()
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg" fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
from PIL import Image from PIL import Image
img = Image.open(fnm) img = Image.open(fnm)

View File

@ -1,178 +0,0 @@
use std::collections::HashMap;
use actix_web::{ HttpResponse, post, web };
use serde::Deserialize;
use serde_json::Value;
use serde_json::json;
use crate::api::JsonResponse;
use crate::AppState;
use crate::errors::AppError;
use crate::service::dialog_info::Query;
use crate::service::dialog_info::Mutation;
#[derive(Debug, Deserialize)]
pub struct ListParams {
pub uid: i64,
pub dialog_id: Option<i64>,
}
#[post("/v1.0/dialogs")]
async fn list(
params: web::Json<ListParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let mut result = HashMap::new();
if let Some(dia_id) = params.dialog_id {
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?.unwrap();
let kb = crate::service::kb_info::Query
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
.unwrap();
print!("{:?}", dia.history);
let hist: Value = serde_json::from_str(&dia.history)?;
let detail =
json!({
"dialog_id": dia_id,
"dialog_name": dia.dialog_name.to_owned(),
"created_at": dia.created_at.to_string().to_owned(),
"updated_at": dia.updated_at.to_string().to_owned(),
"history": hist,
"kb_info": kb
});
result.insert("dialogs", vec![detail]);
} else {
let mut dias = Vec::<Value>::new();
for dia in Query::find_dialog_infos_by_uid(&data.conn, params.uid).await? {
let kb = crate::service::kb_info::Query
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
.unwrap();
let hist: Value = serde_json::from_str(&dia.history)?;
dias.push(
json!({
"dialog_id": dia.dialog_id,
"dialog_name": dia.dialog_name.to_owned(),
"created_at": dia.created_at.to_string().to_owned(),
"updated_at": dia.updated_at.to_string().to_owned(),
"history": hist,
"kb_info": kb
})
);
}
result.insert("dialogs", dias);
}
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Debug, Deserialize)]
pub struct RmParams {
pub uid: i64,
pub dialog_id: i64,
}
#[post("/v1.0/delete_dialog")]
async fn delete(
params: web::Json<RmParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::delete_dialog_info(&data.conn, params.dialog_id).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Debug, Deserialize)]
pub struct CreateParams {
pub uid: i64,
pub dialog_id: Option<i64>,
pub kb_id: i64,
pub name: String,
}
#[post("/v1.0/create_dialog")]
async fn create(
param: web::Json<CreateParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let mut result = HashMap::new();
if let Some(dia_id) = param.dialog_id {
result.insert("dialog_id", dia_id);
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?;
let _ = Mutation::update_dialog_info_by_id(
&data.conn,
dia_id,
&param.name,
&dia.unwrap().history
).await?;
} else {
let dia = Mutation::create_dialog_info(
&data.conn,
param.uid,
param.kb_id,
&param.name
).await?;
result.insert("dialog_id", dia.dialog_id.unwrap());
}
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Debug, Deserialize)]
pub struct UpdateHistoryParams {
pub uid: i64,
pub dialog_id: i64,
pub history: Value,
}
#[post("/v1.0/update_history")]
async fn update_history(
param: web::Json<UpdateHistoryParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let mut json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
if let Some(dia) = Query::find_dialog_info_by_id(&data.conn, param.dialog_id).await? {
let _ = Mutation::update_dialog_info_by_id(
&data.conn,
param.dialog_id,
&dia.dialog_name,
&param.history.to_string()
).await?;
} else {
json_response.code = 500;
json_response.err = "Can't find dialog data!".to_owned();
}
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}

View File

@ -1,314 +0,0 @@
use std::collections::{HashMap};
use std::io::BufReader;
use actix_multipart_extract::{ File, Multipart, MultipartForm };
use actix_web::web::Bytes;
use actix_web::{ HttpResponse, post, web };
use chrono::{ Utc, FixedOffset };
use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
use sea_orm::DbConn;
use crate::api::JsonResponse;
use crate::AppState;
use crate::entity::doc_info::Model;
use crate::errors::AppError;
use crate::service::doc_info::{ Mutation, Query };
use serde::Deserialize;
use regex::Regex;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
#[derive(Debug, Deserialize)]
pub struct ListParams {
pub uid: i64,
pub filter: FilterParams,
pub sortby: String,
pub page: Option<u32>,
pub per_page: Option<u32>,
}
#[derive(Debug, Deserialize)]
pub struct FilterParams {
pub keywords: Option<String>,
pub folder_id: Option<i64>,
pub tag_id: Option<i64>,
pub kb_id: Option<i64>,
}
#[post("/v1.0/docs")]
async fn list(
params: web::Json<ListParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let docs = Query::find_doc_infos_by_params(&data.conn, params.into_inner()).await?;
let mut result = HashMap::new();
result.insert("docs", docs);
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Deserialize, MultipartForm, Debug)]
pub struct UploadForm {
#[multipart(max_size = 512MB)]
file_field: File,
uid: i64,
did: i64,
}
fn file_type(filename: &String) -> String {
let fnm = filename.to_lowercase();
if
let Some(_) = Regex::new(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$")
.unwrap()
.captures(&fnm)
{
return "Video".to_owned();
}
if
let Some(_) = Regex::new(
r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$"
)
.unwrap()
.captures(&fnm)
{
return "Picture".to_owned();
}
if
let Some(_) = Regex::new(r"\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$")
.unwrap()
.captures(&fnm)
{
return "Music".to_owned();
}
if
let Some(_) = Regex::new(r"\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$")
.unwrap()
.captures(&fnm)
{
return "Document".to_owned();
}
"Other".to_owned()
}
#[post("/v1.0/upload")]
async fn upload(
payload: Multipart<UploadForm>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
Ok(HttpResponse::Ok().body("File uploaded successfully"))
}
pub(crate) async fn _upload_file(uid: i64, did: i64, file_name: &str, bytes: &[u8], data: &web::Data<AppState>) -> Result<(), AppError> {
async fn add_number_to_filename(
file_name: &str,
conn: &DbConn,
uid: i64,
parent_id: i64
) -> String {
let mut i = 0;
let mut new_file_name = file_name.to_string();
let arr: Vec<&str> = file_name.split(".").collect();
let suffix = String::from(arr[arr.len() - 1]);
let preffix = arr[..arr.len() - 1].join(".");
let mut docs = Query::find_doc_infos_by_name(
conn,
uid,
&new_file_name,
Some(parent_id)
).await.unwrap();
while docs.len() > 0 {
i += 1;
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
docs = Query::find_doc_infos_by_name(
conn,
uid,
&new_file_name,
Some(parent_id)
).await.unwrap();
}
new_file_name
}
let fnm = add_number_to_filename(file_name, &data.conn, uid, did).await;
let bucket_name = format!("{}-upload", uid);
let s3_client: &minio::s3::client::Client = &data.s3_client;
let buckets_exists = s3_client
.bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
.unwrap();
if !buckets_exists {
print!("Create bucket: {}", bucket_name.clone());
s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
} else {
print!("Existing bucket: {}", bucket_name.clone());
}
let location = format!("/{}/{}", did, fnm)
.as_bytes()
.to_vec()
.iter()
.map(|b| format!("{:02x}", b).to_string())
.collect::<Vec<String>>()
.join("");
print!("===>{}", location.clone());
s3_client.put_object(
&mut PutObjectArgs::new(
&bucket_name,
&location,
&mut BufReader::new(bytes),
Some(bytes.len()),
None
)?
).await?;
let doc = Mutation::create_doc_info(&data.conn, Model {
did: Default::default(),
uid: uid,
doc_name: fnm.clone(),
size: bytes.len() as i64,
location,
r#type: file_type(&fnm),
thumbnail_base64: Default::default(),
created_at: now(),
updated_at: now(),
is_deleted: Default::default(),
}).await?;
let _ = Mutation::place_doc(&data.conn, did, doc.did.unwrap()).await?;
Ok(())
}
#[derive(Deserialize, Debug)]
pub struct RmDocsParam {
uid: i64,
dids: Vec<i64>,
}
#[post("/v1.0/delete_docs")]
async fn delete(
params: web::Json<RmDocsParam>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::delete_doc_info(&data.conn, &params.dids).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Debug, Deserialize)]
pub struct MvParams {
pub uid: i64,
pub dids: Vec<i64>,
pub dest_did: i64,
}
#[post("/v1.0/mv_docs")]
async fn mv(
params: web::Json<MvParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
Mutation::mv_doc_info(&data.conn, params.dest_did, &params.dids).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Debug, Deserialize)]
pub struct NewFoldParams {
pub uid: i64,
pub parent_id: i64,
pub name: String,
}
#[post("/v1.0/new_folder")]
async fn new_folder(
params: web::Json<NewFoldParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let doc = Mutation::create_doc_info(&data.conn, Model {
did: Default::default(),
uid: params.uid,
doc_name: params.name.to_string(),
size: 0,
r#type: "folder".to_string(),
location: "".to_owned(),
thumbnail_base64: Default::default(),
created_at: now(),
updated_at: now(),
is_deleted: Default::default(),
}).await?;
let _ = Mutation::place_doc(&data.conn, params.parent_id, doc.did.unwrap()).await?;
Ok(HttpResponse::Ok().body("Folder created successfully"))
}
#[derive(Debug, Deserialize)]
pub struct RenameParams {
pub uid: i64,
pub did: i64,
pub name: String,
}
#[post("/v1.0/rename")]
async fn rename(
params: web::Json<RenameParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let docs = Query::find_doc_infos_by_name(&data.conn, params.uid, &params.name, None).await?;
if docs.len() > 0 {
let json_response = JsonResponse {
code: 500,
err: "Name duplicated!".to_owned(),
data: (),
};
return Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
);
}
let doc = Mutation::rename(&data.conn, params.did, &params.name).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: doc,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}

View File

@ -1,166 +0,0 @@
use std::collections::HashMap;
use actix_web::{ get, HttpResponse, post, web };
use serde::Serialize;
use crate::api::JsonResponse;
use crate::AppState;
use crate::entity::kb_info;
use crate::errors::AppError;
use crate::service::kb_info::Mutation;
use crate::service::kb_info::Query;
use serde::Deserialize;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AddDocs2KbParams {
pub uid: i64,
pub dids: Vec<i64>,
pub kb_id: i64,
}
#[post("/v1.0/create_kb")]
async fn create(
model: web::Json<kb_info::Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let mut docs = Query::find_kb_infos_by_name(
&data.conn,
model.kb_name.to_owned()
).await.unwrap();
if docs.len() > 0 {
let json_response = JsonResponse {
code: 201,
err: "Duplicated name.".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
} else {
let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?;
let mut result = HashMap::new();
result.insert("kb_id", model.kb_id.unwrap());
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
}
#[post("/v1.0/add_docs_to_kb")]
async fn add_docs_to_kb(
param: web::Json<AddDocs2KbParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::add_docs(&data.conn, param.kb_id, param.dids.to_owned()).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[post("/v1.0/anti_kb_docs")]
async fn anti_kb_docs(
param: web::Json<AddDocs2KbParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::remove_docs(&data.conn, param.dids.to_owned(), Some(param.kb_id)).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[get("/v1.0/kbs")]
async fn list(
model: web::Json<kb_info::Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let kbs = Query::find_kb_infos_by_uid(&data.conn, model.uid).await?;
let mut result = HashMap::new();
result.insert("kbs", kbs);
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[post("/v1.0/delete_kb")]
async fn delete(
model: web::Json<kb_info::Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::delete_kb_info(&data.conn, model.kb_id).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DocIdsParams {
pub uid: i64,
pub dids: Vec<i64>,
}
#[post("/v1.0/all_relevents")]
async fn all_relevents(
params: web::Json<DocIdsParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let dids = crate::service::doc_info::Query::all_descendent_ids(&data.conn, &params.dids).await?;
let mut result = HashMap::new();
let kbs = Query::find_kb_by_docs(&data.conn, dids).await?;
result.insert("kbs", kbs);
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}

View File

@ -1,14 +0,0 @@
use serde::{ Deserialize, Serialize };
pub(crate) mod tag_info;
pub(crate) mod kb_info;
pub(crate) mod dialog_info;
pub(crate) mod doc_info;
pub(crate) mod user_info;
#[derive(Debug, Deserialize, Serialize)]
struct JsonResponse<T> {
code: u32,
err: String,
data: T,
}

View File

@ -1,81 +0,0 @@
use std::collections::HashMap;
use actix_web::{ get, HttpResponse, post, web };
use serde::Deserialize;
use crate::api::JsonResponse;
use crate::AppState;
use crate::entity::tag_info;
use crate::errors::AppError;
use crate::service::tag_info::{ Mutation, Query };
#[derive(Debug, Deserialize)]
pub struct TagListParams {
pub uid: i64,
}
#[post("/v1.0/create_tag")]
async fn create(
model: web::Json<tag_info::Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let model = Mutation::create_tag(&data.conn, model.into_inner()).await?;
let mut result = HashMap::new();
result.insert("tid", model.tid.unwrap());
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[post("/v1.0/delete_tag")]
async fn delete(
model: web::Json<tag_info::Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::delete_tag(&data.conn, model.tid).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
//#[get("/v1.0/tags", wrap = "HttpAuthentication::bearer(validator)")]
#[post("/v1.0/tags")]
async fn list(
param: web::Json<TagListParams>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let tags = Query::find_tags_by_uid(param.uid, &data.conn).await?;
let mut result = HashMap::new();
result.insert("tags", tags);
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}

View File

@ -1,149 +0,0 @@
use std::collections::HashMap;
use std::io::SeekFrom;
use std::ptr::null;
use actix_identity::Identity;
use actix_web::{ HttpResponse, post, web };
use chrono::{ FixedOffset, Utc };
use sea_orm::ActiveValue::NotSet;
use serde::{ Deserialize, Serialize };
use crate::api::JsonResponse;
use crate::AppState;
use crate::entity::{ doc_info, tag_info };
use crate::entity::user_info::Model;
use crate::errors::{ AppError, UserError };
use crate::service::user_info::Mutation;
use crate::service::user_info::Query;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub(crate) fn create_auth_token(user: &Model) -> u64 {
use std::{ collections::hash_map::DefaultHasher, hash::{ Hash, Hasher } };
let mut hasher = DefaultHasher::new();
user.hash(&mut hasher);
hasher.finish()
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct LoginParams {
pub(crate) email: String,
pub(crate) password: String,
}
#[post("/v1.0/login")]
async fn login(
data: web::Data<AppState>,
identity: Identity,
input: web::Json<LoginParams>
) -> Result<HttpResponse, AppError> {
match Query::login(&data.conn, &input.email, &input.password).await? {
Some(user) => {
let _ = Mutation::update_login_status(user.uid, &data.conn).await?;
let token = create_auth_token(&user).to_string();
identity.remember(token.clone());
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: token.clone(),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.append_header(("X-Auth-Token", token))
.body(serde_json::to_string(&json_response)?)
)
}
None => Err(UserError::LoginFailed.into()),
}
}
#[post("/v1.0/register")]
async fn register(
model: web::Json<Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let mut result = HashMap::new();
let u = Query::find_user_infos(&data.conn, &model.email).await?;
if let Some(_) = u {
let json_response = JsonResponse {
code: 500,
err: "Email registered!".to_owned(),
data: (),
};
return Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
);
}
let usr = Mutation::create_user(&data.conn, &model).await?;
result.insert("uid", usr.uid.clone().unwrap());
crate::service::doc_info::Mutation::create_doc_info(&data.conn, doc_info::Model {
did: Default::default(),
uid: usr.uid.clone().unwrap(),
doc_name: "/".into(),
size: 0,
location: "".into(),
thumbnail_base64: "".into(),
r#type: "folder".to_string(),
created_at: now(),
updated_at: now(),
is_deleted: Default::default(),
}).await?;
let tnm = vec!["Video", "Picture", "Music", "Document"];
let tregx = vec![
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)",
".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)",
".*\\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)",
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)"
];
for i in 0..4 {
crate::service::tag_info::Mutation::create_tag(&data.conn, tag_info::Model {
tid: Default::default(),
uid: usr.uid.clone().unwrap(),
tag_name: tnm[i].to_owned(),
regx: tregx[i].to_owned(),
color: (i + 1).to_owned() as i16,
icon: (i + 1).to_owned() as i16,
folder_id: 0,
created_at: Default::default(),
updated_at: Default::default(),
}).await?;
}
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}
#[post("/v1.0/setting")]
async fn setting(
model: web::Json<Model>,
data: web::Data<AppState>
) -> Result<HttpResponse, AppError> {
let _ = Mutation::update_user_by_id(&data.conn, &model).await?;
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(
HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response)?)
)
}

View File

@ -1,38 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "dialog2_kb")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = true)]
pub id: i64,
#[sea_orm(index)]
pub dialog_id: i64,
#[sea_orm(index)]
pub kb_id: i64,
}
#[derive(Debug, Clone, Copy, EnumIter)]
pub enum Relation {
DialogInfo,
KbInfo,
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::DialogInfo =>
Entity::belongs_to(super::dialog_info::Entity)
.from(Column::DialogId)
.to(super::dialog_info::Column::DialogId)
.into(),
Self::KbInfo =>
Entity::belongs_to(super::kb_info::Entity)
.from(Column::KbId)
.to(super::kb_info::Column::KbId)
.into(),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,38 +0,0 @@
use chrono::{ DateTime, FixedOffset };
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "dialog_info")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub dialog_id: i64,
#[sea_orm(index)]
pub uid: i64,
#[serde(skip_deserializing)]
pub kb_id: i64,
pub dialog_name: String,
pub history: String,
#[serde(skip_deserializing)]
pub created_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub is_deleted: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl Related<super::kb_info::Entity> for Entity {
fn to() -> RelationDef {
super::dialog2_kb::Relation::KbInfo.def()
}
fn via() -> Option<RelationDef> {
Some(super::dialog2_kb::Relation::DialogInfo.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,38 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "doc2_doc")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = true)]
pub id: i64,
#[sea_orm(index)]
pub parent_id: i64,
#[sea_orm(index)]
pub did: i64,
}
#[derive(Debug, Clone, Copy, EnumIter)]
pub enum Relation {
Parent,
Child,
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::Parent =>
Entity::belongs_to(super::doc_info::Entity)
.from(Column::ParentId)
.to(super::doc_info::Column::Did)
.into(),
Self::Child =>
Entity::belongs_to(super::doc_info::Entity)
.from(Column::Did)
.to(super::doc_info::Column::Did)
.into(),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,62 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
use crate::entity::kb_info;
use chrono::{ DateTime, FixedOffset };
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "doc_info")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub did: i64,
#[sea_orm(index)]
pub uid: i64,
pub doc_name: String,
pub size: i64,
#[sea_orm(column_name = "type")]
pub r#type: String,
#[serde(skip_deserializing)]
pub location: String,
#[serde(skip_deserializing)]
pub thumbnail_base64: String,
#[serde(skip_deserializing)]
pub created_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub is_deleted: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl Related<super::tag_info::Entity> for Entity {
fn to() -> RelationDef {
super::tag2_doc::Relation::Tag.def()
}
fn via() -> Option<RelationDef> {
Some(super::tag2_doc::Relation::DocInfo.def().rev())
}
}
impl Related<super::kb_info::Entity> for Entity {
fn to() -> RelationDef {
super::kb2_doc::Relation::KbInfo.def()
}
fn via() -> Option<RelationDef> {
Some(super::kb2_doc::Relation::DocInfo.def().rev())
}
}
impl Related<super::doc2_doc::Entity> for Entity {
fn to() -> RelationDef {
super::doc2_doc::Relation::Parent.def()
}
fn via() -> Option<RelationDef> {
Some(super::doc2_doc::Relation::Child.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,47 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
use chrono::{ DateTime, FixedOffset };
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "kb2_doc")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = true)]
pub id: i64,
#[sea_orm(index)]
pub kb_id: i64,
#[sea_orm(index)]
pub did: i64,
#[serde(skip_deserializing)]
pub kb_progress: f32,
#[serde(skip_deserializing)]
pub kb_progress_msg: String,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub is_deleted: bool,
}
#[derive(Debug, Clone, Copy, EnumIter)]
pub enum Relation {
DocInfo,
KbInfo,
}
impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
match self {
Self::DocInfo =>
Entity::belongs_to(super::doc_info::Entity)
.from(Column::Did)
.to(super::doc_info::Column::Did)
.into(),
Self::KbInfo =>
Entity::belongs_to(super::kb_info::Entity)
.from(Column::KbId)
.to(super::kb_info::Column::KbId)
.into(),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,47 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
use chrono::{ DateTime, FixedOffset };
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "kb_info")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
#[serde(skip_deserializing)]
pub kb_id: i64,
#[sea_orm(index)]
pub uid: i64,
pub kb_name: String,
pub icon: i16,
#[serde(skip_deserializing)]
pub created_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub is_deleted: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl Related<super::doc_info::Entity> for Entity {
fn to() -> RelationDef {
super::kb2_doc::Relation::DocInfo.def()
}
fn via() -> Option<RelationDef> {
Some(super::kb2_doc::Relation::KbInfo.def().rev())
}
}
impl Related<super::dialog_info::Entity> for Entity {
fn to() -> RelationDef {
super::dialog2_kb::Relation::DialogInfo.def()
}
fn via() -> Option<RelationDef> {
Some(super::dialog2_kb::Relation::KbInfo.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,9 +0,0 @@
pub(crate) mod user_info;
pub(crate) mod tag_info;
pub(crate) mod tag2_doc;
pub(crate) mod kb2_doc;
pub(crate) mod dialog2_kb;
pub(crate) mod doc2_doc;
pub(crate) mod kb_info;
pub(crate) mod doc_info;
pub(crate) mod dialog_info;

View File

@ -1,38 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "tag2_doc")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = true)]
pub id: i64,
#[sea_orm(index)]
pub tag_id: i64,
#[sea_orm(index)]
pub did: i64,
}
#[derive(Debug, Clone, Copy, EnumIter)]
pub enum Relation {
Tag,
DocInfo,
}
impl RelationTrait for Relation {
fn def(&self) -> sea_orm::RelationDef {
match self {
Self::Tag =>
Entity::belongs_to(super::tag_info::Entity)
.from(Column::TagId)
.to(super::tag_info::Column::Tid)
.into(),
Self::DocInfo =>
Entity::belongs_to(super::doc_info::Entity)
.from(Column::Did)
.to(super::doc_info::Column::Did)
.into(),
}
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,40 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
use chrono::{ DateTime, FixedOffset };
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "tag_info")]
pub struct Model {
#[sea_orm(primary_key)]
#[serde(skip_deserializing)]
pub tid: i64,
#[sea_orm(index)]
pub uid: i64,
pub tag_name: String,
#[serde(skip_deserializing)]
pub regx: String,
pub color: i16,
pub icon: i16,
#[serde(skip_deserializing)]
pub folder_id: i64,
#[serde(skip_deserializing)]
pub created_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl Related<super::doc_info::Entity> for Entity {
fn to() -> RelationDef {
super::tag2_doc::Relation::DocInfo.def()
}
fn via() -> Option<RelationDef> {
Some(super::tag2_doc::Relation::Tag.def().rev())
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,30 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::{ Deserialize, Serialize };
use chrono::{ DateTime, FixedOffset };
#[derive(Clone, Debug, PartialEq, Eq, Hash, DeriveEntityModel, Deserialize, Serialize)]
#[sea_orm(table_name = "user_info")]
pub struct Model {
#[sea_orm(primary_key)]
#[serde(skip_deserializing)]
pub uid: i64,
pub email: String,
pub nickname: String,
pub avatar_base64: String,
pub color_scheme: String,
pub list_style: String,
pub language: String,
pub password: String,
#[serde(skip_deserializing)]
pub last_login_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub created_at: DateTime<FixedOffset>,
#[serde(skip_deserializing)]
pub updated_at: DateTime<FixedOffset>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -1,83 +0,0 @@
use actix_web::{HttpResponse, ResponseError};
use thiserror::Error;
#[derive(Debug, Error)]
pub(crate) enum AppError {
#[error("`{0}`")]
User(#[from] UserError),
#[error("`{0}`")]
Json(#[from] serde_json::Error),
#[error("`{0}`")]
Actix(#[from] actix_web::Error),
#[error("`{0}`")]
Db(#[from] sea_orm::DbErr),
#[error("`{0}`")]
MinioS3(#[from] minio::s3::error::Error),
#[error("`{0}`")]
Std(#[from] std::io::Error),
}
#[derive(Debug, Error)]
pub(crate) enum UserError {
#[error("`username` field of `User` cannot be empty!")]
EmptyUsername,
#[error("`username` field of `User` cannot contain whitespaces!")]
UsernameInvalidCharacter,
#[error("`password` field of `User` cannot be empty!")]
EmptyPassword,
#[error("`password` field of `User` cannot contain whitespaces!")]
PasswordInvalidCharacter,
#[error("Could not find any `User` for id: `{0}`!")]
NotFound(i64),
#[error("Failed to login user!")]
LoginFailed,
#[error("User is not logged in!")]
NotLoggedIn,
#[error("Invalid authorization token!")]
InvalidToken,
#[error("Could not find any `User`!")]
Empty,
}
impl ResponseError for AppError {
fn status_code(&self) -> actix_web::http::StatusCode {
match self {
AppError::User(user_error) => match user_error {
UserError::EmptyUsername => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
UserError::UsernameInvalidCharacter => {
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
}
UserError::EmptyPassword => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
UserError::PasswordInvalidCharacter => {
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
}
UserError::NotFound(_) => actix_web::http::StatusCode::NOT_FOUND,
UserError::NotLoggedIn => actix_web::http::StatusCode::UNAUTHORIZED,
UserError::Empty => actix_web::http::StatusCode::NOT_FOUND,
UserError::LoginFailed => actix_web::http::StatusCode::NOT_FOUND,
UserError::InvalidToken => actix_web::http::StatusCode::UNAUTHORIZED,
},
AppError::Actix(fail) => fail.as_response_error().status_code(),
_ => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn error_response(&self) -> HttpResponse {
let status_code = self.status_code();
let response = HttpResponse::build(status_code).body(self.to_string());
response
}
}

View File

@ -1,145 +0,0 @@
mod api;
mod entity;
mod service;
mod errors;
mod web_socket;
use std::env;
use actix_files::Files;
use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
use actix_session::CookieSession;
use actix_web::{ web, App, HttpServer, middleware, Error };
use actix_web::cookie::time::Duration;
use actix_web::dev::ServiceRequest;
use actix_web::error::ErrorUnauthorized;
use actix_web_httpauth::extractors::bearer::BearerAuth;
use listenfd::ListenFd;
use minio::s3::client::Client;
use minio::s3::creds::StaticProvider;
use minio::s3::http::BaseUrl;
use sea_orm::{ Database, DatabaseConnection };
use migration::{ Migrator, MigratorTrait };
use crate::errors::{ AppError, UserError };
use crate::web_socket::doc_info::upload_file_ws;
#[derive(Debug, Clone)]
struct AppState {
conn: DatabaseConnection,
s3_client: Client,
}
pub(crate) async fn validator(
req: ServiceRequest,
credentials: BearerAuth
) -> Result<ServiceRequest, Error> {
if let Some(token) = req.get_identity() {
println!("{}, {}", credentials.token(), token);
(credentials.token() == token)
.then(|| req)
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
} else {
Err(ErrorUnauthorized(UserError::NotLoggedIn))
}
}
#[actix_web::main]
async fn main() -> Result<(), AppError> {
std::env::set_var("RUST_LOG", "debug");
tracing_subscriber::fmt::init();
// get env vars
dotenvy::dotenv().ok();
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file");
let host = env::var("HOST").expect("HOST is not set in .env file");
let port = env::var("PORT").expect("PORT is not set in .env file");
let server_url = format!("{host}:{port}");
let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
if s3_base_url.find("http") != Some(0) {
s3_base_url = format!("http://{}", s3_base_url);
}
// establish connection to database and apply migrations
// -> create post table if not exists
let conn = Database::connect(&db_url).await.unwrap();
Migrator::up(&conn, None).await.unwrap();
let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
let s3_client = Client::new(
s3_base_url.parse::<BaseUrl>()?,
Some(Box::new(static_provider)),
None,
Some(true)
)?;
let state = AppState { conn, s3_client };
// create server and try to serve over socket if possible
let mut listenfd = ListenFd::from_env();
let mut server = HttpServer::new(move || {
App::new()
.service(Files::new("/static", "./static"))
.app_data(web::Data::new(state.clone()))
.wrap(
IdentityService::new(
CookieIdentityPolicy::new(&[0; 32])
.name("auth-cookie")
.login_deadline(Duration::seconds(120))
.secure(false)
)
)
.wrap(
CookieSession::signed(&[0; 32])
.name("session-cookie")
.secure(false)
// WARNING(alex): This uses the `time` crate, not `std::time`!
.expires_in_time(Duration::seconds(60))
)
.wrap(middleware::Logger::default())
.configure(init)
});
server = match listenfd.take_tcp_listener(0)? {
Some(listener) => server.listen(listener)?,
None => server.bind(&server_url)?,
};
println!("Starting server at {server_url}");
server.run().await?;
Ok(())
}
fn init(cfg: &mut web::ServiceConfig) {
cfg.service(api::tag_info::create);
cfg.service(api::tag_info::delete);
cfg.service(api::tag_info::list);
cfg.service(api::kb_info::create);
cfg.service(api::kb_info::delete);
cfg.service(api::kb_info::list);
cfg.service(api::kb_info::add_docs_to_kb);
cfg.service(api::kb_info::anti_kb_docs);
cfg.service(api::kb_info::all_relevents);
cfg.service(api::doc_info::list);
cfg.service(api::doc_info::delete);
cfg.service(api::doc_info::mv);
cfg.service(api::doc_info::upload);
cfg.service(api::doc_info::new_folder);
cfg.service(api::doc_info::rename);
cfg.service(api::dialog_info::list);
cfg.service(api::dialog_info::delete);
cfg.service(api::dialog_info::create);
cfg.service(api::dialog_info::update_history);
cfg.service(api::user_info::login);
cfg.service(api::user_info::register);
cfg.service(api::user_info::setting);
cfg.service(web::resource("/ws-upload-doc").route(web::get().to(upload_file_ws)));
}

View File

@ -1,107 +0,0 @@
use chrono::{ Local, FixedOffset, Utc };
use migration::Expr;
use sea_orm::{
ActiveModelTrait,
DbConn,
DbErr,
DeleteResult,
EntityTrait,
PaginatorTrait,
QueryOrder,
UpdateResult,
};
use sea_orm::ActiveValue::Set;
use sea_orm::QueryFilter;
use sea_orm::ColumnTrait;
use crate::entity::dialog_info;
use crate::entity::dialog_info::Entity;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub struct Query;
impl Query {
pub async fn find_dialog_info_by_id(
db: &DbConn,
id: i64
) -> Result<Option<dialog_info::Model>, DbErr> {
Entity::find_by_id(id).one(db).await
}
pub async fn find_dialog_infos(db: &DbConn) -> Result<Vec<dialog_info::Model>, DbErr> {
Entity::find().all(db).await
}
pub async fn find_dialog_infos_by_uid(
db: &DbConn,
uid: i64
) -> Result<Vec<dialog_info::Model>, DbErr> {
Entity::find()
.filter(dialog_info::Column::Uid.eq(uid))
.filter(dialog_info::Column::IsDeleted.eq(false))
.all(db).await
}
pub async fn find_dialog_infos_in_page(
db: &DbConn,
page: u64,
posts_per_page: u64
) -> Result<(Vec<dialog_info::Model>, u64), DbErr> {
// Setup paginator
let paginator = Entity::find()
.order_by_asc(dialog_info::Column::DialogId)
.paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await?;
// Fetch paginated posts
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
}
}
pub struct Mutation;
impl Mutation {
pub async fn create_dialog_info(
db: &DbConn,
uid: i64,
kb_id: i64,
name: &String
) -> Result<dialog_info::ActiveModel, DbErr> {
(dialog_info::ActiveModel {
dialog_id: Default::default(),
uid: Set(uid),
kb_id: Set(kb_id),
dialog_name: Set(name.to_owned()),
history: Set("".to_owned()),
created_at: Set(now()),
updated_at: Set(now()),
is_deleted: Default::default(),
}).save(db).await
}
pub async fn update_dialog_info_by_id(
db: &DbConn,
dialog_id: i64,
dialog_name: &String,
history: &String
) -> Result<UpdateResult, DbErr> {
Entity::update_many()
.col_expr(dialog_info::Column::DialogName, Expr::value(dialog_name))
.col_expr(dialog_info::Column::History, Expr::value(history))
.col_expr(dialog_info::Column::UpdatedAt, Expr::value(now()))
.filter(dialog_info::Column::DialogId.eq(dialog_id))
.exec(db).await
}
pub async fn delete_dialog_info(db: &DbConn, dialog_id: i64) -> Result<UpdateResult, DbErr> {
Entity::update_many()
.col_expr(dialog_info::Column::IsDeleted, Expr::value(true))
.filter(dialog_info::Column::DialogId.eq(dialog_id))
.exec(db).await
}
pub async fn delete_all_dialog_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
Entity::delete_many().exec(db).await
}
}

View File

@ -1,335 +0,0 @@
use chrono::{ Utc, FixedOffset };
use sea_orm::{
ActiveModelTrait,
ColumnTrait,
DbConn,
DbErr,
DeleteResult,
EntityTrait,
PaginatorTrait,
QueryOrder,
Unset,
Unchanged,
ConditionalStatement,
QuerySelect,
JoinType,
RelationTrait,
DbBackend,
Statement,
UpdateResult,
};
use sea_orm::ActiveValue::Set;
use sea_orm::QueryFilter;
use crate::api::doc_info::ListParams;
use crate::entity::{ doc2_doc, doc_info };
use crate::entity::doc_info::Entity;
use crate::service;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub struct Query;
impl Query {
pub async fn find_doc_info_by_id(
db: &DbConn,
id: i64
) -> Result<Option<doc_info::Model>, DbErr> {
Entity::find_by_id(id).one(db).await
}
pub async fn find_doc_infos(db: &DbConn) -> Result<Vec<doc_info::Model>, DbErr> {
Entity::find().all(db).await
}
pub async fn find_doc_infos_by_uid(
db: &DbConn,
uid: i64
) -> Result<Vec<doc_info::Model>, DbErr> {
Entity::find().filter(doc_info::Column::Uid.eq(uid)).all(db).await
}
pub async fn find_doc_infos_by_name(
db: &DbConn,
uid: i64,
name: &String,
parent_id: Option<i64>
) -> Result<Vec<doc_info::Model>, DbErr> {
let mut dids = Vec::<i64>::new();
if let Some(pid) = parent_id {
for d2d in doc2_doc::Entity
::find()
.filter(doc2_doc::Column::ParentId.eq(pid))
.all(db).await? {
dids.push(d2d.did);
}
} else {
let doc = Entity::find()
.filter(doc_info::Column::DocName.eq(name.clone()))
.filter(doc_info::Column::Uid.eq(uid))
.all(db).await?;
if doc.len() == 0 {
return Ok(vec![]);
}
assert!(doc.len() > 0);
let d2d = doc2_doc::Entity
::find()
.filter(doc2_doc::Column::Did.eq(doc[0].did))
.all(db).await?;
assert!(d2d.len() <= 1, "Did: {}->{}", doc[0].did, d2d.len());
if d2d.len() > 0 {
for d2d_ in doc2_doc::Entity
::find()
.filter(doc2_doc::Column::ParentId.eq(d2d[0].parent_id))
.all(db).await? {
dids.push(d2d_.did);
}
}
}
Entity::find()
.filter(doc_info::Column::DocName.eq(name.clone()))
.filter(doc_info::Column::Uid.eq(uid))
.filter(doc_info::Column::Did.is_in(dids))
.filter(doc_info::Column::IsDeleted.eq(false))
.all(db).await
}
pub async fn all_descendent_ids(db: &DbConn, doc_ids: &Vec<i64>) -> Result<Vec<i64>, DbErr> {
let mut dids = doc_ids.clone();
let mut i: usize = 0;
loop {
if dids.len() == i {
break;
}
for d in doc2_doc::Entity
::find()
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
.all(db).await? {
dids.push(d.did);
}
i += 1;
}
Ok(dids)
}
pub async fn find_doc_infos_by_params(
db: &DbConn,
params: ListParams
) -> Result<Vec<doc_info::Model>, DbErr> {
// Setup paginator
let mut sql: String =
"
select
a.did,
a.uid,
a.doc_name,
a.location,
a.size,
a.type,
a.created_at,
a.updated_at,
a.is_deleted
from
doc_info as a
".to_owned();
let mut cond: String = format!(" a.uid={} and a.is_deleted=False ", params.uid);
if let Some(kb_id) = params.filter.kb_id {
sql.push_str(
&format!(" inner join kb2_doc on kb2_doc.did = a.did and kb2_doc.kb_id={}", kb_id)
);
}
if let Some(folder_id) = params.filter.folder_id {
sql.push_str(
&format!(" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}", folder_id)
);
}
// Fetch paginated posts
if let Some(tag_id) = params.filter.tag_id {
let tag = service::tag_info::Query
::find_tag_info_by_id(tag_id, &db).await
.unwrap()
.unwrap();
if tag.folder_id > 0 {
sql.push_str(
&format!(
" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}",
tag.folder_id
)
);
}
if tag.regx.len() > 0 {
cond.push_str(&format!(" and (type='{}' or doc_name ~ '{}') ", tag.tag_name, tag.regx));
}
}
if let Some(keywords) = params.filter.keywords {
cond.push_str(&format!(" and doc_name like '%{}%'", keywords));
}
if cond.len() > 0 {
sql.push_str(&" where ");
sql.push_str(&cond);
}
let mut orderby = params.sortby.clone();
if orderby.len() == 0 {
orderby = "updated_at desc".to_owned();
}
sql.push_str(&format!(" order by {}", orderby));
let mut page_size: u32 = 30;
if let Some(pg_sz) = params.per_page {
page_size = pg_sz;
}
let mut page: u32 = 0;
if let Some(pg) = params.page {
page = pg;
}
sql.push_str(&format!(" limit {} offset {} ;", page_size, page * page_size));
print!("{}", sql);
Entity::find()
.from_raw_sql(Statement::from_sql_and_values(DbBackend::Postgres, sql, vec![]))
.all(db).await
}
pub async fn find_doc_infos_in_page(
db: &DbConn,
page: u64,
posts_per_page: u64
) -> Result<(Vec<doc_info::Model>, u64), DbErr> {
// Setup paginator
let paginator = Entity::find()
.order_by_asc(doc_info::Column::Did)
.paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await?;
// Fetch paginated posts
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
}
}
pub struct Mutation;
impl Mutation {
pub async fn mv_doc_info(db: &DbConn, dest_did: i64, dids: &[i64]) -> Result<(), DbErr> {
for did in dids {
let d = doc2_doc::Entity
::find()
.filter(doc2_doc::Column::Did.eq(did.to_owned()))
.all(db).await?;
let _ = (doc2_doc::ActiveModel {
id: Set(d[0].id),
did: Set(did.to_owned()),
parent_id: Set(dest_did),
}).update(db).await?;
}
Ok(())
}
pub async fn place_doc(
db: &DbConn,
dest_did: i64,
did: i64
) -> Result<doc2_doc::ActiveModel, DbErr> {
(doc2_doc::ActiveModel {
id: Default::default(),
parent_id: Set(dest_did),
did: Set(did),
}).save(db).await
}
pub async fn create_doc_info(
db: &DbConn,
form_data: doc_info::Model
) -> Result<doc_info::ActiveModel, DbErr> {
(doc_info::ActiveModel {
did: Default::default(),
uid: Set(form_data.uid.to_owned()),
doc_name: Set(form_data.doc_name.to_owned()),
size: Set(form_data.size.to_owned()),
r#type: Set(form_data.r#type.to_owned()),
location: Set(form_data.location.to_owned()),
thumbnail_base64: Default::default(),
created_at: Set(form_data.created_at.to_owned()),
updated_at: Set(form_data.updated_at.to_owned()),
is_deleted: Default::default(),
}).save(db).await
}
pub async fn update_doc_info_by_id(
db: &DbConn,
id: i64,
form_data: doc_info::Model
) -> Result<doc_info::Model, DbErr> {
let doc_info: doc_info::ActiveModel = Entity::find_by_id(id)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
.map(Into::into)?;
(doc_info::ActiveModel {
did: doc_info.did,
uid: Set(form_data.uid.to_owned()),
doc_name: Set(form_data.doc_name.to_owned()),
size: Set(form_data.size.to_owned()),
r#type: Set(form_data.r#type.to_owned()),
location: Set(form_data.location.to_owned()),
thumbnail_base64: doc_info.thumbnail_base64,
created_at: doc_info.created_at,
updated_at: Set(now()),
is_deleted: Default::default(),
}).update(db).await
}
pub async fn delete_doc_info(db: &DbConn, doc_ids: &Vec<i64>) -> Result<UpdateResult, DbErr> {
let mut dids = doc_ids.clone();
let mut i: usize = 0;
loop {
if dids.len() == i {
break;
}
let mut doc: doc_info::ActiveModel = Entity::find_by_id(dids[i])
.one(db).await?
.ok_or(DbErr::Custom(format!("Can't find doc:{}", dids[i])))
.map(Into::into)?;
doc.updated_at = Set(now());
doc.is_deleted = Set(true);
let _ = doc.update(db).await?;
for d in doc2_doc::Entity
::find()
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
.all(db).await? {
dids.push(d.did);
}
let _ = doc2_doc::Entity
::delete_many()
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
.exec(db).await?;
let _ = doc2_doc::Entity
::delete_many()
.filter(doc2_doc::Column::Did.eq(dids[i]))
.exec(db).await?;
i += 1;
}
crate::service::kb_info::Mutation::remove_docs(&db, dids, None).await
}
pub async fn rename(db: &DbConn, doc_id: i64, name: &String) -> Result<doc_info::Model, DbErr> {
let mut doc: doc_info::ActiveModel = Entity::find_by_id(doc_id)
.one(db).await?
.ok_or(DbErr::Custom(format!("Can't find doc:{}", doc_id)))
.map(Into::into)?;
doc.updated_at = Set(now());
doc.doc_name = Set(name.clone());
doc.update(db).await
}
pub async fn delete_all_doc_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
Entity::delete_many().exec(db).await
}
}

View File

@ -1,168 +0,0 @@
use chrono::{ Local, FixedOffset, Utc };
use migration::Expr;
use sea_orm::{
ActiveModelTrait,
ColumnTrait,
DbConn,
DbErr,
DeleteResult,
EntityTrait,
PaginatorTrait,
QueryFilter,
QueryOrder,
UpdateResult,
};
use sea_orm::ActiveValue::Set;
use crate::entity::kb_info;
use crate::entity::kb2_doc;
use crate::entity::kb_info::Entity;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub struct Query;
impl Query {
pub async fn find_kb_info_by_id(db: &DbConn, id: i64) -> Result<Option<kb_info::Model>, DbErr> {
Entity::find_by_id(id).one(db).await
}
pub async fn find_kb_infos(db: &DbConn) -> Result<Vec<kb_info::Model>, DbErr> {
Entity::find().all(db).await
}
pub async fn find_kb_infos_by_uid(db: &DbConn, uid: i64) -> Result<Vec<kb_info::Model>, DbErr> {
Entity::find().filter(kb_info::Column::Uid.eq(uid)).all(db).await
}
pub async fn find_kb_infos_by_name(
db: &DbConn,
name: String
) -> Result<Vec<kb_info::Model>, DbErr> {
Entity::find().filter(kb_info::Column::KbName.eq(name)).all(db).await
}
pub async fn find_kb_by_docs(
db: &DbConn,
doc_ids: Vec<i64>
) -> Result<Vec<kb_info::Model>, DbErr> {
let mut kbids = Vec::<i64>::new();
for k in kb2_doc::Entity
::find()
.filter(kb2_doc::Column::Did.is_in(doc_ids))
.all(db).await? {
kbids.push(k.kb_id);
}
Entity::find().filter(kb_info::Column::KbId.is_in(kbids)).all(db).await
}
pub async fn find_kb_infos_in_page(
db: &DbConn,
page: u64,
posts_per_page: u64
) -> Result<(Vec<kb_info::Model>, u64), DbErr> {
// Setup paginator
let paginator = Entity::find()
.order_by_asc(kb_info::Column::KbId)
.paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await?;
// Fetch paginated posts
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
}
}
pub struct Mutation;
impl Mutation {
pub async fn create_kb_info(
db: &DbConn,
form_data: kb_info::Model
) -> Result<kb_info::ActiveModel, DbErr> {
(kb_info::ActiveModel {
kb_id: Default::default(),
uid: Set(form_data.uid.to_owned()),
kb_name: Set(form_data.kb_name.to_owned()),
icon: Set(form_data.icon.to_owned()),
created_at: Set(now()),
updated_at: Set(now()),
is_deleted: Default::default(),
}).save(db).await
}
pub async fn add_docs(db: &DbConn, kb_id: i64, doc_ids: Vec<i64>) -> Result<(), DbErr> {
for did in doc_ids {
let res = kb2_doc::Entity
::find()
.filter(kb2_doc::Column::KbId.eq(kb_id))
.filter(kb2_doc::Column::Did.eq(did))
.all(db).await?;
if res.len() > 0 {
continue;
}
let _ = (kb2_doc::ActiveModel {
id: Default::default(),
kb_id: Set(kb_id),
did: Set(did),
kb_progress: Set(0.0),
kb_progress_msg: Set("".to_owned()),
updated_at: Set(now()),
is_deleted: Default::default(),
}).save(db).await?;
}
Ok(())
}
pub async fn remove_docs(
db: &DbConn,
doc_ids: Vec<i64>,
kb_id: Option<i64>
) -> Result<UpdateResult, DbErr> {
let update = kb2_doc::Entity
::update_many()
.col_expr(kb2_doc::Column::IsDeleted, Expr::value(true))
.col_expr(kb2_doc::Column::KbProgress, Expr::value(0))
.col_expr(kb2_doc::Column::KbProgressMsg, Expr::value(""))
.filter(kb2_doc::Column::Did.is_in(doc_ids));
if let Some(kbid) = kb_id {
update.filter(kb2_doc::Column::KbId.eq(kbid)).exec(db).await
} else {
update.exec(db).await
}
}
pub async fn update_kb_info_by_id(
db: &DbConn,
id: i64,
form_data: kb_info::Model
) -> Result<kb_info::Model, DbErr> {
let kb_info: kb_info::ActiveModel = Entity::find_by_id(id)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
.map(Into::into)?;
(kb_info::ActiveModel {
kb_id: kb_info.kb_id,
uid: kb_info.uid,
kb_name: Set(form_data.kb_name.to_owned()),
icon: Set(form_data.icon.to_owned()),
created_at: kb_info.created_at,
updated_at: Set(now()),
is_deleted: Default::default(),
}).update(db).await
}
pub async fn delete_kb_info(db: &DbConn, kb_id: i64) -> Result<DeleteResult, DbErr> {
let kb: kb_info::ActiveModel = Entity::find_by_id(kb_id)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
.map(Into::into)?;
kb.delete(db).await
}
pub async fn delete_all_kb_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
Entity::delete_many().exec(db).await
}
}

View File

@ -1,5 +0,0 @@
pub(crate) mod dialog_info;
pub(crate) mod tag_info;
pub(crate) mod kb_info;
pub(crate) mod doc_info;
pub(crate) mod user_info;

View File

@ -1,108 +0,0 @@
use chrono::{ FixedOffset, Utc };
use sea_orm::{
ActiveModelTrait,
DbConn,
DbErr,
DeleteResult,
EntityTrait,
PaginatorTrait,
QueryOrder,
ColumnTrait,
QueryFilter,
};
use sea_orm::ActiveValue::{ Set, NotSet };
use crate::entity::tag_info;
use crate::entity::tag_info::Entity;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub struct Query;
impl Query {
pub async fn find_tag_info_by_id(
id: i64,
db: &DbConn
) -> Result<Option<tag_info::Model>, DbErr> {
Entity::find_by_id(id).one(db).await
}
pub async fn find_tags_by_uid(uid: i64, db: &DbConn) -> Result<Vec<tag_info::Model>, DbErr> {
Entity::find().filter(tag_info::Column::Uid.eq(uid)).all(db).await
}
pub async fn find_tag_infos_in_page(
db: &DbConn,
page: u64,
posts_per_page: u64
) -> Result<(Vec<tag_info::Model>, u64), DbErr> {
// Setup paginator
let paginator = Entity::find()
.order_by_asc(tag_info::Column::Tid)
.paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await?;
// Fetch paginated posts
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
}
}
pub struct Mutation;
impl Mutation {
pub async fn create_tag(
db: &DbConn,
form_data: tag_info::Model
) -> Result<tag_info::ActiveModel, DbErr> {
(tag_info::ActiveModel {
tid: Default::default(),
uid: Set(form_data.uid.to_owned()),
tag_name: Set(form_data.tag_name.to_owned()),
regx: Set(form_data.regx.to_owned()),
color: Set(form_data.color.to_owned()),
icon: Set(form_data.icon.to_owned()),
folder_id: match form_data.folder_id {
0 => NotSet,
_ => Set(form_data.folder_id.to_owned()),
},
created_at: Set(now()),
updated_at: Set(now()),
}).save(db).await
}
pub async fn update_tag_by_id(
db: &DbConn,
id: i64,
form_data: tag_info::Model
) -> Result<tag_info::Model, DbErr> {
let tag: tag_info::ActiveModel = Entity::find_by_id(id)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
.map(Into::into)?;
(tag_info::ActiveModel {
tid: tag.tid,
uid: tag.uid,
tag_name: Set(form_data.tag_name.to_owned()),
regx: Set(form_data.regx.to_owned()),
color: Set(form_data.color.to_owned()),
icon: Set(form_data.icon.to_owned()),
folder_id: Set(form_data.folder_id.to_owned()),
created_at: Default::default(),
updated_at: Set(now()),
}).update(db).await
}
pub async fn delete_tag(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
let tag: tag_info::ActiveModel = Entity::find_by_id(tid)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
.map(Into::into)?;
tag.delete(db).await
}
pub async fn delete_all_tags(db: &DbConn) -> Result<DeleteResult, DbErr> {
Entity::delete_many().exec(db).await
}
}

View File

@ -1,131 +0,0 @@
use chrono::{ FixedOffset, Utc };
use migration::Expr;
use sea_orm::{
ActiveModelTrait,
ColumnTrait,
DbConn,
DbErr,
DeleteResult,
EntityTrait,
PaginatorTrait,
QueryFilter,
QueryOrder,
UpdateResult,
};
use sea_orm::ActiveValue::Set;
use crate::entity::user_info;
use crate::entity::user_info::Entity;
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
pub struct Query;
impl Query {
pub async fn find_user_info_by_id(
db: &DbConn,
id: i64
) -> Result<Option<user_info::Model>, DbErr> {
Entity::find_by_id(id).one(db).await
}
pub async fn login(
db: &DbConn,
email: &str,
password: &str
) -> Result<Option<user_info::Model>, DbErr> {
Entity::find()
.filter(user_info::Column::Email.eq(email))
.filter(user_info::Column::Password.eq(password))
.one(db).await
}
pub async fn find_user_infos(
db: &DbConn,
email: &String
) -> Result<Option<user_info::Model>, DbErr> {
Entity::find().filter(user_info::Column::Email.eq(email)).one(db).await
}
pub async fn find_user_infos_in_page(
db: &DbConn,
page: u64,
posts_per_page: u64
) -> Result<(Vec<user_info::Model>, u64), DbErr> {
// Setup paginator
let paginator = Entity::find()
.order_by_asc(user_info::Column::Uid)
.paginate(db, posts_per_page);
let num_pages = paginator.num_pages().await?;
// Fetch paginated posts
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
}
}
pub struct Mutation;
impl Mutation {
pub async fn create_user(
db: &DbConn,
form_data: &user_info::Model
) -> Result<user_info::ActiveModel, DbErr> {
(user_info::ActiveModel {
uid: Default::default(),
email: Set(form_data.email.to_owned()),
nickname: Set(form_data.nickname.to_owned()),
avatar_base64: Set(form_data.avatar_base64.to_owned()),
color_scheme: Set(form_data.color_scheme.to_owned()),
list_style: Set(form_data.list_style.to_owned()),
language: Set(form_data.language.to_owned()),
password: Set(form_data.password.to_owned()),
last_login_at: Set(now()),
created_at: Set(now()),
updated_at: Set(now()),
}).save(db).await
}
pub async fn update_user_by_id(
db: &DbConn,
form_data: &user_info::Model
) -> Result<user_info::Model, DbErr> {
let usr: user_info::ActiveModel = Entity::find_by_id(form_data.uid)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find user.".to_owned()))
.map(Into::into)?;
(user_info::ActiveModel {
uid: Set(form_data.uid),
email: Set(form_data.email.to_owned()),
nickname: Set(form_data.nickname.to_owned()),
avatar_base64: Set(form_data.avatar_base64.to_owned()),
color_scheme: Set(form_data.color_scheme.to_owned()),
list_style: Set(form_data.list_style.to_owned()),
language: Set(form_data.language.to_owned()),
password: Set(form_data.password.to_owned()),
updated_at: Set(now()),
last_login_at: usr.last_login_at,
created_at: usr.created_at,
}).update(db).await
}
pub async fn update_login_status(uid: i64, db: &DbConn) -> Result<UpdateResult, DbErr> {
Entity::update_many()
.col_expr(user_info::Column::LastLoginAt, Expr::value(now()))
.filter(user_info::Column::Uid.eq(uid))
.exec(db).await
}
pub async fn delete_user(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
let tag: user_info::ActiveModel = Entity::find_by_id(tid)
.one(db).await?
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
.map(Into::into)?;
tag.delete(db).await
}
pub async fn delete_all(db: &DbConn) -> Result<DeleteResult, DbErr> {
Entity::delete_many().exec(db).await
}
}

View File

@ -1,97 +0,0 @@
use std::io::{Cursor, Write};
use std::time::{Duration, Instant};
use actix_rt::time::interval;
use actix_web::{HttpRequest, HttpResponse, rt, web};
use actix_web::web::Buf;
use actix_ws::Message;
use futures_util::{future, StreamExt};
use futures_util::future::Either;
use uuid::Uuid;
use crate::api::doc_info::_upload_file;
use crate::AppState;
use crate::errors::AppError;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
/// How long before lack of client response causes a timeout.
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
pub async fn upload_file_ws(req: HttpRequest, stream: web::Payload, data: web::Data<AppState>) -> Result<HttpResponse, AppError> {
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
// spawn websocket handler (and don't await it) so that the response is returned immediately
rt::spawn(upload_file_handler(data, session, msg_stream));
Ok(res)
}
async fn upload_file_handler(
data: web::Data<AppState>,
mut session: actix_ws::Session,
mut msg_stream: actix_ws::MessageStream,
) {
let mut bytes = Cursor::new(vec![]);
let mut last_heartbeat = Instant::now();
let mut interval = interval(HEARTBEAT_INTERVAL);
let reason = loop {
let tick = interval.tick();
tokio::pin!(tick);
match future::select(msg_stream.next(), tick).await {
// received message from WebSocket client
Either::Left((Some(Ok(msg)), _)) => {
match msg {
Message::Text(text) => {
session.text(text).await.unwrap();
}
Message::Binary(bin) => {
let mut pos = 0; // notice the name of the file that will be written
while pos < bin.len() {
let bytes_written = bytes.write(&bin[pos..]).unwrap();
pos += bytes_written
};
session.binary(bin).await.unwrap();
}
Message::Close(reason) => {
break reason;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Pong(_) => {
last_heartbeat = Instant::now();
}
Message::Continuation(_) | Message::Nop => {}
};
}
Either::Left((Some(Err(_)), _)) => {
break None;
}
Either::Left((None, _)) => break None,
Either::Right((_inst, _)) => {
if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT {
break None;
}
let _ = session.ping(b"").await;
}
}
};
let _ = session.close(reason).await;
if !bytes.has_remaining() {
return;
}
let uid = bytes.get_i64();
let did = bytes.get_i64();
_upload_file(uid, did, &Uuid::new_v4().to_string(), &bytes.into_inner(), &data).await.unwrap();
}

View File

@ -1 +0,0 @@
pub mod doc_info;

0
web_server/__init__.py Normal file
View File

147
web_server/apps/__init__.py Normal file
View File

@ -0,0 +1,147 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from flask import Blueprint, Flask, request
from werkzeug.wrappers.request import Request
from flask_cors import CORS
from web_server.db import StatusEnum
from web_server.db.services import UserService
from web_server.utils import CustomJSONEncoder
from flask_session import Session
from flask_login import LoginManager
from web_server.settings import RetCode, SECRET_KEY, stat_logger
from web_server.hook import HookManager
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
from web_server.utils.api_utils import get_json_result, server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
__all__ = ['app']
logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
logger.addHandler(h)
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Flask(__name__)
CORS(app, supports_credentials=True,max_age = 2592000)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
#app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
Session(app)
login_manager = LoginManager()
login_manager.init_app(app)
def search_pages_path(pages_dir):
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
def register_page(page_path):
page_name = page_path.stem.rstrip('_app')
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.manager = Blueprint(page_name, module_name)
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, 'page_name', page_name)
url_prefix = f'/{API_VERSION}/{page_name}'
app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / 'web_server' / 'apps',
]
client_urls_prefix = [
register_page(path)
for dir in pages_dir
for path in search_pages_path(dir)
]
def client_authentication_before_request():
result = HookManager.client_authentication(ClientAuthenticationParameters(
request.full_path, request.headers,
request.form, request.data, request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
def site_authentication_before_request():
for url_prefix in client_urls_prefix:
if request.path.startswith(url_prefix):
return
result = HookManager.site_authentication(AuthenticationParameters(
request.headers.get('site_signature'),
request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
@app.before_request
def authentication_before_request():
if CLIENT_AUTHENTICATION:
return client_authentication_before_request()
if SITE_AUTHENTICATION:
return site_authentication_before_request()
@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:
access_token = str(jwt.loads(authorization))
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
if user:
return user[0]
else:
return None
except Exception as e:
stat_logger.exception(e)
return None
else:
return None

View File

@ -0,0 +1,235 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
from rag.nlp import search
from rag.utils import ELASTICSEARCH
from web_server.db.services import duplicate_name
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.services.user_service import TenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, FileType
from web_server.db.services.document_service import DocumentService
from web_server.settings import RetCode
from web_server.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from web_server.utils.file_utils import filename_type
@manager.route('/upload', methods=['POST'])
@login_required
@validate_request("kb_id")
def upload():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
location = filename
while MINIO.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, filename, blob)
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"created_by": current_user.id,
"type": filename_type(filename),
"name": filename,
"location": location,
"size": len(blob)
})
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/create', methods=['POST'])
@login_required
@validate_request("name", "kb_id")
def create():
req = request.json
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if DocumentService.query(name=req["name"], kb_id=kb_id):
return get_data_error_result(
retmsg="Duplicated document name in the same knowledgebase.")
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req["name"],
"location": "",
"size": 0
})
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
keywords = request.args.get("keywords", "")
page_number = request.args.get("page", 1)
items_per_page = request.args.get("page_size", 15)
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
docs = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data=docs)
except Exception as e:
return server_error_response(e)
@manager.route('/change_status', methods=['POST'])
@login_required
@validate_request("doc_id", "status")
def change_status():
req = request.json
if str(req["status"]) not in ["0", "1"]:
get_json_result(
data=False,
retmsg='"Status" must be either 0 or 1!',
retcode=RetCode.ARGUMENT_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if not DocumentService.update_by_id(
req["doc_id"], {"status": str(req["status"])}):
return get_data_error_result(
retmsg="Database error (Document update)!")
if str(req["status"]) == "0":
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="""
if(ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.remove(
ctx._source.kb_id.indexOf('%s')
);
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name(
kb.tenant_id)
)
else:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name(
kb.tenant_id)
)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("doc_id")
def rm():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result(
retmsg="Database error (Document removal)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
MINIO.rm(kb.id, doc.location)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/rename', methods=['POST'])
@login_required
@validate_request("doc_id", "name", "old_name")
def rename():
req = request.json
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
req["old_name"].lower()).suffix:
get_json_result(
data=False,
retmsg="The extension of file can't be changed",
retcode=RetCode.ARGUMENT_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if DocumentService.query(name=req["name"], kb_id=doc.kb_id):
return get_data_error_result(
retmsg="Duplicated document name in the same knowledgebase.")
if not DocumentService.update_by_id(
req["doc_id"], {"name": req["name"]}):
return get_data_error_result(
retmsg="Database error (Document rename)!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

102
web_server/apps/kb_app.py Normal file
View File

@ -0,0 +1,102 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/create', methods=['post'])
@login_required
@validate_request("name", "description", "permission", "embd_id", "parser_id")
def create():
req = request.json
req["name"] = req["name"].strip()
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
try:
req["id"] = get_uuid()
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
if not KnowledgebaseService.save(**req): return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]})
except Exception as e:
return server_error_response(e)
@manager.route('/update', methods=['post'])
@login_required
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
def update():
req = request.json
req["name"] = req["name"].strip()
try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
del req["kb_id"]
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
return get_json_result(data=kb.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
page_number = request.args.get("page", 1)
items_per_page = request.args.get("page_size", 15)
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['post'])
@login_required
@validate_request("kb_id")
def rm():
req = request.json
try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

226
web_server/apps/user_app.py Normal file
View File

@ -0,0 +1,226 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user
from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger
from web_server.utils.api_utils import get_json_result, cors_reponse
@manager.route('/login', methods=['POST', 'GET'])
def login():
userinfo = None
login_channel = "password"
if session.get("access_token"):
login_channel = session["access_token_from"]
if session["access_token_from"] == "github":
userinfo = user_info_from_github(session["access_token"])
elif not request.json:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Unautherized!')
email = request.json.get('email') if not userinfo else userinfo["email"]
users = UserService.query(email=email)
if not users:
if request.json is not None:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
avatar = ""
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
try:
users = user_register({
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": login_channel,
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
login_user(users[0])
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
password = request.json.get('password')
try:
password = decrypt(password)
except:
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
user = UserService.query_user(email, password)
if user:
response_data = user.to_json()
user.access_token = get_uuid()
login_user(user)
user.save()
msg = "Welcome back!"
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
else:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
@manager.route('/github_callback', methods=['GET'])
def github_callback():
try:
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
},headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
return redirect(url_for("user.login"), code=307)
except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
def user_info_from_github(access_token):
import requests
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
return user_info
@manager.route("/logout", methods=['GET'])
@login_required
def log_out():
current_user.access_token = ""
current_user.save()
logout_user()
return get_json_result(data=True)
@manager.route("/setting", methods=["POST"])
@login_required
def setting_user():
update_dict = {}
request_data = request.json
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
for k in request_data.keys():
if k in ["password", "new_password"]:continue
update_dict[k] = request_data[k]
try:
UserService.update_by_id(current_user.id, update_dict)
return get_json_result(data=True)
except Exception as e:
stat_logger.exception(e)
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/info", methods=["GET"])
@login_required
def user_info():
return get_json_result(data=current_user.to_dict())
def user_register(user):
user_id = get_uuid()
user["id"] = user_id
tenant = {
"id": user_id,
"name": user["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER
}
if not UserService.save(**user):return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
return UserService.query(email=user["email"])
@manager.route("/register", methods=["POST"])
@validate_request("nickname", "email", "password")
def user_add():
req = request.json
if UserService.query(email=req["email"]):
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
user_dict = {
"access_token": get_uuid(),
"email": req["email"],
"nickname": req["nickname"],
"password": decrypt(req["password"]),
"login_channel": "password",
"last_login_time": get_format_time(),
"is_superuser": False,
}
try:
users = user_register(user_dict)
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e:
stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/tenant_info", methods=["GET"])
@login_required
def tenant_info():
try:
tenants = TenantService.get_by_user_id(current_user.id)
return get_json_result(data=tenants)
except Exception as e:
return server_error_response(e)

54
web_server/db/__init__.py Normal file
View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
from enum import IntEnum
from strenum import StrEnum
class StatusEnum(Enum):
VALID = "1"
IN_VALID = "0"
class UserTenantRole(StrEnum):
OWNER = 'owner'
ADMIN = 'admin'
NORMAL = 'normal'
class TenantPermission(StrEnum):
ME = 'me'
TEAM = 'team'
class SerializedType(IntEnum):
PICKLE = 1
JSON = 2
class FileType(StrEnum):
PDF = 'pdf'
DOC = 'doc'
VISUAL = 'visual'
AURAL = 'aural'
VIRTUAL = 'virtual'
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'

616
web_server/db/db_models.py Normal file
View File

@ -0,0 +1,616 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import os
import sys
import typing
import operator
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from peewee import (
BigAutoField, BigIntegerField, BooleanField, CharField,
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase
from web_server.db import SerializedType
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
from web_server.utils.log_utils import getLogger
from web_server import utils
LOGGER = getLogger()
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
class LongTextField(TextField):
field_type = 'LONGTEXT'
class JSONField(LongTextField):
default_value = {}
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if value is None:
value = self.default_value
return utils.json_dumps(value)
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
default_value = []
class SerializedField(LongTextField):
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
self._serialized_type = serialized_type
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return None
return utils.json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
if is_continuous_field(p):
return True
else:
return False
def auto_date_timestamp_field():
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def auto_date_timestamp_db_field():
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def remove_field_name_prefix(field_name):
return field_name[2:] if field_name.startswith('f_') else field_name
class BaseModel(Model):
create_time = BigIntegerField(null=True)
create_date = DateTimeField(null=True)
update_time = BigIntegerField(null=True)
update_date = DateTimeField(null=True)
def to_json(self):
# This function is obsolete
return self.to_dict()
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
human_model_dict = {}
for k in self._meta.primary_key.field_names:
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
for k in only_primary_with:
human_model_dict[k] = model_dict[f'f_{k}']
return human_model_dict
@property
def meta(self) -> Metadata:
return self._meta
@classmethod
def get_primary_keys_name(cls):
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
cls._meta.primary_key.name]
@classmethod
def getter_by(cls, attr):
return operator.attrgetter(attr)(cls)
@classmethod
def query(cls, reverse=None, order_by=None, **kwargs):
filters = []
for f_n, f_v in kwargs.items():
attr_name = '%s' % f_n
if not hasattr(cls, attr_name) or f_v is None:
continue
if type(f_v) in {list, set}:
f_v = list(f_v)
if is_continuous_field(type(getattr(cls, attr_name))):
if len(f_v) == 2:
for i, v in enumerate(f_v):
if isinstance(v, str) and f_n in auto_date_timestamp_field():
# time type: %Y-%m-%d %H:%M:%S
f_v[i] = utils.date_string_to_timestamp(v)
lt_value = f_v[0]
gt_value = f_v[1]
if lt_value is not None and gt_value is not None:
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
elif lt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
elif gt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
else:
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
else:
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
if filters:
query_records = cls.select().where(*filters)
if reverse is not None:
if not order_by or not hasattr(cls, f"{order_by}"):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
elif reverse is False:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
return [query_record for query_record in query_records]
else:
return []
@classmethod
def insert(cls, __data=None, **insert):
if isinstance(__data, dict) and __data:
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
if insert:
insert["create_time"] = utils.current_timestamp()
return super().insert(__data, **insert)
# update and insert will call this method
@classmethod
def _normalize_data(cls, data, kwargs):
normalized = super()._normalize_data(data, kwargs)
if not normalized:
return {}
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
cls._meta.combined[f"{f_n}_time"] in normalized and \
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
normalized[cls._meta.combined[f"{f_n}_time"]])
return normalized
class JsonSerializedField(SerializedField):
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
class DatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB
def lock(self):
# SQL parameters only support %s format placeholders
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'mysql lock {self.lock_name} does not exist')
def __enter__(self):
if isinstance(self.db, PooledMySQLDatabase):
self.lock()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PooledMySQLDatabase):
self.unlock()
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)
return magic
DB = BaseDataBase().database_connection
DB.lock = DatabaseLock
def close_connection():
try:
if DB:
DB.close()
except Exception as e:
LOGGER.exception(e)
class DataBaseModel(BaseModel):
class Meta:
database = DB
@DB.connection_context()
def init_database_tables():
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
create_failed_list = []
for name, obj in members:
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
table_objs.append(obj)
LOGGER.info(f"start create table {obj.__name__}")
try:
obj.create_table()
LOGGER.info(f"create table success: {obj.__name__}")
except Exception as e:
LOGGER.exception(e)
create_failed_list.append(obj.__name__)
if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}")
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = '%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
class User(DataBaseModel, UserMixin):
id = CharField(max_length=32, primary_key=True)
access_token = CharField(max_length=255, null=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name")
password = CharField(max_length=255, null=True, help_text="password")
email = CharField(max_length=255, null=False, help_text="email", index=True)
avatar = TextField(null=True, help_text="avatar base64 string")
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
last_login_time = DateTimeField(null=True)
is_authenticated = CharField(max_length=1, null=False, default="1")
is_active = CharField(max_length=1, null=False, default="1")
is_anonymous = CharField(max_length=1, null=False, default="0")
login_channel = CharField(null=True, help_text="from which user login")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
is_superuser = BooleanField(null=True, help_text="is root", default=False)
def __str__(self):
return self.email
def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY)
return jwt.dumps(str(self.access_token))
class Meta:
db_table = "user"
class Tenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
name = CharField(max_length=100, null=True, help_text="Tenant name")
public_key = CharField(max_length=255, null=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "tenant"
class UserTenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
user_id = CharField(max_length=32, null=False)
tenant_id = CharField(max_length=32, null=False)
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
invited_by = CharField(max_length=32, null=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "user_tenant"
class InvitationCode(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
code = CharField(max_length=32, null=False)
visit_time = DateTimeField(null=True)
user_id = CharField(max_length=32, null=True)
tenant_id = CharField(max_length=32, null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "invitation_code"
class LLMFactories(DataBaseModel):
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
logo = TextField(null=True, help_text="llm logo base64")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "llm_factories"
class LLM(DataBaseModel):
# defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.llm_name
class Meta:
db_table = "llm"
class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
def __str__(self):
return self.llm_name
class Meta:
db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory')
class Knowledgebase(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
description = TextField(null=True, help_text="KB description")
permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0)
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "knowledgebase"
class Document(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")
name = CharField(max_length=255, null=True, help_text="file name", index=True)
location = CharField(max_length=255, null=True, help_text="where dose it store")
size = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "document"
class Dialog(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=255, null=True, help_text="dialog application name")
description = TextField(null=True, help_text="Dialog description")
icon = CharField(max_length=16, null=False, help_text="dialog icon")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
default="Creative")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "dialog"
class DialogKb(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True)
kb_id = CharField(max_length=32, null=False)
class Meta:
db_table = "dialog_kb"
primary_key = CompositeKey('dialog_id', 'kb_id')
class Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name")
message = JSONField(null=True)
class Meta:
db_table = "conversation"
"""
class Job(DataBaseModel):
# multi-party common configuration
f_user_id = CharField(max_length=25, null=True)
f_job_id = CharField(max_length=25, index=True)
f_name = CharField(max_length=500, null=True, default='')
f_description = TextField(null=True, default='')
f_tag = CharField(max_length=50, null=True, default='')
f_dsl = JSONField()
f_runtime_conf = JSONField()
f_runtime_conf_on_party = JSONField()
f_train_runtime_conf = JSONField(null=True)
f_roles = JSONField()
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50)
f_status = CharField(max_length=50)
f_status_code = IntegerField(null=True)
f_user = JSONField()
# this party configuration
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_is_initiator = BooleanField(null=True, default=False)
f_progress = IntegerField(null=True, default=0)
f_ready_signal = BooleanField(default=False)
f_ready_time = BigIntegerField(null=True)
f_cancel_signal = BooleanField(default=False)
f_cancel_time = BigIntegerField(null=True)
f_rerun_signal = BooleanField(default=False)
f_end_scheduling_updates = IntegerField(null=True, default=0)
f_engine_name = CharField(max_length=50, null=True)
f_engine_type = CharField(max_length=10, null=True)
f_cores = IntegerField(default=0)
f_memory = IntegerField(default=0) # MB
f_remaining_cores = IntegerField(default=0)
f_remaining_memory = IntegerField(default=0) # MB
f_resource_in_use = BooleanField(default=False)
f_apply_resource_time = BigIntegerField(null=True)
f_return_resource_time = BigIntegerField(null=True)
f_inheritance_info = JSONField(null=True)
f_inheritance_status = CharField(max_length=50, null=True)
f_start_time = BigIntegerField(null=True)
f_start_date = DateTimeField(null=True)
f_end_time = BigIntegerField(null=True)
f_end_date = DateTimeField(null=True)
f_elapsed = BigIntegerField(null=True)
class Meta:
db_table = "t_job"
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
class PipelineComponentMeta(DataBaseModel):
f_model_id = CharField(max_length=100, index=True)
f_model_version = CharField(max_length=100, index=True)
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_component_name = CharField(max_length=100, index=True)
f_component_module_name = CharField(max_length=100)
f_model_alias = CharField(max_length=100, index=True)
f_model_proto_index = JSONField(null=True)
f_run_parameters = JSONField(null=True)
f_archive_sha256 = CharField(max_length=100, null=True)
f_archive_from_ip = CharField(max_length=100, null=True)
class Meta:
db_table = 't_pipeline_component_meta'
indexes = (
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
)
"""

View File

@ -0,0 +1,157 @@
#
# Copyright 2021 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import json
import time
from functools import wraps
from shortuuid import ShortUUID
from web_server.versions import get_fate_version
from web_server.errors.error_services import *
from web_server.settings import (
GRPC_PORT, HOST, HTTP_PORT,
RANDOM_INSTANCE_ID, stat_logger,
)
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
server_instance = (
f'{HOST}:{GRPC_PORT}',
json.dumps({
'instance_id': instance_id,
'timestamp': round(time.time() * 1000),
'version': get_fate_version() or '',
'host': HOST,
'grpc_port': GRPC_PORT,
'http_port': HTTP_PORT,
}),
)
def check_service_supported(method):
"""Decorator to check if `service_name` is supported.
The attribute `supported_services` MUST be defined in class.
The first and second arguments of `method` MUST be `self` and `service_name`.
:param Callable method: The class method.
:return: The inner wrapper function.
:rtype: Callable
"""
@wraps(method)
def magic(self, service_name, *args, **kwargs):
if service_name not in self.supported_services:
raise ServiceNotSupported(service_name=service_name)
return method(self, service_name, *args, **kwargs)
return magic
class ServicesDB(abc.ABC):
"""Database for storage service urls.
Abstract base class for the real backends.
"""
@property
@abc.abstractmethod
def supported_services(self):
"""The names of supported services.
The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving).
:return: The service names.
:rtype: list
"""
pass
@abc.abstractmethod
def _get_serving(self):
pass
def get_serving(self):
try:
return self._get_serving()
except ServicesError as e:
stat_logger.exception(e)
return []
@abc.abstractmethod
def _insert(self, service_name, service_url, value=''):
pass
@check_service_supported
def insert(self, service_name, service_url, value=''):
"""Insert a service url to database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._insert(service_name, service_url, value)
except ServicesError as e:
stat_logger.exception(e)
@abc.abstractmethod
def _delete(self, service_name, service_url):
pass
@check_service_supported
def delete(self, service_name, service_url):
"""Delete a service url from database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._delete(service_name, service_url)
except ServicesError as e:
stat_logger.exception(e)
def register_flow(self):
"""Call `self.insert` for insert the flow server address to databae.
:return: None
"""
self.insert('flow-server', *server_instance)
def unregister_flow(self):
"""Call `self.delete` for delete the flow server address from databae.
:return: None
"""
self.delete('flow-server', server_instance[0])
@abc.abstractmethod
def _get_urls(self, service_name, with_values=False):
pass
@check_service_supported
def get_urls(self, service_name, with_values=False):
"""Query service urls from database. The urls may belong to other nodes.
Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported.
`fateflow` is a url containing scheme, host, port and path,
while `servings` only contains host and port.
:param str service_name: The service name.
:return: The service urls.
:rtype: list
"""
try:
return self._get_urls(service_name, with_values)
except ServicesError as e:
stat_logger.exception(e)
return []

131
web_server/db/db_utils.py Normal file
View File

@ -0,0 +1,131 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
from functools import reduce
from typing import Dict, Type, Union
from web_server.utils import current_timestamp, timestamp_to_date
from web_server.db.db_models import DB, DataBaseModel
from web_server.db.runtime_config import RuntimeConfig
from web_server.utils.log_utils import getLogger
from enum import Enum
LOGGER = getLogger()
@DB.connection_context()
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model])
current_time = current_timestamp()
current_date = timestamp_to_date(current_time)
for data in data_source:
if 'f_create_time' not in data:
data['f_create_time'] = current_time
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
data['f_update_time'] = current_time
data['f_update_date'] = current_date
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
for i in range(0, len(data_source), batch_size):
with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict:
query = query.on_conflict(preserve=preserve)
query.execute()
def get_dynamic_db_model(base, job_id):
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
def get_dynamic_tracking_table_index(job_id):
return job_id[:8]
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = 'f_%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
supported_operators = {
'==': operator.eq,
'<': operator.lt,
'<=': operator.le,
'>': operator.gt,
'>=': operator.ge,
'!=': operator.ne,
'<<': operator.lshift,
'>>': operator.rshift,
'%': operator.mod,
'**': operator.pow,
'^': operator.xor,
'~': operator.inv,
}
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
expression = []
for field, value in query.items():
if not isinstance(value, (list, tuple)):
value = ('==', value)
op, *val = value
field = getattr(model, f'f_{field}')
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
expression.append(value)
return reduce(operator.iand, expression)
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
query: dict = None, order_by: Union[str, list, tuple] = None):
data = model.select()
if query:
data = data.where(query_dict2expression(model, query))
count = data.count()
if not order_by:
order_by = 'create_time'
if not isinstance(order_by, (list, tuple)):
order_by = (order_by, 'asc')
order_by, order = order_by
order_by = getattr(model, f'f_{order_by}')
order_by = getattr(order_by, order)()
data = data.order_by(order_by)
if limit > 0:
data = data.limit(limit)
if offset > 0:
data = data.offset(offset)
return list(data), count
class StatusEnum(Enum):
# 样本可用状态
VALID = "1"
IN_VALID = "0"

141
web_server/db/init_data.py Normal file
View File

@ -0,0 +1,141 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import uuid
from web_server.db import LLMType
from web_server.db.db_models import init_database_tables as init_web_db
from web_server.db.services import UserService
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
def init_superuser():
user_info = {
"id": uuid.uuid1().hex,
"password": "admin",
"nickname": "admin",
"is_superuser": True,
"email": "kai.hu@infiniflow.org",
"creator": "system",
"status": "1",
}
UserService.save(**user_info)
def init_llm_factory():
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "通义千问",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智普AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "文心一言",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
]
llm_infos = [{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen_vl_chat_v1",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)
for info in llm_infos:
LLMService.save(**info)
def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory()
print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__':
init_web_db()
init_web_data()

View File

@ -0,0 +1,21 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import time
import typing
from web_server.utils.log_utils import sql_logger
import peewee

View File

@ -0,0 +1,27 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ReloadConfigBase:
@classmethod
def get_all(cls):
configs = {}
for k, v in cls.__dict__.items():
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
configs[k] = v
return configs
@classmethod
def get(cls, config_name):
return getattr(cls, config_name) if hasattr(cls, config_name) else None

View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from web_server.versions import get_versions
from .reload_config_base import ReloadConfigBase
class RuntimeConfig(ReloadConfigBase):
DEBUG = None
WORK_MODE = None
HTTP_PORT = None
JOB_SERVER_HOST = None
JOB_SERVER_VIP = None
ENV = dict()
SERVICE_DB = None
LOAD_CONFIG_MANAGER = False
@classmethod
def init_config(cls, **kwargs):
for k, v in kwargs.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def init_env(cls):
cls.ENV.update(get_versions())
@classmethod
def load_config_manager(cls):
cls.LOAD_CONFIG_MANAGER = True
@classmethod
def get_env(cls, key):
return cls.ENV.get(key, None)
@classmethod
def get_all_env(cls):
return cls.ENV
@classmethod
def set_service_db(cls, service_db):
cls.SERVICE_DB = service_db

View File

@ -0,0 +1,164 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import socket
from pathlib import Path
from web_server import utils
from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo
from .reload_config_base import ReloadConfigBase
class ServiceRegistry(ReloadConfigBase):
@classmethod
@DB.connection_context()
def load_service(cls, **kwargs) -> [ServiceRegistryInfo]:
service_registry_list = ServiceRegistryInfo.query(**kwargs)
return [service for service in service_registry_list]
@classmethod
@DB.connection_context()
def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"):
if not server_info:
server_list = ServerRegistry.query_server_info_from_db(server_name=server_name)
if not server_list:
raise Exception(f"no found server {server_name}")
server_info = server_list[0]
url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}"
else:
url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}"
service_info = {
"f_server_name": server_name,
"f_service_name": service_name,
"f_url": url,
"f_method": method,
"f_params": params if params else {},
"f_data": data if data else {},
"f_headers": headers if headers else {}
}
entity_model, status = ServiceRegistryInfo.get_or_create(
f_server_name=server_name,
f_service_name=service_name,
defaults=service_info)
if status is False:
for key in service_info:
setattr(entity_model, key, service_info[key])
entity_model.save(force_insert=False)
class ServerRegistry(ReloadConfigBase):
FATEBOARD = None
FATE_ON_STANDALONE = None
FATE_ON_EGGROLL = None
FATE_ON_SPARK = None
MODEL_STORE_ADDRESS = None
SERVINGS = None
FATEMANAGER = None
STUDIO = None
@classmethod
def load(cls):
cls.load_server_info_from_conf()
cls.load_server_info_from_db()
@classmethod
def load_server_info_from_conf(cls):
path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF
conf = utils.file_utils.load_yaml_conf(path)
if not isinstance(conf, dict):
raise ValueError('invalid config file')
local_path = path.with_name(f'local.{utils.SERVICE_CONF}')
if local_path.exists():
local_conf = utils.file_utils.load_yaml_conf(local_path)
if not isinstance(local_conf, dict):
raise ValueError('invalid local config file')
conf.update(local_conf)
for k, v in conf.items():
if isinstance(v, dict):
setattr(cls, k.upper(), v)
@classmethod
def register(cls, server_name, server_info):
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http"))
setattr(cls, server_name, server_info)
@classmethod
def save(cls, service_config):
update_server = {}
for server_name, server_info in service_config.items():
cls.parameter_check(server_info)
api_info = server_info.pop("api", {})
for service_name, info in api_info.items():
ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info)
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http")
setattr(cls, server_name.upper(), server_info)
return update_server
@classmethod
def parameter_check(cls, service_info):
if "host" in service_info and "port" in service_info:
cls.connection_test(service_info.get("host"), service_info.get("port"))
@classmethod
def connection_test(cls, ip, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = s.connect_ex((ip, port))
if result != 0:
raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}")
@classmethod
def query(cls, service_name, default=None):
service_info = getattr(cls, service_name, default)
if not service_info:
service_info = utils.get_base_config(service_name, default)
return service_info
@classmethod
@DB.connection_context()
def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]:
if server_name:
server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper())
else:
server_list = ServerRegistryInfo.select()
return [server for server in server_list]
@classmethod
@DB.connection_context()
def load_server_info_from_db(cls):
for server in cls.query_server_info_from_db():
server_info = {
"host": server.f_host,
"port": server.f_port,
"protocol": server.f_protocol
}
setattr(cls, server.f_server_name.upper(), server_info)
@classmethod
@DB.connection_context()
def save_server_info_to_db(cls, server_name, host, port, protocol="http"):
server_info = {
"f_server_name": server_name,
"f_host": host,
"f_port": port,
"f_protocol": protocol
}
entity_model, status = ServerRegistryInfo.get_or_create(
f_server_name=server_name,
defaults=server_info)
if status is False:
for key in server_info:
setattr(entity_model, key, server_info[key])
entity_model.save(force_insert=False)

View File

@ -0,0 +1,38 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
from .user_service import UserService
def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\([0-9]+\)$", nm)
c = 0
if r:
c = int(r.group(1))
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)

View File

@ -0,0 +1,153 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from web_server.db.db_models import DB
from web_server.utils import datetime_format
class CommonService:
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
#if "id" not in kwargs:
# kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
with DB.atomic():
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic():
for data in data_list:
data["update_time"] = cur
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
return False, None
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
with DB.atomic():
cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
return result
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list

Some files were not shown because too many files have changed in this diff Show More