Refa: remove max toekns for image2txt models. (#6078)

### What problem does this PR solve?

#6063

### Type of change


- [x] Refactoring
This commit is contained in:
Kevin Hu 2025-03-14 13:51:45 +08:00 committed by GitHub
parent 42eb99554f
commit 56b228f187
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -49,7 +49,6 @@ class Base(ABC):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7) top_p=gen_conf.get("top_p", 0.7)
) )
@ -71,7 +70,6 @@ class Base(ABC):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True stream=True
@ -157,8 +155,7 @@ class GptV4(Base):
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=prompt, messages=prompt
max_tokens=max_tokens,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -181,8 +178,7 @@ class AzureGptV4(Base):
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=prompt, messages=prompt
max_tokens=max_tokens,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -241,7 +237,6 @@ class QWenCV(Base):
if his["role"] == "user": if his["role"] == "user":
his["content"] = self.chat_prompt(his["content"], image) his["content"] = self.chat_prompt(his["content"], image)
response = MultiModalConversation.call(model=self.model_name, messages=history, response = MultiModalConversation.call(model=self.model_name, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)) top_p=gen_conf.get("top_p", 0.7))
@ -271,7 +266,6 @@ class QWenCV(Base):
tk_count = 0 tk_count = 0
try: try:
response = MultiModalConversation.call(model=self.model_name, messages=history, response = MultiModalConversation.call(model=self.model_name, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True) stream=True)
@ -306,8 +300,7 @@ class Zhipu4V(Base):
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=prompt, messages=prompt
max_tokens=max_tokens,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -322,7 +315,6 @@ class Zhipu4V(Base):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7) top_p=gen_conf.get("top_p", 0.7)
) )
@ -344,7 +336,6 @@ class Zhipu4V(Base):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True stream=True
@ -376,12 +367,10 @@ class OllamaCV(Base):
def describe(self, image, max_tokens=1024): def describe(self, image, max_tokens=1024):
prompt = self.prompt("") prompt = self.prompt("")
try: try:
options = {"num_predict": max_tokens}
response = self.client.generate( response = self.client.generate(
model=self.model_name, model=self.model_name,
prompt=prompt[0]["content"][1]["text"], prompt=prompt[0]["content"][1]["text"],
images=[image], images=[image]
options=options
) )
ans = response["response"].strip() ans = response["response"].strip()
return ans, 128 return ans, 128
@ -399,8 +388,6 @@ class OllamaCV(Base):
options = {} options = {}
if "temperature" in gen_conf: if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"] options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"] options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: if "presence_penalty" in gen_conf:
@ -429,8 +416,6 @@ class OllamaCV(Base):
options = {} options = {}
if "temperature" in gen_conf: if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"] options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"] options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: if "presence_penalty" in gen_conf:
@ -480,8 +465,7 @@ class XinferenceCV(Base):
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=self.prompt(b64), messages=self.prompt(b64)
max_tokens=max_tokens,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -497,15 +481,13 @@ class GeminiCV(Base):
def describe(self, image, max_tokens=2048): def describe(self, image, max_tokens=2048):
from PIL.Image import open from PIL.Image import open
gen_config = {'max_output_tokens':max_tokens}
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image) b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64))) img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img] input = [prompt,img]
res = self.model.generate_content( res = self.model.generate_content(
input, input
generation_config=gen_config,
) )
return res.text,res.usage_metadata.total_token_count return res.text,res.usage_metadata.total_token_count
@ -525,7 +507,7 @@ class GeminiCV(Base):
history[-1]["parts"].append("data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7))) top_p=gen_conf.get("top_p", 0.7)))
ans = response.text ans = response.text
@ -551,7 +533,7 @@ class GeminiCV(Base):
history[-1]["parts"].append("data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True) top_p=gen_conf.get("top_p", 0.7)), stream=True)
for resp in response: for resp in response:
@ -618,8 +600,7 @@ class NvidiaCV(Base):
"Authorization": f"Bearer {self.key}", "Authorization": f"Bearer {self.key}",
}, },
json={ json={
"messages": self.prompt(b64), "messages": self.prompt(b64)
"max_tokens": max_tokens,
}, },
) )
response = response.json() response = response.json()