import inspect
import logging
import os
import re
import textwrap
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union, overload
import click
from jinja2 import ChoiceLoader, Environment, FileSystemLoader, Template
from linkml_runtime.linkml_model.meta import (
ClassDefinition,
ElementName,
SchemaDefinition,
SlotDefinition,
TypeDefinition,
)
from linkml_runtime.utils.compile_python import compile_python
from linkml_runtime.utils.formatutils import camelcase, remove_empty_items, underscore
from linkml_runtime.utils.schemaview import SchemaView
from pydantic.version import VERSION as PYDANTIC_VERSION
from linkml._version import __version__
from linkml.generators.common.lifecycle import LifecycleMixin
from linkml.generators.common.type_designators import get_accepted_type_designator_values, get_type_designator_value
from linkml.generators.oocodegen import OOCodeGenerator
from linkml.generators.pydanticgen import includes
from linkml.generators.pydanticgen.array import ArrayRangeGenerator, ArrayRepresentation
from linkml.generators.pydanticgen.build import ClassResult, SlotResult, SplitResult
from linkml.generators.pydanticgen.template import (
Import,
Imports,
ObjectImport,
PydanticAttribute,
PydanticBaseModel,
PydanticClass,
PydanticModule,
PydanticTemplateModel,
)
from linkml.generators.python.python_ifabsent_processor import PythonIfAbsentProcessor
from linkml.utils import deprecation_warning
from linkml.utils.generator import shared_arguments
if int(PYDANTIC_VERSION[0]) == 1:
deprecation_warning("pydantic-v1")
def _get_pyrange(t: TypeDefinition, sv: SchemaView) -> str:
pyrange = t.repr if t is not None else None
if pyrange is None:
pyrange = t.base
if t.base == "XSDDateTime":
pyrange = "datetime "
if t.base == "XSDDate":
pyrange = "date"
if pyrange is None and t.typeof is not None:
pyrange = _get_pyrange(sv.get_type(t.typeof), sv)
if pyrange is None:
raise Exception(f"No python type for range: {t.name} // {t}")
return pyrange
DEFAULT_IMPORTS = (
Imports()
+ Import(module="__future__", objects=[ObjectImport(name="annotations")])
+ Import(
module="datetime", objects=[ObjectImport(name="datetime"), ObjectImport(name="date"), ObjectImport(name="time")]
)
+ Import(module="decimal", objects=[ObjectImport(name="Decimal")])
+ Import(module="enum", objects=[ObjectImport(name="Enum")])
+ Import(module="re")
+ Import(module="sys")
+ Import(
module="typing",
objects=[
ObjectImport(name="Any"),
ObjectImport(name="ClassVar"),
ObjectImport(name="List"),
ObjectImport(name="Literal"),
ObjectImport(name="Dict"),
ObjectImport(name="Optional"),
ObjectImport(name="Union"),
],
)
+ Import(
module="pydantic",
objects=[
ObjectImport(name="BaseModel"),
ObjectImport(name="ConfigDict"),
ObjectImport(name="Field"),
ObjectImport(name="RootModel"),
ObjectImport(name="field_validator"),
],
)
)
DEFAULT_INJECTS = [includes.LinkMLMeta]
class MetadataMode(str, Enum):
FULL = "full"
"""
all metadata from the source schema will be included, even if it is represented by the template classes,
and even if it is represented by some child class (eg. "classes" will be included with schema metadata
"""
EXCEPT_CHILDREN = "except_children"
"""
all metadata from the source schema will be included, even if it is represented by the template classes,
except if it is represented by some child template class (eg. "classes" will be excluded from schema metadata)
"""
AUTO = "auto"
"""
Only the metadata that isn't represented by the template classes or excluded with ``meta_exclude`` will be included
"""
NONE = None
"""
No metadata will be included.
"""
class SplitMode(str, Enum):
FULL = "full"
"""
Import all classes defined in imported schemas
"""
AUTO = "auto"
"""
Only import those classes that are actually used in the generated schema as
* parents (``is_a``)
* mixins
* slot ranges
"""
DefinitionType = TypeVar("DefinitionType", bound=Union[SchemaDefinition, ClassDefinition, SlotDefinition])
TemplateType = TypeVar("TemplateType", bound=Union[PydanticModule, PydanticClass, PydanticAttribute])
[docs]@dataclass
class PydanticGenerator(OOCodeGenerator, LifecycleMixin):
"""
Generates Pydantic-compliant classes from a schema
This is an alternative to the dataclasses-based Pythongen
Lifecycle methods (see :class:`.LifecycleMixin` ) supported:
* :meth:`~.LifecycleMixin.before_generate_enums`
Slot generation is nested within class generation, since the pydantic generator currently doesn't
create an independent representation of slots aside from their materialization as class fields.
Accordingly, the ``before_`` and ``after_generate_slots`` are called before and after each class's
slot generation, rather than all slot generation.
* :meth:`~.LifecycleMixin.before_generate_classes`
* :meth:`~.LifecycleMixin.before_generate_class`
* :meth:`~.LifecycleMixin.after_generate_class`
* :meth:`~.LifecycleMixin.after_generate_classes`
* :meth:`~.LifecycleMixin.before_generate_slots`
* :meth:`~.LifecycleMixin.before_generate_slot`
* :meth:`~.LifecycleMixin.after_generate_slot`
* :meth:`~.LifecycleMixin.after_generate_slots`
* :meth:`~.LifecycleMixin.before_render_template`
* :meth:`~.LifecycleMixin.after_render_template`
"""
# ClassVar overrides
generatorname = os.path.basename(__file__)
generatorversion = "0.0.2"
valid_formats = ["pydantic"]
file_extension = "py"
# ObjectVars
array_representations: List[ArrayRepresentation] = field(default_factory=lambda: [ArrayRepresentation.LIST])
black: bool = False
"""
If black is present in the environment, format the serialized code with it
"""
template_dir: Optional[Union[str, Path]] = None
"""
Override templates for each PydanticTemplateModel.
Directory with templates that override the default :attr:`.PydanticTemplateModel.template`
for each class. If a matching template is not found in the override directory,
the default templates will be used.
"""
extra_fields: Literal["allow", "forbid", "ignore"] = "forbid"
gen_mixin_inheritance: bool = True
injected_classes: Optional[List[Union[Type, str]]] = None
"""
A list/tuple of classes to inject into the generated module.
Accepts either live classes or strings. Live classes will have their source code
extracted with inspect.get - so they need to be standard python classes declared in a
source file (ie. the module they are contained in needs a ``__file__`` attr,
see: :func:`inspect.getsource` )
"""
injected_fields: Optional[List[str]] = None
"""
A list/tuple of field strings to inject into the base class.
Examples:
.. code-block:: python
injected_fields = (
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
)
"""
imports: Optional[List[Import]] = None
"""
Additional imports to inject into generated module.
Examples:
.. code-block:: python
from linkml.generators.pydanticgen.template import (
ConditionalImport,
ObjectImport,
Import,
Imports
)
imports = (Imports() +
Import(module='sys') +
Import(module='numpy', alias='np') +
Import(module='pathlib', objects=[
ObjectImport(name="Path"),
ObjectImport(name="PurePath", alias="RenamedPurePath")
]) +
ConditionalImport(
module="typing",
objects=[ObjectImport(name="Literal")],
condition="sys.version_info >= (3, 8)",
alternative=Import(
module="typing_extensions",
objects=[ObjectImport(name="Literal")]
),
).imports
)
becomes:
.. code-block:: python
import sys
import numpy as np
from pathlib import (
Path,
PurePath as RenamedPurePath
)
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
"""
sort_imports: bool = True
"""
Before returning from :meth:`.PydanticGenerator.render`, sort imports with :meth:`.Imports.sort`
Default ``True``, but optional in case import order must be explicitly given,
eg. to avoid circular import errors in complex generator subclasses.
"""
metadata_mode: Union[MetadataMode, str, None] = MetadataMode.AUTO
"""
How to include schema metadata in generated pydantic models.
See :class:`.MetadataMode` for mode documentation
"""
split: bool = False
"""
Generate schema that import other schema as separate python modules
that import from one another, rather than rolling all into a single
module (default, ``False``).
"""
split_pattern: str = ".{{ schema.name }}"
"""
When splitting generation, imported modules need to be generated separately
and placed in a python package and import from each other. Since the
location of those imported modules is variable -- e.g. one might want to
generate schema in multiple packages depending on their version -- this
pattern is used to generate the module portion of the import statement.
These patterns should generally yield a relative module import,
since functions like :func:`.generate_split` will generate and write files
relative to some base file, though this is not a requirement since custom
split generation logic is also allowed.
The pattern is a jinja template string that is given the ``SchemaDefinition``
of the imported schema in the environment. Additional variables can be passed
into the jinja environment with the :attr:`.split_context` argument.
Further modification is possible by using jinja filters.
After templating, the string is passed through a :attr:`SNAKE_CASE` pattern
to replace whitespace and other characters that can't be used in module names.
See also :meth:`.generate_module_import`, which is used to generate the
module portion of the import statement (and can be overridden in subclasses).
Examples:
for a schema named ``ExampleSchema`` and version ``1.2.3`` ...
``".{{ schema.name }}"`` (the default) becomes
``from .example_schema import ClassA, ...``
``"...{{ schema.name }}.v{{ schema.version | replace('.', '_') }}"`` becomes
``from ...example_schema.v1_2_3 import ClassA, ...``
"""
split_context: Optional[dict] = None
"""
Additional variables to pass into ``split_pattern`` when
generating imported module names.
Passed in as ``**kwargs`` , so e.g. if ``split_context = {'myval': 1}``
then one would use it in a template string like ``{{ myval }}``
"""
split_mode: SplitMode = SplitMode.AUTO
"""
How to filter imports from imported schema.
See :class:`.SplitMode` for description of options
"""
# ObjectVars (identical to pythongen)
gen_classvars: bool = True
gen_slots: bool = True
genmeta: bool = False
emit_metadata: bool = True
# ClassVars
SNAKE_CASE: ClassVar[str] = r"(((?<!^)(?<!\.))(?=[A-Z][a-z]))|([^\w\.]+)"
"""Substitute CamelCase and non-word characters with _"""
# Private attributes
_predefined_slot_values: Optional[Dict[str, Dict[str, str]]] = None
_class_bases: Optional[Dict[str, List[str]]] = None
def __post_init__(self):
super().__post_init__()
[docs] def compile_module(self, **kwargs) -> ModuleType:
"""
Compiles generated python code to a module
:return:
"""
pycode = self.serialize(**kwargs)
try:
return compile_python(pycode)
except NameError as e:
logging.error(f"Code:\n{pycode}")
logging.error(f"Error compiling generated python code: {e}")
raise e
def _get_classes(self, sv: SchemaView) -> Tuple[List[ClassDefinition], Optional[List[ClassDefinition]]]:
all_classes = sv.all_classes(imports=True).values()
if self.split:
local_classes = sv.all_classes(imports=False).values()
imported_classes = [c for c in all_classes if c not in local_classes]
return list(local_classes), imported_classes
else:
return list(all_classes), None
[docs] @staticmethod
def sort_classes(
clist: List[ClassDefinition], imported: Optional[List[ClassDefinition]] = None
) -> List[ClassDefinition]:
"""
sort classes such that if C is a child of P then C appears after P in the list
Overridden method include mixin classes
TODO: This should move to SchemaView
"""
if imported is not None:
imported = [i.name for i in imported]
clist = list(clist)
slist = [] # sorted
while len(clist) > 0:
can_add = False
for i in range(len(clist)):
candidate = clist[i]
can_add = False
if candidate.is_a:
candidates = [candidate.is_a] + candidate.mixins
else:
candidates = candidate.mixins
# remove blocking classes imported from other schemas if in split mode
if imported:
candidates = [c for c in candidates if c not in imported]
if not candidates:
can_add = True
else:
if set(candidates) <= set([p.name for p in slist]):
can_add = True
if can_add:
slist = slist + [candidate]
del clist[i]
break
if not can_add:
raise ValueError(f"could not find suitable element in {clist} that does not ref {slist}")
return slist
def generate_class(self, cls: ClassDefinition) -> ClassResult:
pyclass = PydanticClass(
name=camelcase(cls.name),
bases=self.class_bases.get(camelcase(cls.name), PydanticBaseModel.default_name),
description=cls.description.replace('"', '\\"') if cls.description is not None else None,
)
imports = self._get_imports(cls) if self.split else None
result = ClassResult(cls=pyclass, source=cls, imports=imports)
# Gather slots
slots = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)]
slots = self.before_generate_slots(slots, self.schemaview)
slot_results = []
for slot in slots:
slot = self.before_generate_slot(slot, self.schemaview)
slot = self.generate_slot(slot, cls)
slot = self.after_generate_slot(slot, self.schemaview)
slot_results.append(slot)
result = result.merge(slot)
slot_results = self.after_generate_slots(slot_results, self.schemaview)
attributes = {slot.attribute.name: slot.attribute for slot in slot_results}
result.cls.attributes = attributes
result.cls = self.include_metadata(result.cls, cls)
return result
def generate_slot(self, slot: SlotDefinition, cls: ClassDefinition) -> SlotResult:
slot_args = {
k: slot._as_dict.get(k, None)
for k in PydanticAttribute.model_fields.keys()
if slot._as_dict.get(k, None) is not None
}
slot_alias = slot.alias if slot.alias else slot.name
slot_args["name"] = underscore(slot_alias)
slot_args["description"] = slot.description.replace('"', '\\"') if slot.description is not None else None
predef = self.predefined_slot_values.get(camelcase(cls.name), {}).get(slot.name, None)
if predef is not None:
slot_args["predefined"] = str(predef)
pyslot = PydanticAttribute(**slot_args)
pyslot = self.include_metadata(pyslot, slot)
slot_ranges = []
# Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of
any_of_ranges = [a.range if a.range else slot.range for a in slot.any_of]
if any_of_ranges:
# list comprehension here is pulling ranges from within AnonymousSlotExpression
slot_ranges.extend(any_of_ranges)
else:
slot_ranges.append(slot.range)
pyranges = [self.generate_python_range(slot_range, slot, cls) for slot_range in slot_ranges]
pyranges = list(set(pyranges)) # remove duplicates
pyranges.sort()
if len(pyranges) == 1:
pyrange = pyranges[0]
elif len(pyranges) > 1:
pyrange = f"Union[{', '.join(pyranges)}]"
else:
raise Exception(f"Could not generate python range for {cls.name}.{slot.name}")
pyslot.range = pyrange
imports = self._get_imports(slot) if self.split else None
result = SlotResult(attribute=pyslot, source=slot, imports=imports)
if slot.array is not None:
results = self.get_array_representations_range(slot, result.attribute.range)
if len(results) == 1:
result.attribute.range = results[0].range
else:
result.attribute.range = f"Union[{', '.join([res.range for res in results])}]"
for res in results:
result = result.merge(res)
elif slot.multivalued:
if slot.inlined or slot.inlined_as_list:
collection_key = self.generate_collection_key(slot_ranges, slot, cls)
else:
collection_key = None
if slot.inlined is False or collection_key is None or slot.inlined_as_list is True:
result.attribute.range = f"List[{result.attribute.range}]"
else:
simple_dict_value = None
if len(slot_ranges) == 1:
simple_dict_value = self._inline_as_simple_dict_with_value(slot)
if simple_dict_value:
# simple_dict_value might be the range of the identifier of a class when range is a class,
# so we specify either that identifier or the range itself
if simple_dict_value != result.attribute.range:
simple_dict_value = f"Union[{simple_dict_value}, {result.attribute.range}]"
result.attribute.range = f"Dict[str, {simple_dict_value}]"
else:
result.attribute.range = f"Dict[{collection_key}, {result.attribute.range}]"
if not (slot.required or slot.identifier or slot.key) and not slot.designates_type:
result.attribute.range = f"Optional[{result.attribute.range}]"
return result
@property
def predefined_slot_values(self) -> Dict[str, Dict[str, str]]:
"""
:return: Dictionary of dictionaries with predefined slot values for each class
"""
if self._predefined_slot_values is None:
sv = self.schemaview
ifabsent_processor = PythonIfAbsentProcessor(sv)
slot_values = defaultdict(dict)
for class_def in sv.all_classes().values():
for slot_name in sv.class_slots(class_def.name):
slot = sv.induced_slot(slot_name, class_def.name)
if slot.designates_type:
target_value = get_type_designator_value(sv, slot, class_def)
slot_values[camelcase(class_def.name)][slot.name] = f'"{target_value}"'
if slot.multivalued:
slot_values[camelcase(class_def.name)][slot.name] = (
"[" + slot_values[camelcase(class_def.name)][slot.name] + "]"
)
slot_values[camelcase(class_def.name)][slot.name] = slot_values[camelcase(class_def.name)][
slot.name
]
elif slot.ifabsent is not None:
value = ifabsent_processor.process_slot(slot, class_def)
slot_values[camelcase(class_def.name)][slot.name] = value
self._predefined_slot_values = slot_values
return self._predefined_slot_values
@property
def class_bases(self) -> Dict[str, List[str]]:
"""
Generate the inheritance list for each class from is_a plus mixins
:return:
"""
if self._class_bases is None:
sv = self.schemaview
parents = {}
for class_def in sv.all_classes().values():
class_parents = []
if class_def.is_a:
class_parents.append(camelcase(class_def.is_a))
if self.gen_mixin_inheritance and class_def.mixins:
class_parents.extend([camelcase(mixin) for mixin in class_def.mixins])
if len(class_parents) > 0:
# Use the sorted list of classes to order the parent classes, but reversed to match MRO needs
class_parents.sort(
key=lambda x: self.sorted_class_names.index(x) if x in self.sorted_class_names else -1
)
class_parents.reverse()
parents[camelcase(class_def.name)] = class_parents
self._class_bases = parents
return self._class_bases
def get_mixin_identifier_range(self, mixin) -> str:
sv = self.schemaview
id_ranges = list(
{
_get_pyrange(sv.get_type(sv.get_identifier_slot(c).range), sv)
for c in sv.class_descendants(mixin.name, mixins=True)
if sv.get_identifier_slot(c) is not None
}
)
if len(id_ranges) == 0:
return None
elif len(id_ranges) == 1:
return id_ranges[0]
else:
return f"Union[{'.'.join(id_ranges)}]"
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
sv = self.schemaview
range_cls = sv.get_class(slot_range)
# Hardcoded handling for Any
if range_cls.class_uri == "linkml:Any":
return "Any"
# Inline the class itself only if the class is defined as inline, or if the class has no
# identifier slot and also isn't a mixin.
if (
inlined
or inlined_as_list
or (sv.get_identifier_slot(range_cls.name, use_key=True) is None and not sv.is_mixin(range_cls.name))
):
if (
len([x for x in sv.class_induced_slots(slot_range) if x.designates_type]) > 0
and len(sv.class_descendants(slot_range)) > 1
):
return "Union[" + ",".join([camelcase(c) for c in sv.class_descendants(slot_range)]) + "]"
else:
return f"{camelcase(slot_range)}"
# For the more difficult cases, set string as the default and attempt to improve it
range_cls_identifier_slot_range = "str"
# For mixins, try to use the identifier slot of descendant classes
if self.gen_mixin_inheritance and sv.is_mixin(range_cls.name) and sv.get_identifier_slot(range_cls.name):
range_cls_identifier_slot_range = self.get_mixin_identifier_range(range_cls)
# If the class itself has an identifier slot, it can be allowed to overwrite a value from mixin above
if (
sv.get_identifier_slot(range_cls.name) is not None
and sv.get_identifier_slot(range_cls.name).range is not None
):
range_cls_identifier_slot_range = _get_pyrange(
sv.get_type(sv.get_identifier_slot(range_cls.name).range), sv
)
return range_cls_identifier_slot_range
[docs] def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str:
"""
Generate the python range for a slot range value
"""
sv = self.schemaview
if slot_def.designates_type:
pyrange = (
"Literal["
+ ",".join(['"' + x + '"' for x in get_accepted_type_designator_values(sv, slot_def, class_def)])
+ "]"
)
elif slot_def.equals_string:
pyrange = f'Literal["{slot_def.equals_string}"]'
elif slot_def.equals_string_in:
pyrange = "Literal[" + ", ".join([f'"{a_string}"' for a_string in slot_def.equals_string_in]) + "]"
elif slot_range in sv.all_classes():
pyrange = self.get_class_slot_range(
slot_range,
inlined=slot_def.inlined,
inlined_as_list=slot_def.inlined_as_list,
)
elif slot_range in sv.all_enums():
pyrange = f"{camelcase(slot_range)}"
elif slot_range in sv.all_types():
t = sv.get_type(slot_range)
pyrange = _get_pyrange(t, sv)
elif slot_range is None:
pyrange = "str"
else:
# TODO: default ranges in schemagen
# pyrange = 'str'
# logging.error(f'range: {s.range} is unknown')
raise Exception(f"range: {slot_range}")
return pyrange
[docs] def generate_collection_key(
self,
slot_ranges: List[str],
slot_def: SlotDefinition,
class_def: ClassDefinition,
) -> Optional[str]:
"""
Find the python range value (str, int, etc) for the identifier slot
of a class used as a slot range.
If a pyrange value matches a class name, the range of the identifier slot
will be returned. If more than one match is found and they don't match,
an exception will be raised.
:param slot_ranges: list of python range values
"""
collection_keys: Set[str] = set()
if slot_ranges is None:
return None
for slot_range in slot_ranges:
if slot_range is None or slot_range not in self.schemaview.all_classes():
continue # ignore non-class ranges
identifier_slot = self.schemaview.get_identifier_slot(slot_range, use_key=True)
if identifier_slot is not None:
collection_keys.add(self.generate_python_range(identifier_slot.range, slot_def, class_def))
if len(collection_keys) > 1:
raise Exception(f"Slot with any_of range has multiple identifier slot range types: {collection_keys}")
if len(collection_keys) == 1:
return list(collection_keys)[0]
return None
def _clean_injected_classes(self, injected_classes: List[Union[str, Type]]) -> Optional[List[str]]:
"""Get source, deduplicate, and dedent injected classes"""
if len(injected_classes) == 0:
return None
injected_classes = list(
dict.fromkeys([c if isinstance(c, str) else inspect.getsource(c) for c in injected_classes])
)
injected_classes = [textwrap.dedent(c) for c in injected_classes]
return injected_classes
def _inline_as_simple_dict_with_value(self, slot_def: SlotDefinition) -> Optional[str]:
"""
Determine if a slot should be inlined as a simple dict with a value.
For example, if we have a class such as Prefix, with two slots prefix and expansion,
then an inlined list of prefixes can be serialized as:
.. code-block:: yaml
prefixes:
prefix1: expansion1
prefix2: expansion2
...
Provided that the prefix slot is the identifier slot for the Prefix class.
TODO: move this to SchemaView
:param slot_def: SlotDefinition
:param sv: SchemaView
:return: str
"""
if slot_def.inlined and not slot_def.inlined_as_list:
if slot_def.range in self.schemaview.all_classes():
id_slot = self.schemaview.get_identifier_slot(slot_def.range, use_key=True)
if id_slot is not None:
range_cls_slots = self.schemaview.class_induced_slots(slot_def.range)
if len(range_cls_slots) == 2:
non_id_slots = [slot for slot in range_cls_slots if slot.name != id_slot.name]
if len(non_id_slots) == 1:
value_slot = non_id_slots[0]
value_slot_range_type = self.schemaview.get_type(value_slot.range)
if value_slot_range_type is not None:
return _get_pyrange(value_slot_range_type, self.schemaview)
return None
def _template_environment(self) -> Environment:
env = PydanticTemplateModel.environment()
if self.template_dir is not None:
loader = ChoiceLoader([FileSystemLoader(self.template_dir), env.loader])
env.loader = loader
return env
[docs] def get_array_representations_range(self, slot: SlotDefinition, range: str) -> List[SlotResult]:
"""
Generate the python range for array representations
"""
array_reps = []
for repr in self.array_representations:
generator = ArrayRangeGenerator.get_generator(repr)
result = generator(slot.array, range).make()
array_reps.append(result)
if len(array_reps) == 0:
raise ValueError("No array representation generated, but one was requested!")
return array_reps
@overload
def include_metadata(self, model: PydanticModule, source: SchemaDefinition) -> PydanticModule: ...
@overload
def include_metadata(self, model: PydanticClass, source: ClassDefinition) -> PydanticClass: ...
@overload
def include_metadata(self, model: PydanticAttribute, source: SlotDefinition) -> PydanticAttribute: ...
def _get_imports(self, element: Union[ClassDefinition, SlotDefinition, None] = None) -> Imports:
"""
Get imports that are implied by their usage in slots or classes
(and thus need to be imported when generating schemas in :attr:`.split` == ``True`` mode).
**Note:**
Since in pydantic (currently) the only things that are materialized are classes, we don't
import class slots from imported schemas and abandon slots, directly expressing them
in the model.
This is a parent placeholder method in case that changes, "give me something and return
a set of imports" that calls subordinate methods. If slots become materialized, keep
this as the directly called method rather than spaghetti-ing out another
independent method. This method is also isolated in anticipation of structured imports,
where we will need to revise our expectations of what is imported when.
Args:
element (:class:`.ClassDefinition` , :class:`.SlotDefinition` , None): The element
to get import for. If ``None`` , get all needed imports (see :attr:`.split_mode`
"""
# import from local references, rather than serializing every class in every file
if not self.split or (self.split_mode == SplitMode.FULL and element is not None):
# we are either compiling this whole thing in one big file (default)
# or going to import all classes from the imported schemas,
# so we don't import anything
return Imports()
# gather a list of class names,
# remove local classes and transform to Imports later.
needed_classes = []
# fine to call rather than pass bc it's cached
all_classes = self.schemaview.all_classes(imports=True)
local_classes = self.schemaview.all_classes(imports=False)
if isinstance(element, ClassDefinition):
if element.is_a:
needed_classes.append(element.is_a)
if element.mixins:
needed_classes.extend(element.mixins)
elif isinstance(element, SlotDefinition):
# collapses `slot.range`, `slot.any_of`, and `slot.one_of` to a list
slot_ranges = self.schemaview.slot_range_as_union(element)
needed_classes.extend([a_range for a_range in slot_ranges if a_range in all_classes])
elif element is None:
# get all imports
needed_classes.extend([cls for cls in all_classes if cls not in local_classes])
else:
raise ValueError(f"Unsupported type of element to get imports from: f{type(element)}")
# SPECIAL CASE: classes that are not generated for structural reasons.
# TODO: Do we want to have a general means of skipping class generation?
skips = ("AnyType",)
class_imports = [
self._get_element_import(cls) for cls in needed_classes if (cls not in local_classes and cls not in skips)
]
imports = Imports(imports=class_imports)
return imports
[docs] def generate_module_import(self, schema: SchemaDefinition, context: Optional[dict] = None) -> str:
"""
Generate the module string for importing from python modules generated from imported schemas
when in :attr:`.split` mode.
Use the :attr:`.split_pattern` as a jinja template rendered with the :class:`.SchemaDefinition`
and any passed ``context``. Apply the :attr:`.SNAKE_CASE` regex to substitute matches with
``_`` and ensure lowercase.
"""
if context is None:
context = {}
module = Template(self.split_pattern).render(schema=schema, **context)
module = re.sub(self.SNAKE_CASE, "_", module) if self.SNAKE_CASE else module
module = module.lower()
return module
def _get_element_import(self, class_name: ElementName) -> Import:
"""
Make an import object for an element from another schema, using the
:attr:`.split_import_pattern` to generate the module import part.
"""
schema_name = self.schemaview.element_by_schema_map()[class_name]
schema = [s for s in self.schemaview.schema_map.values() if s.name == schema_name][0]
module = self.generate_module_import(schema, self.split_context)
return Import(module=module, objects=[ObjectImport(name=camelcase(class_name))], is_schema=True)
[docs] def render(self) -> PydanticModule:
"""
Render the schema to a :class:`PydanticModule` model
"""
sv: SchemaView
sv = self.schemaview
# imports
imports = DEFAULT_IMPORTS
if self.imports is not None:
for i in self.imports:
imports += i
if self.split_mode == SplitMode.FULL:
imports += self._get_imports()
# injected classes
injected_classes = DEFAULT_INJECTS.copy()
if self.injected_classes is not None:
injected_classes += self.injected_classes.copy()
# enums
enums = self.before_generate_enums(list(sv.all_enums().values()), sv)
enums = self.generate_enums({e.name: e for e in enums})
base_model = PydanticBaseModel(extra_fields=self.extra_fields, fields=self.injected_fields)
# schema classes
class_results = []
source_classes, imported_classes = self._get_classes(sv)
source_classes = self.sort_classes(source_classes, imported_classes)
# Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below
source_classes = [c for c in source_classes if c.class_uri != "linkml:Any"]
source_classes = self.before_generate_classes(source_classes, sv)
self.sorted_class_names = [camelcase(c.name) for c in source_classes]
for cls in source_classes:
cls = self.before_generate_class(cls, sv)
result = self.generate_class(cls)
result = self.after_generate_class(result, sv)
class_results.append(result)
if result.imports is not None:
imports += result.imports
if result.injected_classes is not None:
injected_classes.extend(result.injected_classes)
class_results = self.after_generate_classes(class_results, sv)
classes = {r.cls.name: r.cls for r in class_results}
injected_classes = self._clean_injected_classes(injected_classes)
imports.render_sorted = self.sort_imports
module = PydanticModule(
metamodel_version=self.schema.metamodel_version,
version=self.schema.version,
python_imports=imports,
base_model=base_model,
injected_classes=injected_classes,
enums=enums,
classes=classes,
)
module = self.include_metadata(module, self.schemaview.schema)
module = self.before_render_template(module, self.schemaview)
return module
[docs] def serialize(self, rendered_module: Optional[PydanticModule] = None) -> str:
"""
Serialize the schema to a pydantic module as a string
Args:
rendered_module ( :class:`.PydanticModule` ): Optional, if schema was previously
rendered with :meth:`~.PydanticGenerator.render` , use that,
otherwise :meth:`~.PydanticGenerator.render` fresh.
"""
if rendered_module is not None:
module = rendered_module
else:
module = self.render()
serialized = module.render(self._template_environment(), self.black)
serialized = self.after_render_template(serialized, self.schemaview)
return serialized
def default_value_for_type(self, typ: str) -> str:
return "None"
[docs] @classmethod
def generate_split(
cls,
schema: Union[str, Path, SchemaDefinition],
output_path: Union[str, Path] = Path("."),
split_pattern: Optional[str] = None,
split_context: Optional[dict] = None,
split_mode: SplitMode = SplitMode.AUTO,
**kwargs,
) -> List[SplitResult]:
"""
Generate a schema that imports from other schema as a set of python modules that
import from one another, rather than generating all imported classes in a single schema.
Uses ``output_path`` for the main schema from ``schema`` , and then
generates any imported schema (from which classes are actually used)
to modules whose locations are determined by the module names generated
by the ``split_pattern`` (see :attr:`.PydanticGenerator.split_pattern` ).
For example, for
* a ``output_path`` of ``my_dir/v1_2_3/main.py``
* a schema ``main`` with a version ``v1.2.3``
* that imports from ``s2`` with version ``v4.5.6``,
* and a ``split_pattern`` of ``..{{ schema.version | replace('.', '_') }}.{{ schema.name }}``
One would get:
* ``my_dir/v1_2_3/main.py`` , as expected
* that imports ``from ..v4_5_6.s2``
* a module at ``my_dir/v4_5_6/s2.py``
``__init__.py`` files are generated for any directories that are between
the generated modules and their highest common directory.
Args:
schema (str, :class:`.Path` , :class:`.SchemaDefinition` ): Main schema to generate
output_path (str, :class:`.Path` ): Python ``.py`` module to generate main schema to
split_pattern (str): Pattern to use to generate module names, see :attr:`.PydanticGenerator.split_pattern`
split_context (dict): Additional variables to pass into jinja context when generating module import names.
Returns:
list[:class:`.SplitResult`]
"""
output_path = Path(output_path)
if not output_path.suffix == ".py":
raise ValueError(f"output path must be a python file to write the main schema to, got {output_path}")
results = []
# --------------------------------------------------
# Main schema
# --------------------------------------------------
gen_kwargs = kwargs
gen_kwargs.update(
{"split": True, "split_pattern": split_pattern, "split_context": split_context, "split_mode": split_mode}
)
generator = cls(schema, **gen_kwargs)
# Generate the initial schema to figure out which of the imported schema actually need
# to be generated
rendered = generator.render()
# write schema - we use the ``output_path`` for the main schema, and then
# interpret all imported schema paths as relative to that
output_path.parent.mkdir(parents=True, exist_ok=True)
serialized = generator.serialize(rendered_module=rendered)
with open(output_path, "w", encoding="utf-8") as ofile:
ofile.write(serialized)
results.append(
SplitResult(main=True, source=generator.schemaview.schema, path=output_path, serialized_module=serialized)
)
# --------------------------------------------------
# Imported schemas
# --------------------------------------------------
imported_schema = {
generator.generate_module_import(sch): sch for sch in generator.schemaview.schema_map.values()
}
for generated_import in [i for i in rendered.python_imports if i.is_schema]:
import_generator = cls(imported_schema[generated_import.module], **gen_kwargs)
serialized = import_generator.serialize()
rel_path = _import_to_path(generated_import.module)
abs_path = (output_path.parent / rel_path).resolve()
abs_path.parent.mkdir(parents=True, exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as ofile:
ofile.write(serialized)
results.append(
SplitResult(
main=False,
source=imported_schema[generated_import.module],
path=abs_path,
serialized_module=serialized,
module_import=generated_import.module,
)
)
_ensure_inits([r.path for r in results])
return results
def _subclasses(cls: Type):
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in _subclasses(c)])
_TEMPLATE_NAMES = sorted(list(set([c.template for c in _subclasses(PydanticTemplateModel)])))
def _import_to_path(module: str) -> Path:
"""Make a (relative) ``Path`` object from a python module import string"""
# handle leading .'s separately..
_, dots, module = re.split(r"(^\.*)(?=\w)", module, maxsplit=1)
# treat zero or one dots as a relative import to the current directory
dir_pieces = ["../" for _ in range(max(len(dots) - 1, 0))]
dir_pieces.extend(module.split("."))
dir_pieces[-1] = dir_pieces[-1] + ".py"
return Path(*dir_pieces)
def _ensure_inits(paths: List[Path]):
"""For a set of paths, find the common root and it and all the subdirectories have an __init__.py"""
# if there is only one file, there is no relative importing to be done
if len(paths) <= 1:
return
common_path = Path(os.path.commonpath(paths))
if not (ipath := (common_path / "__init__.py")).exists():
with open(ipath, "w", encoding="utf-8") as ifile:
ifile.write(" \n")
for path in paths:
# ensure __init__ for each directory from this path up to the common path
path = path.parent
while path != common_path:
if not (ipath := (path / "__init__.py")).exists():
with open(ipath, "w", encoding="utf-8") as ifile:
ifile.write(" \n")
path = path.parent
@shared_arguments(PydanticGenerator)
@click.option("--template-file", hidden=True)
@click.option(
"--template-dir",
type=click.Path(),
help="""
Optional jinja2 template directory to use for class generation.
Pass a directory containing templates with the same name as any of the default
:class:`.PydanticTemplateModel` templates to override them. The given directory will be
searched for matching templates, and use the default templates as a fallback
if an override is not found
Available templates to override:
\b
"""
+ "\n".join(["- " + name for name in _TEMPLATE_NAMES]),
)
@click.option(
"--array-representations",
type=click.Choice([k.value for k in ArrayRepresentation]),
multiple=True,
default=["list"],
help="List of array representations to accept for array slots. Default is list of lists.",
)
@click.option(
"--extra-fields",
type=click.Choice(["allow", "ignore", "forbid"], case_sensitive=False),
default="forbid",
help="How to handle extra fields in BaseModel.",
)
@click.option(
"--black",
is_flag=True,
default=False,
help="Format generated models with black (must be present in the environment)",
)
@click.option(
"--meta",
type=click.Choice([k for k in MetadataMode]),
default="auto",
help="How to include linkml schema metadata in generated pydantic classes. "
"See docs for MetadataMode for full description of choices. "
"Default (auto) is to include all metadata that can't be otherwise represented",
)
@click.version_option(__version__, "-V", "--version")
@click.command(name="pydantic")
def cli(
yamlfile,
template_file=None,
template_dir: Optional[str] = None,
head=True,
genmeta=False,
classvars=True,
slots=True,
array_representations=list("list"),
extra_fields: Literal["allow", "forbid", "ignore"] = "forbid",
black: bool = False,
meta: MetadataMode = "auto",
**args,
):
"""Generate pydantic classes to represent a LinkML model"""
if template_file is not None:
raise DeprecationWarning(
(
"Passing a single template_file is deprecated. Pass a directory of template files instead. "
"See help string for --template-dir"
)
)
if template_dir is not None:
if not Path(template_dir).exists():
raise FileNotFoundError(f"The template directory {template_dir} does not exist!")
gen = PydanticGenerator(
yamlfile,
array_representations=[ArrayRepresentation(x) for x in array_representations],
extra_fields=extra_fields,
emit_metadata=head,
genmeta=genmeta,
gen_classvars=classvars,
gen_slots=slots,
template_dir=template_dir,
black=black,
metadata_mode=meta,
**args,
)
print(gen.serialize())
if __name__ == "__main__":
cli()