use minio to store uploaded files; build dialog server; (#16)

* format code

* use minio to store uploaded files; build dialog server;
This commit is contained in:
KevinHuSh 2023-12-25 19:05:59 +08:00 committed by GitHub
parent d4fd138954
commit 3245107dc7
13 changed files with 520 additions and 134 deletions

View File

@ -1 +1,2 @@
from .embedding_model import HuEmbedding
from .chat_model import GptTurbo

34
python/llm/chat_model.py Normal file
View File

@ -0,0 +1,34 @@
from abc import ABC
import openapi
import os
class Base(ABC):
def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!")
class GptTurbo(Base):
def __init__(self):
openapi.api_key = os.environ["OPENAPI_KEY"]
def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system})
res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
messages=history,
**gen_conf)
return res.choices[0].message.content.strip()
class QWen(Base):
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
response = Generation.call(
Generation.Models.qwen_turbo,
messages=messages,
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,6 +1,7 @@
from abc import ABC
from FlagEmbedding import FlagModel
import torch
import numpy as np
class Base(ABC):
def encode(self, texts: list, batch_size=32):
@ -27,5 +28,5 @@ class HuEmbedding(Base):
def encode(self, texts: list, batch_size=32):
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.encode(texts[i:i+batch_size]))
return res
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
return np.array(res)

View File

@ -372,7 +372,7 @@ class PptChunker(HuChunker):
def __call__(self, fnm):
from pptx import Presentation
ppt = Presentation(fnm)
ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
flds = self.Fields()
flds.text_chunks = []
for slide in ppt.slides:
@ -396,7 +396,9 @@ class TextChunker(HuChunker):
@staticmethod
def is_binary_file(file_path):
mime = magic.Magic(mime=True)
file_type = mime.from_file(file_path)
if isinstance(file_path, str):
file_type = mime.from_file(file_path)
else:file_type = mime.from_buffer(file_path)
if 'text' in file_type:
return False
else:

221
python/nlp/search.py Normal file
View File

@ -0,0 +1,221 @@
import re
from elasticsearch_dsl import Q,Search,A
from typing import List, Optional, Tuple,Dict, Union
from dataclasses import dataclass
from util import setup_logging, rmSpace
from nlp import huqie, query
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
from copy import deepcopy
class Dealer:
def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es)
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
self.es = es
self.emb_mdl = emb_mdl
@dataclass
class SearchResult:
total:int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
def _vector(self, txt, sim=0.8, topk=10):
return {
"field": "q_vec",
"k": topk,
"similarity": sim,
"num_candidates": 1000,
"query_vector": self.emb_mdl.encode_queries(txt)
}
def search(self, req, idxnm, tks_num=3):
keywords = []
qst = req.get("question", "")
bqry,keywords = self.qryr.question(qst)
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05
print(bqry)
s = Search()
pg = int(req.get("page", 1))-1
ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg*ps:(pg+1)*ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
s = s.highlight_options(
fragment_size = 120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
del s["highlight"]
q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry,_ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:continue
if kk in kwds:continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total = self.es.getTotal(res),
ids = self.es.getDocIds(res),
query_vector = q_vec,
aggregation = aggs,
highlight = self.getHighlight(res),
field = self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id","image_id", "doc_id", "q_vec"]),
keywords = list(kwds)
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
bkts = res["aggregations"]["aggs_"+g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:continue
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
r.append(t)
r = "".join(r)
return r
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:return {}
for d in self.es.getSource(sres):
m = {n:d.get(n) for n in flds if d.get(n) is not None}
for n,v in m.items():
if type(v) == type([]):
m[n] = "\t".join([str(vv) for vv in v])
continue
if type(v) != type(""):m[n] = str(m[n])
m[n] = rmSpace(m[n])
if m:res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0
e = 0
res = ""
def citeit():
nonlocal s, e, ans, res
if not ins_embd:return
embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
huqie.qie(ans[s:e]).split(" "),
ins_tw)
print(ans[s: e], sim)
mx = np.max(sim)*0.99
if mx < 0.55:return
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
for i in cita: res += f"@?{i}?@"
return cita
punct = set(";。?!")
if not self.qryr.isChinese(ans):
punct.add("?")
punct.add(".")
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
e += 1
continue
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
e += 1
continue
res += ans[s: e]
citeit()
res += ans[e]
e += 1
s = e
if s< len(ans):
res += ans[s:]
citeit()
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd: return []
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
#return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight)
return sim
if __name__ == "__main__":
from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [
"胡凯",
""
]
for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))

View File

@ -3,6 +3,7 @@ import re
import pandas as pd
from collections import Counter
from nlp import huqie
from io import BytesIO
class HuDocxParser:
@ -97,7 +98,7 @@ class HuDocxParser:
return ["\n".join(lines)]
def __call__(self, fnm):
self.doc = Document(fnm)
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
return secs, tbls

View File

@ -1,10 +1,12 @@
from openpyxl import load_workbook
import sys
from io import BytesIO
class HuExcelParser:
def __call__(self, fnm):
wb = load_workbook(fnm)
if isinstance(fnm, str):wb = load_workbook(fnm)
else: wb = load_workbook(BytesIO(fnm))
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]

View File

@ -1,4 +1,5 @@
import xgboost as xgb
from io import BytesIO
import torch
import re
import pdfplumber
@ -1525,7 +1526,7 @@ class HuParser:
return "\n\n".join(res)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.pdf = pdfplumber.open(fnm)
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
self.lefted_chars = []
self.mean_height = []
self.mean_width = []

164
python/svr/dialog_svr.py Executable file
View File

@ -0,0 +1,164 @@
#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from svr import sec_search as search
from svr.rpc_proxy import RPCProxy
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
SE = None
CFIELD="content_ltks"
EMBEDDING = HuEmbedding()
LLM = GptTurbo()
def get_QA_pairs(hists):
pa = []
for h in hists:
for k in ["user", "assistant"]:
if h.get(k):
pa.append({
"content": h[k],
"role": k,
})
for p in pa[:-1]: assert len(p) == 2, p
return pa
def get_instruction(sres, top_i, max_len=8096 fld="content_ltks"):
max_len //= len(top_i)
# add instruction to prompt
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
if len(instructions)>2:
# Said that LLM is sensitive to the first and the last one, so
# rearrange the order of references
instructions.append(copy.deepcopy(instructions[1]))
instructions.pop(1)
def token_num(txt):
c = 0
for tk in re.split(r"[,。/?‘’”“:;:;!]", txt):
if re.match(r"[a-zA-Z-]+$", tk):
c += 1
continue
c += len(tk)
return c
_inst = ""
for ins in instructions:
if token_num(_inst) > 4096:
_inst += "\n知识库:" + instructions[-1][:max_len]
break
_inst += "\n知识库:" + ins[:max_len]
return _inst
def prompt_and_answer(history, inst):
hist = get_QA_pairs(history)
chks = []
for s in re.split(r"[:;;。\n\r]+", inst):
if s: chks.append(s)
chks = len(set(chks))/(0.1+len(chks))
print("Duplication portion:", chks)
system = """
你是一个智能助手请总结知识库的内容来回答问题请列举知识库中的数据详细回答%s当所有知识库内容都与问题无关时你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话回答需要考虑聊天历史
以下是知识库
%s
以上是知识库
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
print("【PROMPT】:", system)
start = timer()
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
print("GENERATE: ", timer()-start)
print("===>>", response)
return response
class Handler(RequestHandler):
def post(self):
global SE,MUST_TK_NUM
param = json.loads(self.request.body.decode('utf-8'))
try:
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
res = SE.search({
"question": question,
"kb_ids": param.get("kb_ids", []),
"size": param.get("topn", 15)
})
sim = SE.rerank(res, question)
rk_idx = np.argsort(sim*-1)
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
inst = get_instruction(res, topidx)
ans, topidx = prompt_and_answer(param["history"], inst)
ans = SE.insert_citations(ans, topidx, res)
refer = OrderedDict()
docnms = {}
for i in rk_idx:
did = res.field[res.ids[i]]["doc_id"])
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"])
if did not in refer: refer[did] = []
refer[did].append({
"chunk_id": res.ids[i],
"content": res.field[res.ids[i]]["content_ltks"]),
"image": ""
})
print("::::::::::::::", ans)
self.write(json.dumps({
"code":0,
"msg":"success",
"data":{
"uid": param["uid"],
"dialog_id": param["dialog_id"],
"assistant": ans
"refer": [{
"did": did,
"doc_name": docnms[did],
"chunks": chunks
} for did, chunks in refer.items()]
}
}))
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
except Exception as e:
logging.error("Request 500: "+str(e))
self.write(json.dumps({
"code":500,
"msg":str(e),
"data":{}
}))
print(traceback.format_exc())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
ARGS = parser.parse_args()
SE = search.ResearchReportSearch(es_conn.HuEs("infiniflow"), EMBEDDING)
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
http_server = HTTPServer(app)
http_server.bind(ARGS.port)
http_server.start(3)
IOLoop.current().start()

View File

@ -34,18 +34,14 @@ DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
UPLOAD_LOCATION = os.environ.get("UPLOAD_LOCATION", "./")
logging.warning(f"The files are stored in {UPLOAD_LOCATION}, please check it!")
def chuck_doc(name):
def chuck_doc(name, binary):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0: return PDF(name)
if suff.find("doc") >= 0: return DOC(name)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(name)
if suff.find("ppt") >= 0: return PPT(name)
if suff.find("pdf") >= 0: return PDF(binary)
if suff.find("doc") >= 0: return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
if suff.find("ppt") >= 0: return PPT(binary)
return TextChunker()(name)
return TextChunker()(binary)
def collect(comm, mod, tm):
@ -115,7 +111,7 @@ def build(row):
random.seed(time.time())
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
try:
obj = chuck_doc(os.path.join(UPLOAD_LOCATION, row["location"]))
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
@ -133,9 +129,11 @@ def build(row):
doc = {
"doc_id": row["did"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
@ -144,11 +142,14 @@ def build(row):
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
d["img_bin"] = str(output_buffer.getvalue())
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
docs.append(d)
for arr, img in obj.table_chunks:

View File

@ -1,9 +1,9 @@
use std::collections::HashMap;
use std::io::Write;
use actix_multipart_extract::{File, Multipart, MultipartForm};
use actix_web::{get, HttpResponse, post, web};
use chrono::{Utc, FixedOffset};
use minio::s3::args::{BucketExistsArgs, MakeBucketArgs, UploadObjectArgs};
use std::io::BufReader;
use actix_multipart_extract::{ File, Multipart, MultipartForm };
use actix_web::{ HttpResponse, post, web };
use chrono::{ Utc, FixedOffset };
use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
use sea_orm::DbConn;
use crate::api::JsonResponse;
use crate::AppState;
@ -12,9 +12,6 @@ use crate::errors::AppError;
use crate::service::doc_info::{ Mutation, Query };
use serde::Deserialize;
const BUCKET_NAME: &'static str = "docgpt-upload";
fn now() -> chrono::DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
}
@ -74,53 +71,71 @@ async fn upload(
) -> Result<HttpResponse, AppError> {
let uid = payload.uid;
let file_name = payload.file_field.name.as_str();
async fn add_number_to_filename(file_name: &str, conn:&DbConn, uid:i64, parent_id:i64) -> String {
async fn add_number_to_filename(
file_name: &str,
conn: &DbConn,
uid: i64,
parent_id: i64
) -> String {
let mut i = 0;
let mut new_file_name = file_name.to_string();
let arr: Vec<&str> = file_name.split(".").collect();
let suffix = String::from(arr[arr.len()-1]);
let preffix = arr[..arr.len()-1].join(".");
let mut docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap();
while docs.len()>0 {
let suffix = String::from(arr[arr.len() - 1]);
let preffix = arr[..arr.len() - 1].join(".");
let mut docs = Query::find_doc_infos_by_name(
conn,
uid,
&new_file_name,
Some(parent_id)
).await.unwrap();
while docs.len() > 0 {
i += 1;
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
docs = Query::find_doc_infos_by_name(conn, uid, &new_file_name, Some(parent_id)).await.unwrap();
docs = Query::find_doc_infos_by_name(
conn,
uid,
&new_file_name,
Some(parent_id)
).await.unwrap();
}
new_file_name
}
let fnm = add_number_to_filename(file_name, &data.conn, uid, payload.did).await;
let s3_client = &data.s3_client;
let bucket_name = format!("{}-upload", payload.uid);
let s3_client: &minio::s3::client::Client = &data.s3_client;
let buckets_exists = s3_client
.bucket_exists(&BucketExistsArgs::new(BUCKET_NAME)?)
.await?;
.bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
.unwrap();
if !buckets_exists {
s3_client
.make_bucket(&MakeBucketArgs::new(BUCKET_NAME)?)
.await?;
print!("Create bucket: {}", bucket_name.clone());
s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
} else {
print!("Existing bucket: {}", bucket_name.clone());
}
s3_client
.upload_object(
&mut UploadObjectArgs::new(
BUCKET_NAME,
fnm.as_str(),
format!("/{}/{}-{}", payload.uid, payload.did, fnm).as_str()
)?
)
.await?;
let location = format!("/{}/{}", payload.did, fnm);
print!("===>{}", location.clone());
s3_client.put_object(
&mut PutObjectArgs::new(
&bucket_name,
&location,
&mut BufReader::new(payload.file_field.bytes.as_slice()),
Some(payload.file_field.bytes.len()),
None
)?
).await?;
let location = format!("/{}/{}", BUCKET_NAME, fnm);
let doc = Mutation::create_doc_info(&data.conn, Model {
did:Default::default(),
uid: uid,
did: Default::default(),
uid: uid,
doc_name: fnm,
size: payload.file_field.bytes.len() as i64,
location,
r#type: "doc".to_string(),
created_at: now(),
updated_at: now(),
is_deleted:Default::default(),
is_deleted: Default::default(),
}).await?;
let _ = Mutation::place_doc(&data.conn, payload.did, doc.did.unwrap()).await?;

View File

@ -1,58 +0,0 @@
use std::collections::HashMap;
use actix_web::{get, HttpResponse, post, web};
use actix_web::http::Error;
use crate::api::JsonResponse;
use crate::AppState;
use crate::entity::tag_info;
use crate::service::tag_info::{Mutation, Query};
#[post("/v1.0/create_tag")]
async fn create(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
let model = Mutation::create_tag(&data.conn, model.into_inner()).await.unwrap();
let mut result = HashMap::new();
result.insert("tid", model.tid.unwrap());
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response).unwrap()))
}
#[post("/v1.0/delete_tag")]
async fn delete(model: web::Json<tag_info::Model>, data: web::Data<AppState>) -> Result<HttpResponse, Error> {
let _ = Mutation::delete_tag(&data.conn, model.tid).await.unwrap();
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: (),
};
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response).unwrap()))
}
#[get("/v1.0/tags")]
async fn list(data: web::Data<AppState>) -> Result<HttpResponse, Error> {
let tags = Query::find_tag_infos(&data.conn).await.unwrap();
let mut result = HashMap::new();
result.insert("tags", tags);
let json_response = JsonResponse {
code: 200,
err: "".to_owned(),
data: result,
};
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_string(&json_response).unwrap()))
}

View File

@ -5,9 +5,9 @@ mod errors;
use std::env;
use actix_files::Files;
use actix_identity::{CookieIdentityPolicy, IdentityService, RequestIdentity};
use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
use actix_session::CookieSession;
use actix_web::{web, App, HttpServer, middleware, Error};
use actix_web::{ web, App, HttpServer, middleware, Error };
use actix_web::cookie::time::Duration;
use actix_web::dev::ServiceRequest;
use actix_web::error::ErrorUnauthorized;
@ -16,9 +16,9 @@ use listenfd::ListenFd;
use minio::s3::client::Client;
use minio::s3::creds::StaticProvider;
use minio::s3::http::BaseUrl;
use sea_orm::{Database, DatabaseConnection};
use migration::{Migrator, MigratorTrait};
use crate::errors::{AppError, UserError};
use sea_orm::{ Database, DatabaseConnection };
use migration::{ Migrator, MigratorTrait };
use crate::errors::{ AppError, UserError };
#[derive(Debug, Clone)]
struct AppState {
@ -28,10 +28,10 @@ struct AppState {
pub(crate) async fn validator(
req: ServiceRequest,
credentials: BearerAuth,
credentials: BearerAuth
) -> Result<ServiceRequest, Error> {
if let Some(token) = req.get_identity() {
println!("{}, {}",credentials.token(), token);
println!("{}, {}", credentials.token(), token);
(credentials.token() == token)
.then(|| req)
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
@ -52,26 +52,25 @@ async fn main() -> Result<(), AppError> {
let port = env::var("PORT").expect("PORT is not set in .env file");
let server_url = format!("{host}:{port}");
let s3_base_url = env::var("S3_BASE_URL").expect("S3_BASE_URL is not set in .env file");
let s3_access_key = env::var("S3_ACCESS_KEY").expect("S3_ACCESS_KEY is not set in .env file");;
let s3_secret_key = env::var("S3_SECRET_KEY").expect("S3_SECRET_KEY is not set in .env file");;
let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
if s3_base_url.find("http") != Some(0) {
s3_base_url = format!("http://{}", s3_base_url);
}
// establish connection to database and apply migrations
// -> create post table if not exists
let conn = Database::connect(&db_url).await.unwrap();
Migrator::up(&conn, None).await.unwrap();
let static_provider = StaticProvider::new(
s3_access_key.as_str(),
s3_secret_key.as_str(),
None,
);
let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
let s3_client = Client::new(
s3_base_url.parse::<BaseUrl>()?,
Some(Box::new(static_provider)),
None,
None,
Some(true)
)?;
let state = AppState { conn, s3_client };
@ -82,18 +81,20 @@ async fn main() -> Result<(), AppError> {
App::new()
.service(Files::new("/static", "./static"))
.app_data(web::Data::new(state.clone()))
.wrap(IdentityService::new(
CookieIdentityPolicy::new(&[0; 32])
.name("auth-cookie")
.login_deadline(Duration::seconds(120))
.secure(false),
))
.wrap(
IdentityService::new(
CookieIdentityPolicy::new(&[0; 32])
.name("auth-cookie")
.login_deadline(Duration::seconds(120))
.secure(false)
)
)
.wrap(
CookieSession::signed(&[0; 32])
.name("session-cookie")
.secure(false)
// WARNING(alex): This uses the `time` crate, not `std::time`!
.expires_in_time(Duration::seconds(60)),
.expires_in_time(Duration::seconds(60))
)
.wrap(middleware::Logger::default())
.configure(init)
@ -137,4 +138,4 @@ fn init(cfg: &mut web::ServiceConfig) {
cfg.service(api::user_info::login);
cfg.service(api::user_info::register);
cfg.service(api::user_info::setting);
}
}