Source code for linkml.generators.pydanticgen.pydanticgen

import inspect
import logging
import os
import textwrap
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import (
    Dict,
    List,
    Literal,
    Optional,
    Set,
    Type,
    Union,
)

import click
from jinja2 import ChoiceLoader, Environment, FileSystemLoader
from linkml_runtime.linkml_model.meta import (
    Annotation,
    ClassDefinition,
    SchemaDefinition,
    SlotDefinition,
    TypeDefinition,
)
from linkml_runtime.utils.compile_python import compile_python
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView
from pydantic.version import VERSION as PYDANTIC_VERSION

from linkml._version import __version__
from linkml.generators.common.type_designators import (
    get_accepted_type_designator_values,
    get_type_designator_value,
)
from linkml.generators.oocodegen import OOCodeGenerator
from linkml.generators.pydanticgen.array import ArrayRangeGenerator, ArrayRepresentation
from linkml.generators.pydanticgen.build import SlotResult
from linkml.generators.pydanticgen.template import (
    ConditionalImport,
    Import,
    Imports,
    ObjectImport,
    PydanticAttribute,
    PydanticBaseModel,
    PydanticClass,
    PydanticModule,
    TemplateModel,
)
from linkml.utils import deprecation_warning
from linkml.utils.generator import shared_arguments
from linkml.utils.ifabsent_functions import ifabsent_value_declaration

if int(PYDANTIC_VERSION[0]) == 1:
    deprecation_warning("pydantic-v1")


def _get_pyrange(t: TypeDefinition, sv: SchemaView) -> str:
    pyrange = t.repr if t is not None else None
    if pyrange is None:
        pyrange = t.base
    if t.base == "XSDDateTime":
        pyrange = "datetime "
    if t.base == "XSDDate":
        pyrange = "date"
    if pyrange is None and t.typeof is not None:
        pyrange = _get_pyrange(sv.get_type(t.typeof), sv)
    if pyrange is None:
        raise Exception(f"No python type for range: {t.name} // {t}")
    return pyrange


DEFAULT_IMPORTS = (
    Imports()
    + Import(module="__future__", objects=[ObjectImport(name="annotations")])
    + Import(module="datetime", objects=[ObjectImport(name="datetime"), ObjectImport(name="date")])
    + Import(module="decimal", objects=[ObjectImport(name="Decimal")])
    + Import(module="enum", objects=[ObjectImport(name="Enum")])
    + Import(module="re")
    + Import(module="sys")
    + Import(
        module="typing",
        objects=[
            ObjectImport(name="Any"),
            ObjectImport(name="List"),
            ObjectImport(name="Literal"),
            ObjectImport(name="Dict"),
            ObjectImport(name="Optional"),
            ObjectImport(name="Union"),
        ],
    )
    + Import(module="pydantic.version", objects=[ObjectImport(name="VERSION", alias="PYDANTIC_VERSION")])
    + ConditionalImport(
        condition="int(PYDANTIC_VERSION[0])>=2",
        module="pydantic",
        objects=[
            ObjectImport(name="BaseModel"),
            ObjectImport(name="ConfigDict"),
            ObjectImport(name="Field"),
            ObjectImport(name="field_validator"),
        ],
        alternative=Import(
            module="pydantic",
            objects=[ObjectImport(name="BaseModel"), ObjectImport(name="Field"), ObjectImport(name="validator")],
        ),
    )
)


[docs]@dataclass class PydanticGenerator(OOCodeGenerator): """ Generates Pydantic-compliant classes from a schema This is an alternative to the dataclasses-based Pythongen """ # ClassVar overrides generatorname = os.path.basename(__file__) generatorversion = "0.0.2" valid_formats = ["pydantic"] file_extension = "py" # ObjectVars array_representations: List[ArrayRepresentation] = field(default_factory=lambda: [ArrayRepresentation.LIST]) black: bool = False """ If black is present in the environment, format the serialized code with it """ pydantic_version: int = int(PYDANTIC_VERSION[0]) template_dir: Optional[Union[str, Path]] = None """ Override templates for each TemplateModel. Directory with templates that override the default :attr:`.TemplateModel.template` for each class. If a matching template is not found in the override directory, the default templates will be used. """ extra_fields: Literal["allow", "forbid", "ignore"] = "forbid" gen_mixin_inheritance: bool = True injected_classes: Optional[List[Union[Type, str]]] = None """ A list/tuple of classes to inject into the generated module. Accepts either live classes or strings. Live classes will have their source code extracted with inspect.get - so they need to be standard python classes declared in a source file (ie. the module they are contained in needs a ``__file__`` attr, see: :func:`inspect.getsource` ) """ injected_fields: Optional[List[str]] = None """ A list/tuple of field strings to inject into the base class. Examples: .. code-block:: python injected_fields = ( 'object_id: Optional[str] = Field(None, description="Unique UUID for each object")', ) """ imports: Optional[List[Import]] = None """ Additional imports to inject into generated module. Examples: .. code-block:: python from linkml.generators.pydanticgen.template import ( ConditionalImport, ObjectImport, Import, Imports ) imports = (Imports() + Import(module='sys') + Import(module='numpy', alias='np') + Import(module='pathlib', objects=[ ObjectImport(name="Path"), ObjectImport(name="PurePath", alias="RenamedPurePath") ]) + ConditionalImport( module="typing", objects=[ObjectImport(name="Literal")], condition="sys.version_info >= (3, 8)", alternative=Import( module="typing_extensions", objects=[ObjectImport(name="Literal")] ), ).imports ) becomes: .. code-block:: python import sys import numpy as np from pathlib import ( Path, PurePath as RenamedPurePath ) if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal """ # ObjectVars (identical to pythongen) gen_classvars: bool = True gen_slots: bool = True genmeta: bool = False emit_metadata: bool = True def __post_init__(self): super().__post_init__() if int(self.pydantic_version) == 1: deprecation_warning("pydanticgen-v1")
[docs] def compile_module(self, **kwargs) -> ModuleType: """ Compiles generated python code to a module :return: """ pycode = self.serialize(**kwargs) try: return compile_python(pycode) except NameError as e: logging.error(f"Code:\n{pycode}") logging.error(f"Error compiling generated python code: {e}") raise e
[docs] @staticmethod def sort_classes(clist: List[ClassDefinition]) -> List[ClassDefinition]: """ sort classes such that if C is a child of P then C appears after P in the list Overridden method include mixin classes TODO: This should move to SchemaView """ clist = list(clist) slist = [] # sorted while len(clist) > 0: can_add = False for i in range(len(clist)): candidate = clist[i] can_add = False if candidate.is_a: candidates = [candidate.is_a] + candidate.mixins else: candidates = candidate.mixins if not candidates: can_add = True else: if set(candidates) <= set([p.name for p in slist]): can_add = True if can_add: slist = slist + [candidate] del clist[i] break if not can_add: raise ValueError(f"could not find suitable element in {clist} that does not ref {slist}") return slist
[docs] def get_predefined_slot_values(self) -> Dict[str, Dict[str, str]]: """ :return: Dictionary of dictionaries with predefined slot values for each class """ sv = self.schemaview slot_values = defaultdict(dict) for class_def in sv.all_classes().values(): for slot_name in sv.class_slots(class_def.name): slot = sv.induced_slot(slot_name, class_def.name) if slot.designates_type: target_value = get_type_designator_value(sv, slot, class_def) slot_values[camelcase(class_def.name)][slot.name] = f'"{target_value}"' if slot.multivalued: slot_values[camelcase(class_def.name)][slot.name] = ( "[" + slot_values[camelcase(class_def.name)][slot.name] + "]" ) slot_values[camelcase(class_def.name)][slot.name] = slot_values[camelcase(class_def.name)][ slot.name ] elif slot.ifabsent is not None: value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot) slot_values[camelcase(class_def.name)][slot.name] = value # Multivalued slots that are either not inlined (just an identifier) or are # inlined as lists should get default_factory list, if they're inlined but # not as a list, that means a dictionary elif "linkml:elements" in slot.implements: slot_values[camelcase(class_def.name)][slot.name] = None elif slot.multivalued: has_identifier_slot = self.range_class_has_identifier_slot(slot) if slot.inlined and not slot.inlined_as_list and has_identifier_slot: slot_values[camelcase(class_def.name)][slot.name] = "default_factory=dict" else: slot_values[camelcase(class_def.name)][slot.name] = "default_factory=list" return slot_values
[docs] def range_class_has_identifier_slot(self, slot): """ Check if the range class of a slot has an identifier slot, via both slot.any_of and slot.range Should return False if the range is not a class, and also if the range is a class but has no identifier slot :param slot: SlotDefinition :return: bool """ sv = self.schemaview has_identifier_slot = False if slot.any_of: for slot_range in slot.any_of: any_of_range = slot_range.range if any_of_range in sv.all_classes() and sv.get_identifier_slot(any_of_range, use_key=True) is not None: has_identifier_slot = True if slot.range in sv.all_classes() and sv.get_identifier_slot(slot.range, use_key=True) is not None: has_identifier_slot = True return has_identifier_slot
[docs] def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]: """ Generate the inheritance list for each class from is_a plus mixins :return: """ sv = self.schemaview parents = {} for class_def in sv.all_classes().values(): class_parents = [] if class_def.is_a: class_parents.append(camelcase(class_def.is_a)) if self.gen_mixin_inheritance and class_def.mixins: class_parents.extend([camelcase(mixin) for mixin in class_def.mixins]) if len(class_parents) > 0: # Use the sorted list of classes to order the parent classes, but reversed to match MRO needs class_parents.sort(key=lambda x: self.sorted_class_names.index(x)) class_parents.reverse() parents[camelcase(class_def.name)] = class_parents return parents
def get_mixin_identifier_range(self, mixin) -> str: sv = self.schemaview id_ranges = list( { _get_pyrange(sv.get_type(sv.get_identifier_slot(c).range), sv) for c in sv.class_descendants(mixin.name, mixins=True) if sv.get_identifier_slot(c) is not None } ) if len(id_ranges) == 0: return None elif len(id_ranges) == 1: return id_ranges[0] else: return f"Union[{'.'.join(id_ranges)}]" def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str: sv = self.schemaview range_cls = sv.get_class(slot_range) # Hardcoded handling for Any if range_cls.class_uri == "linkml:Any": return "Any" # Inline the class itself only if the class is defined as inline, or if the class has no # identifier slot and also isn't a mixin. if ( inlined or inlined_as_list or (sv.get_identifier_slot(range_cls.name, use_key=True) is None and not sv.is_mixin(range_cls.name)) ): if ( len([x for x in sv.class_induced_slots(slot_range) if x.designates_type]) > 0 and len(sv.class_descendants(slot_range)) > 1 ): return "Union[" + ",".join([camelcase(c) for c in sv.class_descendants(slot_range)]) + "]" else: return f"{camelcase(slot_range)}" # For the more difficult cases, set string as the default and attempt to improve it range_cls_identifier_slot_range = "str" # For mixins, try to use the identifier slot of descendant classes if self.gen_mixin_inheritance and sv.is_mixin(range_cls.name) and sv.get_identifier_slot(range_cls.name): range_cls_identifier_slot_range = self.get_mixin_identifier_range(range_cls) # If the class itself has an identifier slot, it can be allowed to overwrite a value from mixin above if ( sv.get_identifier_slot(range_cls.name) is not None and sv.get_identifier_slot(range_cls.name).range is not None ): range_cls_identifier_slot_range = _get_pyrange( sv.get_type(sv.get_identifier_slot(range_cls.name).range), sv ) return range_cls_identifier_slot_range
[docs] def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str: """ Generate the python range for a slot range value """ sv = self.schemaview if slot_def.designates_type: pyrange = ( "Literal[" + ",".join(['"' + x + '"' for x in get_accepted_type_designator_values(sv, slot_def, class_def)]) + "]" ) elif slot_range in sv.all_classes(): pyrange = self.get_class_slot_range( slot_range, inlined=slot_def.inlined, inlined_as_list=slot_def.inlined_as_list, ) elif slot_range in sv.all_enums(): pyrange = f"{camelcase(slot_range)}" elif slot_range in sv.all_types(): t = sv.get_type(slot_range) pyrange = _get_pyrange(t, sv) elif slot_range is None: pyrange = "str" else: # TODO: default ranges in schemagen # pyrange = 'str' # logging.error(f'range: {s.range} is unknown') raise Exception(f"range: {slot_range}") return pyrange
[docs] def generate_collection_key( self, slot_ranges: List[str], slot_def: SlotDefinition, class_def: ClassDefinition, ) -> Optional[str]: """ Find the python range value (str, int, etc) for the identifier slot of a class used as a slot range. If a pyrange value matches a class name, the range of the identifier slot will be returned. If more than one match is found and they don't match, an exception will be raised. :param slot_ranges: list of python range values """ collection_keys: Set[str] = set() if slot_ranges is None: return None for slot_range in slot_ranges: if slot_range is None or slot_range not in self.schemaview.all_classes(): continue # ignore non-class ranges identifier_slot = self.schemaview.get_identifier_slot(slot_range, use_key=True) if identifier_slot is not None: collection_keys.add(self.generate_python_range(identifier_slot.range, slot_def, class_def)) if len(collection_keys) > 1: raise Exception(f"Slot with any_of range has multiple identifier slot range types: {collection_keys}") if len(collection_keys) == 1: return list(collection_keys)[0] return None
@staticmethod def _inline_as_simple_dict_with_value(slot_def: SlotDefinition, sv: SchemaView) -> Optional[str]: """ Determine if a slot should be inlined as a simple dict with a value. For example, if we have a class such as Prefix, with two slots prefix and expansion, then an inlined list of prefixes can be serialized as: .. code-block:: yaml prefixes: prefix1: expansion1 prefix2: expansion2 ... Provided that the prefix slot is the identifier slot for the Prefix class. TODO: move this to SchemaView :param slot_def: SlotDefinition :param sv: SchemaView :return: str """ if slot_def.inlined and not slot_def.inlined_as_list: if slot_def.range in sv.all_classes(): id_slot = sv.get_identifier_slot(slot_def.range, use_key=True) if id_slot is not None: range_cls_slots = sv.class_induced_slots(slot_def.range) if len(range_cls_slots) == 2: non_id_slots = [slot for slot in range_cls_slots if slot.name != id_slot.name] if len(non_id_slots) == 1: value_slot = non_id_slots[0] value_slot_range_type = sv.get_type(value_slot.range) if value_slot_range_type is not None: return _get_pyrange(value_slot_range_type, sv) return None def _template_environment(self) -> Environment: env = TemplateModel.environment() if self.template_dir is not None: loader = ChoiceLoader([FileSystemLoader(self.template_dir), env.loader]) env.loader = loader return env
[docs] def get_array_representations_range(self, slot: SlotDefinition, range: str) -> List[SlotResult]: """ Generate the python range for array representations """ array_reps = [] for repr in self.array_representations: generator = ArrayRangeGenerator.get_generator(repr) result = generator(slot.array, range, self.pydantic_version).make() array_reps.append(result) if len(array_reps) == 0: raise ValueError("No array representation generated, but one was requested!") return array_reps
def render(self) -> PydanticModule: sv: SchemaView sv = self.schemaview schema = sv.schema pyschema = SchemaDefinition( id=schema.id, name=schema.name, description=schema.description.replace('"', '\\"') if schema.description else None, ) enums = self.generate_enums(sv.all_enums()) injected_classes = [] if self.injected_classes is not None: injected_classes += self.injected_classes imports = DEFAULT_IMPORTS if self.imports is not None: for i in self.imports: imports += i sorted_classes = self.sort_classes(list(sv.all_classes().values())) self.sorted_class_names = [camelcase(c.name) for c in sorted_classes] # Don't want to generate classes when class_uri is linkml:Any, will # just swap in typing.Any instead down below sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"] for class_original in sorted_classes: class_def: ClassDefinition class_def = deepcopy(class_original) class_name = class_original.name class_def.name = camelcase(class_original.name) if class_def.is_a: class_def.is_a = camelcase(class_def.is_a) class_def.mixins = [camelcase(p) for p in class_def.mixins] if class_def.description: class_def.description = class_def.description.replace('"', '\\"') pyschema.classes[class_def.name] = class_def for attribute in list(class_def.attributes.keys()): del class_def.attributes[attribute] for sn in sv.class_slots(class_name): # TODO: fix runtime, copy should not be necessary s = deepcopy(sv.induced_slot(sn, class_name)) # logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}') s.name = underscore(s.name) if s.description: s.description = s.description.replace('"', '\\"') class_def.attributes[s.name] = s slot_ranges: List[str] = [] # Confirm that the original slot range (ignoring the default that comes in from # induced_slot) isn't in addition to setting any_of any_of_ranges = [a.range if a.range else s.range for a in s.any_of] if any_of_ranges: # list comprehension here is pulling ranges from within AnonymousSlotExpression slot_ranges.extend(any_of_ranges) else: slot_ranges.append(s.range) pyranges = [self.generate_python_range(slot_range, s, class_def) for slot_range in slot_ranges] pyranges = list(set(pyranges)) # remove duplicates pyranges.sort() if len(pyranges) == 1: pyrange = pyranges[0] elif len(pyranges) > 1: pyrange = f"Union[{', '.join(pyranges)}]" else: raise Exception(f"Could not generate python range for {class_name}.{s.name}") if s.array is not None: # TODO add support for xarray results = self.get_array_representations_range(s, pyrange) # TODO: Move results unpacking to own function that is used after each slot build stage :) for res in results: if res.injected_classes: injected_classes += res.injected_classes if res.imports: imports += res.imports if len(results) == 1: pyrange = results[0].annotation else: pyrange = f"Union[{', '.join([res.annotation for res in results])}]" if "linkml:ColumnOrderedArray" in class_def.implements: raise NotImplementedError("Cannot generate Pydantic code for ColumnOrderedArrays.") elif s.multivalued: if s.inlined or s.inlined_as_list: collection_key = self.generate_collection_key(slot_ranges, s, class_def) else: collection_key = None if s.inlined is False or collection_key is None or s.inlined_as_list is True: pyrange = f"List[{pyrange}]" else: simple_dict_value = None if len(slot_ranges) == 1: simple_dict_value = self._inline_as_simple_dict_with_value(s, sv) if simple_dict_value: # inlining as simple dict pyrange = f"Dict[str, {simple_dict_value}]" else: pyrange = f"Dict[{collection_key}, {pyrange}]" if not (s.required or s.identifier or s.key) and not s.designates_type: pyrange = f"Optional[{pyrange}]" ann = Annotation("python_range", pyrange) s.annotations[ann.tag] = ann # TODO: Make cleaning injected classes its own method injected_classes = list( dict.fromkeys([c if isinstance(c, str) else inspect.getsource(c) for c in injected_classes]) ) injected_classes = [textwrap.dedent(c) for c in injected_classes] base_model = PydanticBaseModel( pydantic_ver=self.pydantic_version, extra_fields=self.extra_fields, fields=self.injected_fields ) classes = {} predefined = self.get_predefined_slot_values() bases = self.get_class_isa_plus_mixins() for k, c in pyschema.classes.items(): attrs = {} for attr_name, src_attr in c.attributes.items(): src_attr = src_attr._as_dict new_fields = { k: src_attr.get(k, None) for k in PydanticAttribute.model_fields.keys() if src_attr.get(k, None) is not None } predef_slot = predefined.get(k, {}).get(attr_name, None) if predef_slot is not None: predef_slot = str(predef_slot) new_fields["predefined"] = predef_slot new_fields["name"] = attr_name attrs[attr_name] = PydanticAttribute(**new_fields, pydantic_ver=self.pydantic_version) new_class = PydanticClass( name=k, attributes=attrs, description=c.description, pydantic_ver=self.pydantic_version ) if k in bases: new_class.bases = bases[k] classes[k] = new_class module = PydanticModule( pydantic_ver=self.pydantic_version, metamodel_version=self.schema.metamodel_version, version=self.schema.version, imports=imports.imports, base_model=base_model, injected_classes=injected_classes, enums=enums, classes=classes, ) return module
[docs] def serialize(self) -> str: module = self.render() return module.render(self._template_environment(), self.black)
def default_value_for_type(self, typ: str) -> str: return "None"
def _subclasses(cls: Type): return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in _subclasses(c)]) _TEMPLATE_NAMES = sorted(list(set([c.template for c in _subclasses(TemplateModel)]))) @shared_arguments(PydanticGenerator) @click.option("--template-file", hidden=True) @click.option( "--template-dir", type=click.Path(), help=""" Optional jinja2 template directory to use for class generation. Pass a directory containing templates with the same name as any of the default :class:`.TemplateModel` templates to override them. The given directory will be searched for matching templates, and use the default templates as a fallback if an override is not found Available templates to override: \b """ + "\n".join(["- " + name for name in _TEMPLATE_NAMES]), ) @click.option( "--pydantic-version", type=click.IntRange(1, 2), default=1, help="Pydantic version to use (1 or 2)", ) @click.option( "--array-representations", type=click.Choice([k.value for k in ArrayRepresentation]), multiple=True, default=["list"], help="List of array representations to accept for array slots. Default is list of lists.", ) @click.option( "--extra-fields", type=click.Choice(["allow", "ignore", "forbid"], case_sensitive=False), default="forbid", help="How to handle extra fields in BaseModel.", ) @click.version_option(__version__, "-V", "--version") @click.command() def cli( yamlfile, template_file=None, template_dir: Optional[str] = None, head=True, genmeta=False, classvars=True, slots=True, array_representations=list("list"), pydantic_version=1, extra_fields: Literal["allow", "forbid", "ignore"] = "forbid", **args, ): """Generate pydantic classes to represent a LinkML model""" if template_file is not None: raise DeprecationWarning( ( "Passing a single template_file is deprecated. Pass a directory of template files instead. " "See help string for --template-dir" ) ) if template_dir is not None: if not Path(template_dir).exists(): raise FileNotFoundError(f"The template directory {template_dir} does not exist!") gen = PydanticGenerator( yamlfile, pydantic_version=pydantic_version, array_representations=[ArrayRepresentation(x) for x in array_representations], extra_fields=extra_fields, emit_metadata=head, genmeta=genmeta, gen_classvars=classvars, gen_slots=slots, template_dir=template_dir, **args, ) print(gen.serialize()) if __name__ == "__main__": cli()