mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 12:39:59 +08:00
Dynamic Context Window Size for Ollama Chat (#6582)
# Dynamic Context Window Size for Ollama Chat ## Problem Statement Previously, the Ollama chat implementation used a fixed context window size of 32768 tokens. This caused two main issues: 1. Performance degradation due to unnecessarily large context windows for small conversations 2. Potential business logic failures when using smaller fixed sizes (e.g., 2048 tokens) ## Solution Implemented a dynamic context window size calculation that: 1. Uses a base context size of 8192 tokens 2. Applies a 1.2x buffer ratio to the total token count 3. Adds multiples of 8192 tokens based on the buffered token count 4. Implements a smart context size update strategy ## Implementation Details ### Token Counting Logic ```python def count_tokens(text): """Calculate token count for text""" # Simple calculation: 1 token per ASCII character # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) total = 0 for char in text: if ord(char) < 128: # ASCII characters total += 1 else: # Non-ASCII characters total += 2 return total ``` ### Dynamic Context Calculation ```python def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" # Calculate total tokens for all messages total_tokens = 0 for message in history: content = message.get("content", "") content_tokens = count_tokens(content) role_tokens = 4 # Role marker token overhead total_tokens += content_tokens + role_tokens # Apply 1.2x buffer ratio total_tokens_with_buffer = int(total_tokens * 1.2) # Calculate context size in multiples of 8192 if total_tokens_with_buffer <= 8192: ctx_size = 8192 else: ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 ctx_size = ctx_multiplier * 8192 return ctx_size ``` ### Integration in Chat Method ```python def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] try: # Calculate new context size new_ctx_size = self._calculate_dynamic_ctx(history) # Prepare options with context size options = { "num_ctx": new_ctx_size } # Add other generation options if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"] if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] # Make API call with dynamic context size response = self.client.chat( model=self.model_name, messages=history, options=options, keep_alive=60 ) return response["message"]["content"].strip(), response.get("eval_count", 0) + response.get("prompt_eval_count", 0) except Exception as e: return "**ERROR**: " + str(e), 0 ``` ## Benefits 1. **Improved Performance**: Uses appropriate context windows based on conversation length 2. **Better Resource Utilization**: Context window size scales with content 3. **Maintained Compatibility**: Works with existing business logic 4. **Predictable Scaling**: Context growth in 8192-token increments 5. **Smart Updates**: Context size updates are optimized to reduce unnecessary model reloads ## Future Considerations 1. Fine-tune buffer ratio based on usage patterns 2. Add monitoring for context window utilization 3. Consider language-specific token counting optimizations 4. Implement adaptive threshold based on conversation patterns 5. Add metrics for context size update frequency --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
1fbc4870f0
commit
c61df5dd25
@ -179,7 +179,41 @@ class Base(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
def count_tokens(text):
|
||||
"""Calculate token count for text"""
|
||||
# Simple calculation: 1 token per ASCII character
|
||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total = 0
|
||||
for char in text:
|
||||
if ord(char) < 128: # ASCII characters
|
||||
total += 1
|
||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
||||
total += 2
|
||||
return total
|
||||
|
||||
# Calculate total tokens for all messages
|
||||
total_tokens = 0
|
||||
for message in history:
|
||||
content = message.get("content", "")
|
||||
# Calculate content tokens
|
||||
content_tokens = count_tokens(content)
|
||||
# Add role marker token overhead
|
||||
role_tokens = 4
|
||||
total_tokens += content_tokens + role_tokens
|
||||
|
||||
# Apply 1.2x buffer ratio
|
||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
||||
|
||||
if total_tokens_with_buffer <= 8192:
|
||||
ctx_size = 8192
|
||||
else:
|
||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
||||
ctx_size = ctx_multiplier * 8192
|
||||
|
||||
return ctx_size
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||
@ -469,7 +503,7 @@ class ZhipuChat(Base):
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
@ -478,7 +512,12 @@ class OllamaChat(Base):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
options = {"num_ctx": 32768}
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
|
||||
options = {
|
||||
"num_ctx": ctx_size
|
||||
}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -489,9 +528,11 @@ class OllamaChat(Base):
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
|
||||
|
||||
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=10)
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
return ans, token_count
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
@ -500,28 +541,38 @@ class OllamaChat(Base):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
options = {
|
||||
"num_ctx": ctx_size
|
||||
}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
options["top_p"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
yield token_count
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
yield "**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
|
||||
class LocalAIChat(Base):
|
||||
|
Loading…
x
Reference in New Issue
Block a user