import enum
from typing import Generic, TypeVar

from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID


class StringUUID(TypeDecorator):
    impl = CHAR
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return value
        elif dialect.name == "postgresql":
            return str(value)
        else:
            return value.hex

    def load_dialect_impl(self, dialect):
        if dialect.name == "postgresql":
            return dialect.type_descriptor(UUID())
        else:
            return dialect.type_descriptor(CHAR(36))

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        return str(value)


_E = TypeVar("_E", bound=enum.StrEnum)


class EnumText(TypeDecorator, Generic[_E]):
    impl = VARCHAR
    cache_ok = True

    _length: int
    _enum_class: type[_E]

    def __init__(self, enum_class: type[_E], length: int | None = None):
        self._enum_class = enum_class
        max_enum_value_len = max(len(e.value) for e in enum_class)
        if length is not None:
            if length < max_enum_value_len:
                raise ValueError("length should be greater than enum value length.")
            self._length = length
        else:
            # leave some rooms for future longer enum values.
            self._length = max(max_enum_value_len, 20)

    def process_bind_param(self, value: _E | str | None, dialect):
        if value is None:
            return value
        if isinstance(value, self._enum_class):
            return value.value
        elif isinstance(value, str):
            self._enum_class(value)
            return value
        else:
            raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")

    def load_dialect_impl(self, dialect):
        return dialect.type_descriptor(VARCHAR(self._length))

    def process_result_value(self, value, dialect) -> _E | None:
        if value is None:
            return value
        if not isinstance(value, str):
            raise TypeError(f"expected str, got {type(value)}")
        return self._enum_class(value)

    def compare_values(self, x, y):
        if x is None or y is None:
            return x is y
        return x == y