mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
fix user login issue (#85)
This commit is contained in:
parent
e86c820461
commit
0429107e80
@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
|
|||||||
|
|
||||||
@manager.route('/login', methods=['POST', 'GET'])
|
@manager.route('/login', methods=['POST', 'GET'])
|
||||||
def login():
|
def login():
|
||||||
userinfo = None
|
|
||||||
login_channel = "password"
|
login_channel = "password"
|
||||||
if session.get("access_token"):
|
if not request.json:
|
||||||
login_channel = session["access_token_from"]
|
|
||||||
if session["access_token_from"] == "github":
|
|
||||||
userinfo = user_info_from_github(session["access_token"])
|
|
||||||
elif not request.json:
|
|
||||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||||
retmsg='Unautherized!')
|
retmsg='Unautherized!')
|
||||||
|
|
||||||
email = request.json.get('email') if not userinfo else userinfo["email"]
|
email = request.json.get('email', "")
|
||||||
users = UserService.query(email=email)
|
users = UserService.query(email=email)
|
||||||
if not users:
|
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||||
if request.json is not None:
|
|
||||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
|
||||||
avatar = ""
|
|
||||||
try:
|
|
||||||
avatar = download_img(userinfo["avatar_url"])
|
|
||||||
except Exception as e:
|
|
||||||
stat_logger.exception(e)
|
|
||||||
user_id = get_uuid()
|
|
||||||
try:
|
|
||||||
users = user_register(user_id, {
|
|
||||||
"access_token": session["access_token"],
|
|
||||||
"email": userinfo["email"],
|
|
||||||
"avatar": avatar,
|
|
||||||
"nickname": userinfo["login"],
|
|
||||||
"login_channel": login_channel,
|
|
||||||
"last_login_time": get_format_time(),
|
|
||||||
"is_superuser": False,
|
|
||||||
})
|
|
||||||
if not users: raise Exception('Register user failure.')
|
|
||||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
|
||||||
user = users[0]
|
|
||||||
login_user(user)
|
|
||||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
|
||||||
except Exception as e:
|
|
||||||
rollback_user_registration(user_id)
|
|
||||||
stat_logger.exception(e)
|
|
||||||
return server_error_response(e)
|
|
||||||
elif not request.json:
|
|
||||||
login_user(users[0])
|
|
||||||
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
|
||||||
|
|
||||||
password = request.json.get('password')
|
password = request.json.get('password')
|
||||||
try:
|
try:
|
||||||
@ -97,28 +62,50 @@ def login():
|
|||||||
|
|
||||||
@manager.route('/github_callback', methods=['GET'])
|
@manager.route('/github_callback', methods=['GET'])
|
||||||
def github_callback():
|
def github_callback():
|
||||||
try:
|
import requests
|
||||||
import requests
|
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
"code": request.args.get('code')
|
||||||
"code": request.args.get('code')
|
}, headers={"Accept": "application/json"})
|
||||||
},headers={"Accept": "application/json"})
|
res = res.json()
|
||||||
res = res.json()
|
if "error" in res:
|
||||||
if "error" in res:
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
retmsg=res["error_description"])
|
||||||
retmsg=res["error_description"])
|
|
||||||
|
|
||||||
if "user:email" not in res["scope"].split(","):
|
if "user:email" not in res["scope"].split(","):
|
||||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||||
|
|
||||||
session["access_token"] = res["access_token"]
|
session["access_token"] = res["access_token"]
|
||||||
session["access_token_from"] = "github"
|
session["access_token_from"] = "github"
|
||||||
return redirect(url_for("user.login"), code=307)
|
userinfo = user_info_from_github(session["access_token"])
|
||||||
|
users = UserService.query(email=userinfo["email"])
|
||||||
|
user_id = get_uuid()
|
||||||
|
if not users:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
avatar = download_img(userinfo["avatar_url"])
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
avatar = ""
|
||||||
|
users = user_register(user_id, {
|
||||||
|
"access_token": session["access_token"],
|
||||||
|
"email": userinfo["email"],
|
||||||
|
"avatar": avatar,
|
||||||
|
"nickname": userinfo["login"],
|
||||||
|
"login_channel": "github",
|
||||||
|
"last_login_time": get_format_time(),
|
||||||
|
"is_superuser": False,
|
||||||
|
})
|
||||||
|
if not users: raise Exception('Register user failure.')
|
||||||
|
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||||
|
user = users[0]
|
||||||
|
login_user(user)
|
||||||
|
except Exception as e:
|
||||||
|
rollback_user_registration(user_id)
|
||||||
|
stat_logger.exception(e)
|
||||||
|
|
||||||
except Exception as e:
|
return redirect("/knowledge")
|
||||||
stat_logger.exception(e)
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
def user_info_from_github(access_token):
|
def user_info_from_github(access_token):
|
||||||
@ -208,7 +195,7 @@ def user_register(user_id, user):
|
|||||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||||
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
||||||
|
|
||||||
if not UserService.insert(**user):return
|
if not UserService.save(**user):return
|
||||||
TenantService.insert(**tenant)
|
TenantService.insert(**tenant)
|
||||||
UserTenantService.insert(**usr_tenant)
|
UserTenantService.insert(**usr_tenant)
|
||||||
TenantLLMService.insert_many(tenant_llm)
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ParserType(StrEnum):
|
class ParserType(StrEnum):
|
||||||
GENERAL = "general"
|
|
||||||
PRESENTATION = "presentation"
|
PRESENTATION = "presentation"
|
||||||
LAWS = "laws"
|
LAWS = "laws"
|
||||||
MANUAL = "manual"
|
MANUAL = "manual"
|
||||||
|
@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
|
|||||||
similarity_threshold = FloatField(default=0.2)
|
similarity_threshold = FloatField(default=0.2)
|
||||||
vector_similarity_weight = FloatField(default=0.3)
|
vector_similarity_weight = FloatField(default=0.3)
|
||||||
|
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
|
||||||
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
||||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ def init_superuser():
|
|||||||
"password": "admin",
|
"password": "admin",
|
||||||
"nickname": "admin",
|
"nickname": "admin",
|
||||||
"is_superuser": True,
|
"is_superuser": True,
|
||||||
"email": "kai.hu@infiniflow.org",
|
"email": "admin@ragflow.io",
|
||||||
"creator": "system",
|
"creator": "system",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
}
|
}
|
||||||
@ -61,7 +61,7 @@ def init_superuser():
|
|||||||
TenantService.insert(**tenant)
|
TenantService.insert(**tenant)
|
||||||
UserTenantService.insert(**usr_tenant)
|
UserTenantService.insert(**usr_tenant)
|
||||||
TenantLLMService.insert_many(tenant_llm)
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
||||||
|
|
||||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||||
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
from werkzeug.security import generate_password_hash, check_password_hash
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
|
|
||||||
@ -20,7 +22,7 @@ from api.db import UserTenantRole
|
|||||||
from api.db.db_models import DB, UserTenant
|
from api.db.db_models import DB, UserTenant
|
||||||
from api.db.db_models import User, Tenant
|
from api.db.db_models import User, Tenant
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.utils import get_uuid, get_format_time
|
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
|
|
||||||
|
|
||||||
@ -53,6 +55,11 @@ class UserService(CommonService):
|
|||||||
kwargs["id"] = get_uuid()
|
kwargs["id"] = get_uuid()
|
||||||
if "password" in kwargs:
|
if "password" in kwargs:
|
||||||
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
||||||
|
|
||||||
|
kwargs["create_time"] = current_timestamp()
|
||||||
|
kwargs["create_date"] = datetime_format(datetime.now())
|
||||||
|
kwargs["update_time"] = current_timestamp()
|
||||||
|
kwargs["update_date"] = datetime_format(datetime.now())
|
||||||
obj = cls.model(**kwargs).save(force_insert=True)
|
obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -66,10 +73,10 @@ class UserService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_user(cls, user_id, user_dict):
|
def update_user(cls, user_id, user_dict):
|
||||||
date_time = get_format_time()
|
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
if user_dict:
|
if user_dict:
|
||||||
user_dict["update_time"] = date_time
|
user_dict["update_time"] = current_timestamp()
|
||||||
|
user_dict["update_date"] = datetime_format(datetime.now())
|
||||||
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
|||||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||||
|
|
||||||
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
||||||
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
||||||
|
|
||||||
# distribution
|
# distribution
|
||||||
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
||||||
|
@ -25,7 +25,7 @@ class HuParser:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ocr = OCR()
|
self.ocr = OCR()
|
||||||
if not hasattr(self, "model_speciess"):
|
if not hasattr(self, "model_speciess"):
|
||||||
self.model_speciess = ParserType.GENERAL.value
|
self.model_speciess = ParserType.NAIVE.value
|
||||||
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
||||||
self.tbl_det = TableStructureRecognizer()
|
self.tbl_det = TableStructureRecognizer()
|
||||||
|
|
||||||
|
@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
|
|||||||
"Equation",
|
"Equation",
|
||||||
]
|
]
|
||||||
def __init__(self, domain):
|
def __init__(self, domain):
|
||||||
super().__init__(self.labels, domain,
|
super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
||||||
|
|
||||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
||||||
def __is_garbage(b):
|
def __is_garbage(b):
|
||||||
|
@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(self.labels, "tsr",
|
super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
||||||
|
|
||||||
def __call__(self, images, thr=0.2):
|
def __call__(self, images, thr=0.2):
|
||||||
tbls = super().__call__(images, thr)
|
tbls = super().__call__(images, thr)
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from api.db import ParserType
|
||||||
from rag.nlp import huqie, tokenize
|
from rag.nlp import huqie, tokenize
|
||||||
from deepdoc.parser import PdfParser
|
from deepdoc.parser import PdfParser
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class Pdf(PdfParser):
|
class Pdf(PdfParser):
|
||||||
|
def __init__(self):
|
||||||
|
self.model_speciess = ParserType.MANUAL.value
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, filename, binary=None, from_page=0,
|
def __call__(self, filename, binary=None, from_page=0,
|
||||||
to_page=100000, zoomin=3, callback=None):
|
to_page=100000, zoomin=3, callback=None):
|
||||||
self.__images__(
|
self.__images__(
|
||||||
|
@ -30,11 +30,21 @@ class Pdf(PdfParser):
|
|||||||
|
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
start = timer()
|
start = timer()
|
||||||
|
start = timer()
|
||||||
self._layouts_rec(zoomin)
|
self._layouts_rec(zoomin)
|
||||||
callback(0.77, "Layout analysis finished")
|
callback(0.5, "Layout analysis finished.")
|
||||||
|
print("paddle layouts:", timer() - start)
|
||||||
|
self._table_transformer_job(zoomin)
|
||||||
|
callback(0.7, "Table analysis finished.")
|
||||||
|
self._text_merge()
|
||||||
|
self._concat_downward(concat_between_pages=False)
|
||||||
|
self._filter_forpages()
|
||||||
|
callback(0.77, "Text merging finished")
|
||||||
|
tbls = self._extract_table_figure(True, zoomin, False)
|
||||||
|
|
||||||
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
||||||
self._naive_vertical_merge()
|
#self._naive_vertical_merge()
|
||||||
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
|
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||||
@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|||||||
Successive text will be sliced into pieces using 'delimiter'.
|
Successive text will be sliced into pieces using 'delimiter'.
|
||||||
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
eng = lang.lower() == "english"#is_english(cks)
|
||||||
doc = {
|
doc = {
|
||||||
"docnm_kwd": filename,
|
"docnm_kwd": filename,
|
||||||
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||||
}
|
}
|
||||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||||
|
res = []
|
||||||
pdf_parser = None
|
pdf_parser = None
|
||||||
sections = []
|
sections = []
|
||||||
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
||||||
@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|||||||
callback(0.8, "Finish parsing.")
|
callback(0.8, "Finish parsing.")
|
||||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||||
pdf_parser = Pdf()
|
pdf_parser = Pdf()
|
||||||
sections = pdf_parser(filename if not binary else binary,
|
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||||
from_page=from_page, to_page=to_page, callback=callback)
|
from_page=from_page, to_page=to_page, callback=callback)
|
||||||
|
# add tables
|
||||||
|
for img, rows in tbls:
|
||||||
|
bs = 10
|
||||||
|
de = ";" if eng else ";"
|
||||||
|
for i in range(0, len(rows), bs):
|
||||||
|
d = copy.deepcopy(doc)
|
||||||
|
r = de.join(rows[i:i + bs])
|
||||||
|
r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
|
||||||
|
tokenize(d, r, eng)
|
||||||
|
d["image"] = img
|
||||||
|
res.append(d)
|
||||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
txt = ""
|
txt = ""
|
||||||
@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|||||||
|
|
||||||
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
||||||
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
||||||
eng = lang.lower() == "english"#is_english(cks)
|
|
||||||
res = []
|
|
||||||
# wrap up to es documents
|
# wrap up to es documents
|
||||||
for ck in cks:
|
for ck in cks:
|
||||||
print("--", ck)
|
print("--", ck)
|
||||||
|
@ -37,7 +37,7 @@ from rag.nlp import search
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
|
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
|
||||||
|
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory
|
|||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
FACTORY = {
|
FACTORY = {
|
||||||
ParserType.GENERAL.value: laws,
|
ParserType.NAIVE.value: naive,
|
||||||
ParserType.PAPER.value: paper,
|
ParserType.PAPER.value: paper,
|
||||||
ParserType.BOOK.value: book,
|
ParserType.BOOK.value: book,
|
||||||
ParserType.PRESENTATION.value: presentation,
|
ParserType.PRESENTATION.value: presentation,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user