Add 2 embeding models from OpenAI (#812)

### What problem does this PR solve?

#810 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-05-17 08:51:29 +08:00 committed by GitHub
parent d54d1375a5
commit e73ce39b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 0 deletions

View File

@ -16,6 +16,7 @@
import os
import time
import uuid
from copy import deepcopy
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
@ -166,6 +167,18 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-small",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-3-large",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
@ -376,6 +389,23 @@ def init_llm_factory():
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
## insert openai two embedding models to the current openai user.
print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
for tid in tenant_ids:
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
row = row.to_dict()
row["model_type"] = LLMType.EMBEDDING.value
row["llm_name"] = "text-embedding-3-small"
row["used_tokens"] = 0
try:
TenantLLMService.save(**row)
row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row)
except Exception as e:
pass
break
"""
drop table llm;
drop table llm_factories;

View File

@ -135,6 +135,16 @@ class TenantLLMService(CommonService):
.execute()
return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where(
(cls.model.llm_factory == "OpenAI"),
~(cls.model.llm_name == "text-embedding-3-small"),
~(cls.model.llm_name == "text-embedding-3-large")
).dicts()
return list(objs)
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):