diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 87bd59be9..89e5593b3 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -169,8 +169,8 @@ class TenantLLMService(CommonService): num = 0 try: - for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): - num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ + for u in cls.query(tenant_id=tenant_id, llm_name=mdlnm): + num += cls.model.update(used_tokens=u.used_tokens + used_tokens)\ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ .execute() except Exception as e: @@ -252,7 +252,6 @@ class LLMBundle(object): return yield chunk - def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) if not TenantLLMService.increase_usage( diff --git a/graphrag/index.py b/graphrag/index.py index b4e93c551..02d978f71 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import re from concurrent.futures import ThreadPoolExecutor import json from functools import reduce @@ -24,7 +23,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from graphrag.community_reports_extractor import CommunityReportsExtractor from graphrag.entity_resolution import EntityResolution -from graphrag.graph_extractor import GraphExtractor +from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES from graphrag.mind_map_extractor import MindMapExtractor from rag.nlp import rag_tokenizer from rag.utils import num_tokens_from_string @@ -52,7 +51,7 @@ def graph_merge(g1, g2): return g -def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=["organization", "person", "location", "event", "time"]): +def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES): _, tenant = TenantService.get_by_id(tenant_id) llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id) ext = GraphExtractor(llm_bdl) diff --git a/rag/app/knowledge_graph.py b/rag/app/knowledge_graph.py index 12b87fe09..b7bcddd64 100644 --- a/rag/app/knowledge_graph.py +++ b/rag/app/knowledge_graph.py @@ -1,6 +1,6 @@ import re -from graphrag.index import build_knowlege_graph_chunks +from graphrag.index import build_knowledge_graph_chunks from rag.app import naive from rag.nlp import rag_tokenizer, tokenize_chunks @@ -15,9 +15,9 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, parser_config["layout_recognize"] = False sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True, parser_config=parser_config, callback=callback) - chunks = build_knowlege_graph_chunks(tenant_id, sections, callback, - parser_config.get("entity_types", ["organization", "person", "location", "event", "time"]) - ) + chunks = build_knowledge_graph_chunks(tenant_id, sections, callback, + parser_config.get("entity_types", ["organization", "person", "location", "event", "time"]) + ) for c in chunks: c["docnm_kwd"] = filename doc = { diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index ce65f7dd3..bfce43499 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -20,7 +20,6 @@ from abc import ABC from openai import OpenAI import openai from ollama import Client -from volcengine.maas.v2 import MaasService from rag.nlp import is_english from rag.utils import num_tokens_from_string from groq import Groq @@ -29,6 +28,7 @@ import json import requests import asyncio + class Base(ABC): def __init__(self, key, model_name, base_url): self.client = OpenAI(api_key=key, base_url=base_url) diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 796b0e965..7cd56e9f8 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -78,11 +78,9 @@ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" try: - num_tokens = len(encoder.encode(string)) - return num_tokens - except Exception as e: - pass - return 0 + return len(encoder.encode(string)) + except Exception: + return 0 def truncate(string: str, max_len: int) -> str: