From 0f01b75a0540d977e91e7228503af3ea8aff670f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 24 Aug 2021 15:16:41 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Add=20sqlmodel.sql.expression=20?= =?UTF-8?q?generation=20script=20(select=20overloads)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/generate_select.py | 55 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 scripts/generate_select.py diff --git a/scripts/generate_select.py b/scripts/generate_select.py new file mode 100644 index 0000000..b66a167 --- /dev/null +++ b/scripts/generate_select.py @@ -0,0 +1,55 @@ +from itertools import product +from pathlib import Path +from typing import List, Tuple + +import black +from jinja2 import Template +from pydantic import BaseModel + +template_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py.jinja2" +destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py" + + +number_of_types = 4 + + +class Arg(BaseModel): + name: str + annotation: str + + +arg_groups: List[Arg] = [] + +signatures: List[Tuple[List[Arg], List[str]]] = [] + +for total_args in range(2, number_of_types + 1): + arg_types_tuples = product(["scalar", "model"], repeat=total_args) + for arg_type_tuple in arg_types_tuples: + args: List[Arg] = [] + return_types: List[str] = [] + for i, arg_type in enumerate(arg_type_tuple): + if arg_type == "scalar": + t_var = f"_TScalar_{i}" + arg = Arg(name=f"entity_{i}", annotation=t_var) + ret_type = t_var + else: + t_type = f"_TModel_{i}" + t_var = f"Type[{t_type}]" + arg = Arg(name=f"entity_{i}", annotation=t_var) + ret_type = t_type + args.append(arg) + return_types.append(ret_type) + signatures.append((args, return_types)) + +template: Template = Template(template_path.read_text()) + +result = template.render(number_of_types=number_of_types, signatures=signatures) + +result = ( + "# WARNING: do not modify this code, it is generated by " + "expression.py.jinja2\n\n" + result +) + +result = black.format_str(result, mode=black.Mode()) + +destiny_path.write_text(result)