Merge pull request #693 from fa0311/main

Fixed bugs related to sum
This commit is contained in:
Daniel Gatis 2024-11-29 21:10:45 -03:00 committed by GitHub
commit 5435d2ffee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import os import os
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Tuple from typing import List
import cv2 import cv2
import numpy as np import numpy as np
@ -105,9 +105,10 @@ class SamSession(BaseSession):
valid_providers = [] valid_providers = []
available_providers = ort.get_available_providers() available_providers = ort.get_available_providers()
for provider in providers or []: if providers:
if provider in available_providers: for provider in providers or []:
valid_providers.append(provider) if provider in available_providers:
valid_providers.append(provider)
else: else:
valid_providers.extend(available_providers) valid_providers.extend(available_providers)
@ -142,7 +143,16 @@ class SamSession(BaseSession):
Returns: Returns:
List[PILImage]: A list of masks generated by the decoder. List[PILImage]: A list of masks generated by the decoder.
""" """
prompt = kwargs.get("sam_prompt", "{}") prompt = kwargs.get(
"sam_prompt",
[
{
"type": "point",
"label": 1,
"data": [int(img.width / 2), int(img.height / 2)],
}
],
)
schema = { schema = {
"type": "array", "type": "array",
"items": { "items": {