support gpt-4o (#773)

### What problem does this PR solve?
#771 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-05-15 11:16:08 +08:00 committed by GitHub
parent 77b1520b66
commit aa1c915d6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 20 additions and 7 deletions

View File

@ -194,7 +194,7 @@ def list_app():
res = {} res = {}
for m in llms: for m in llms:
if model_type and m["model_type"] != model_type: if model_type and m["model_type"].find(model_type)<0:
continue continue
if m["fid"] not in res: if m["fid"] not in res:
res[m["fid"]] = [] res[m["fid"]] = []

View File

@ -143,6 +143,12 @@ def init_llm_factory():
llm_infos = [ llm_infos = [
# ---------------------- OpenAI ------------------------ # ---------------------- OpenAI ------------------------
{ {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4o",
"tags": "LLM,CHAT,128K",
"max_tokens": 128000,
"model_type": LLMType.CHAT.value + "," + LLMType.IMAGE2TEXT.value
}, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo", "llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K", "tags": "LLM,CHAT,4K",

View File

@ -81,7 +81,7 @@ class TenantLLMService(CommonService):
if not model_config: if not model_config:
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name) llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["Youdao", "FastEmbed"]: if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config: if not model_config:
if llm_name == "flag-embedding": if llm_name == "flag-embedding":

View File

@ -86,6 +86,12 @@ default_llm = {
"embedding_model": "", "embedding_model": "",
"image2text_model": "", "image2text_model": "",
"asr_model": "", "asr_model": "",
},
"DeepSeek": {
"chat_model": "deepseek-chat",
"embedding_model": "BAAI/bge-large-zh-v1.5",
"image2text_model": "",
"asr_model": "",
} }
} }
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})

View File

@ -25,7 +25,8 @@ EmbeddingModel = {
"Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed, "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed, "ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed, "FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed "Youdao": YoudaoEmbed,
"DeepSeek": DefaultEmbedding
} }

View File

@ -261,7 +261,7 @@ def main():
st = timer() st = timer()
cks = build(r) cks = build(r)
cron_logger.info("Build chunks({}): {}".format(r["name"], timer()-st)) cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st))
if cks is None: if cks is None:
continue continue
if not cks: if not cks:
@ -279,7 +279,7 @@ def main():
callback(-1, "Embedding error:{}".format(str(e))) callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e)) cron_logger.error(str(e))
tk_count = 0 tk_count = 0
cron_logger.info("Embedding elapsed({}): {}".format(r["name"], timer()-st)) cron_logger.info("Embedding elapsed({:.2f}): {}".format(r["name"], timer()-st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st)) callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
init_kb(r) init_kb(r)
@ -291,7 +291,7 @@ def main():
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
cron_logger.info("Indexing elapsed({}): {}".format(r["name"], timer()-st)) cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
if es_r: if es_r:
callback(-1, "Index failure!") callback(-1, "Index failure!")
ELASTICSEARCH.deleteByQuery( ELASTICSEARCH.deleteByQuery(
@ -306,7 +306,7 @@ def main():
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info( cron_logger.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{}".format( "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer()-st)) r["id"], tk_count, len(cks), timer()-st))