Source code for linkml.generators.sqlalchemygen

import logging
import os
from collections import defaultdict
from dataclasses import dataclass
from types import ModuleType
from typing import List, Optional, Union

import click
from jinja2 import Template
from linkml_runtime.linkml_model import Annotation, ClassDefinition, ClassDefinitionName, SchemaDefinition
from linkml_runtime.utils.compile_python import compile_python
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView
from sqlalchemy import Enum

from linkml._version import __version__
from linkml.generators.pydanticgen import PydanticGenerator
from linkml.generators.pythongen import PythonGenerator
from linkml.generators.sqlalchemy import sqlalchemy_declarative_template_str, sqlalchemy_imperative_template_str
from linkml.generators.sqltablegen import SQLTableGenerator
from linkml.transformers.relmodel_transformer import ForeignKeyPolicy, RelationalModelTransformer
from linkml.utils.generator import Generator, shared_arguments

logger = logging.getLogger(__name__)


class TemplateEnum(Enum):
    DECLARATIVE = "declarative"
    IMPERATIVE = "imperative"


[docs]@dataclass class SQLAlchemyGenerator(Generator): """ Generates SQL Alchemy classes See also: :class:`~linkml.generators.sqltablegen.SQLTableGenerator` """ # ClassVars generatorname = os.path.basename(__file__) generatorversion = "0.1.1" valid_formats = ["sqla"] file_extension = "py" uses_schemaloader = False template: Optional[TemplateEnum] = None # ObjectVars original_schema: Union[SchemaDefinition, str] = None def __post_init__(self) -> None: self.original_schema = self.schema self.schemaview = SchemaView(self.schema) super().__post_init__() def generate_sqla( self, model_path: Optional[str] = None, no_model_import: bool = False, template: Optional[TemplateEnum] = None, foreign_key_policy: Optional[ForeignKeyPolicy] = None, **kwargs, ) -> str: # src_sv = SchemaView(self.schema) # self.schema = src_sv.schema template = template or self.template or TemplateEnum.IMPERATIVE sqltr = RelationalModelTransformer(self.schemaview) if foreign_key_policy: sqltr.foreign_key_policy = foreign_key_policy tgen = SQLTableGenerator(self.schemaview.schema) tr_result = sqltr.transform(**kwargs) tr_schema = tr_result.schema for c in tr_schema.classes.values(): for a in c.attributes.values(): sql_range = tgen.get_sql_range(a, tr_schema) sql_type = sql_range.__repr__() ann = Annotation("sql_type", sql_type) a.annotations[ann.tag] = ann # a.sql_type = sql_type if template == TemplateEnum.IMPERATIVE: template_str = sqlalchemy_imperative_template_str elif template == TemplateEnum.DECLARATIVE: template_str = sqlalchemy_declarative_template_str else: raise Exception(f"Unknown template type: {template}") template_obj = Template(template_str) if model_path is None: model_path = self.schema.name logger.info(f"Package for dataclasses == {model_path}") backrefs = defaultdict(list) for m in tr_result.mappings: backrefs[m.source_class].append(m) # for c in tr_schema.classes.values(): # if len(c.attributes) == 0: # skip[c.name] = True # #raise ValueError(f'Class must have attrs: {c.name}') self.add_safe_aliases(tr_schema) tr_sv = SchemaView(tr_schema) rel_schema_classes_ordered = [tr_sv.get_class(cn, strict=True) for cn in self.order_classes_by_hierarchy(tr_sv)] rel_schema_classes_ordered = [c for c in rel_schema_classes_ordered if not self.skip(c)] for c in rel_schema_classes_ordered: # For SQLA there needs to be a primary key for each class; # autogenerate this as a compound key if none declared has_pk = any(a for a in c.attributes.values() if "primary_key" in a.annotations) if not has_pk: for a in c.attributes.values(): ann = Annotation("primary_key", "true") a.annotations[ann.tag] = ann code = template_obj.render( model_path=model_path, mappings=tr_result.mappings, backrefs=backrefs, classname=camelcase, no_model_import=no_model_import, is_join_table=lambda c: any(tag for tag in c.annotations.keys() if tag == "linkml:derived_from"), classes=rel_schema_classes_ordered, ) logger.debug(f"# Generated code:\n{code}") return code
[docs] def serialize(self, **kwargs) -> str: return self.generate_sqla(**kwargs)
def compile_sqla( self, compile_python_dataclasses=False, pydantic=False, model_path=None, template: TemplateEnum = TemplateEnum.IMPERATIVE, **kwargs, ) -> ModuleType: """ Generates and compiles SQL Alchemy bindings - If template is DECLARATIVE, then a single python module with classes is generated - If template is IMPERATIVE, only mappings are generated - if compile_python_dataclasses is True then a standard datamodel is generated :param compile_python_dataclasses: (default False) :param pydantic: :param model_path: :param template: :param kwargs: :return: """ if model_path is None: model_path = self.schema.name if template == TemplateEnum.DECLARATIVE: sqla_code = self.generate_sqla(model_path=None, no_model_import=True, template=template, **kwargs) return compile_python(sqla_code, package_path=model_path) elif compile_python_dataclasses: # concatenate the python dataclasses with the sqla code if pydantic: # mixin inheritance doesn't get along with SQLAlchemy's imperative (aka classical) mapping pygen = PydanticGenerator(self.original_schema, extra_fields="allow", gen_mixin_inheritance=False) else: pygen = PythonGenerator(self.original_schema) dc_code = pygen.serialize() sqla_code = self.generate_sqla(model_path=None, no_model_import=True, **kwargs) return compile_python(f"{dc_code}\n{sqla_code}", package_path=model_path) else: code = self.generate_sqla(model_path=model_path, **kwargs) return compile_python(code, package_path=model_path) @staticmethod def add_safe_aliases(schema: SchemaDefinition) -> None: for c in schema.classes.values(): # c.alias = underscore(c.name) for a in c.attributes.values(): a.alias = underscore(a.name) @staticmethod def skip(cls: ClassDefinition) -> bool: is_skip = len(cls.attributes) == 0 if is_skip: logger.error(f"SKIPPING: {cls.name}") return is_skip # TODO: move this @staticmethod def order_classes_by_hierarchy(sv: SchemaView) -> List[ClassDefinitionName]: olist = sv.class_roots() unprocessed = [cn for cn in sv.all_classes() if cn not in olist] while len(unprocessed) > 0: ext_list = [cn for cn in unprocessed if not any(p for p in sv.class_parents(cn) if p not in olist)] if len(ext_list) == 0: raise ValueError(f"Cycle in hierarchy, cannot process: {unprocessed}") olist += ext_list unprocessed = [cn for cn in unprocessed if cn not in olist] return olist
@shared_arguments(SQLAlchemyGenerator) @click.option( "--declarative/--no-declarative", default=True, show_default=True, help="Generate SQL Alchemy declarative vs imperative", ) @click.option( "--generate-classes/--no-generate-classes", default=False, show_default=True, help="If True, generate Python datamodel (imperative mode only)", ) @click.option( "--pydantic/--no-pydantic", default=False, show_default=True, help="If True, generate Pydantic classes (imperative mode only)", ) @click.option( "--use-foreign-keys/--no-use-foreign-keys", default=True, show_default=True, help="Emit FK declarations", ) @click.version_option(__version__, "-V", "--version") @click.command(name="sqla") def cli(yamlfile, declarative, generate_classes, pydantic, use_foreign_keys=True, **args): """Generate SQL DDL representation""" if pydantic: pygen = PydanticGenerator(yamlfile) print(pygen.serialize()) gen = SQLAlchemyGenerator(yamlfile, **args) if declarative: t = TemplateEnum.DECLARATIVE else: t = TemplateEnum.IMPERATIVE if use_foreign_keys: foreign_key_policy = None # default else: foreign_key_policy = ForeignKeyPolicy.NO_FOREIGN_KEYS print(gen.generate_sqla(template=t, foreign_key_policy=foreign_key_policy)) if generate_classes: raise NotImplementedError("generate classes not implemented") if __name__ == "__main__": cli()