import os
from itertools import product
from pathlib import Path
from typing import List, Tuple

import black
from jinja2 import Template
from pydantic import BaseModel

template_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py.jinja2"
destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py"


number_of_types = 4


class Arg(BaseModel):
    name: str
    annotation: str


arg_groups: List[Arg] = []

signatures: List[Tuple[List[Arg], List[str]]] = []

for total_args in range(2, number_of_types + 1):
    arg_types_tuples = product(["scalar", "model"], repeat=total_args)
    for arg_type_tuple in arg_types_tuples:
        args: List[Arg] = []
        return_types: List[str] = []
        for i, arg_type in enumerate(arg_type_tuple):
            if arg_type == "scalar":
                t_var = f"_TScalar_{i}"
                arg = Arg(name=f"entity_{i}", annotation=t_var)
                ret_type = t_var
            else:
                t_type = f"_TModel_{i}"
                t_var = f"Type[{t_type}]"
                arg = Arg(name=f"entity_{i}", annotation=t_var)
                ret_type = t_type
            args.append(arg)
            return_types.append(ret_type)
        signatures.append((args, return_types))

template: Template = Template(template_path.read_text())

result = template.render(number_of_types=number_of_types, signatures=signatures)

result = (
    "# WARNING: do not modify this code, it is generated by "
    "expression.py.jinja2\n\n" + result
)

result = black.format_str(result, mode=black.Mode())

current_content = destiny_path.read_text()

if current_content != result and os.getenv("CHECK_JINJA"):
    raise RuntimeError(
        "sqlmodel/sql/expression.py content not update with Jinja2 template"
    )

destiny_path.write_text(result)