mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 17:15:53 +08:00
commit
5435d2ffee
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Tuple
|
from typing import List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -105,6 +105,7 @@ class SamSession(BaseSession):
|
|||||||
valid_providers = []
|
valid_providers = []
|
||||||
available_providers = ort.get_available_providers()
|
available_providers = ort.get_available_providers()
|
||||||
|
|
||||||
|
if providers:
|
||||||
for provider in providers or []:
|
for provider in providers or []:
|
||||||
if provider in available_providers:
|
if provider in available_providers:
|
||||||
valid_providers.append(provider)
|
valid_providers.append(provider)
|
||||||
@ -142,7 +143,16 @@ class SamSession(BaseSession):
|
|||||||
Returns:
|
Returns:
|
||||||
List[PILImage]: A list of masks generated by the decoder.
|
List[PILImage]: A list of masks generated by the decoder.
|
||||||
"""
|
"""
|
||||||
prompt = kwargs.get("sam_prompt", "{}")
|
prompt = kwargs.get(
|
||||||
|
"sam_prompt",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "point",
|
||||||
|
"label": 1,
|
||||||
|
"data": [int(img.width / 2), int(img.height / 2)],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
schema = {
|
schema = {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user