mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 12:35:54 +08:00
commit
5435d2ffee
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -105,9 +105,10 @@ class SamSession(BaseSession):
|
||||
valid_providers = []
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
for provider in providers or []:
|
||||
if provider in available_providers:
|
||||
valid_providers.append(provider)
|
||||
if providers:
|
||||
for provider in providers or []:
|
||||
if provider in available_providers:
|
||||
valid_providers.append(provider)
|
||||
else:
|
||||
valid_providers.extend(available_providers)
|
||||
|
||||
@ -142,7 +143,16 @@ class SamSession(BaseSession):
|
||||
Returns:
|
||||
List[PILImage]: A list of masks generated by the decoder.
|
||||
"""
|
||||
prompt = kwargs.get("sam_prompt", "{}")
|
||||
prompt = kwargs.get(
|
||||
"sam_prompt",
|
||||
[
|
||||
{
|
||||
"type": "point",
|
||||
"label": 1,
|
||||
"data": [int(img.width / 2), int(img.height / 2)],
|
||||
}
|
||||
],
|
||||
)
|
||||
schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user