Source code for linkml.generators.rustgen.rustgen

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Literal, overload

from jinja2 import Environment

from linkml.generators.common.lifecycle import LifecycleMixin
from linkml.generators.common.subproperty import is_uri_range
from linkml.generators.common.template import ObjectImport
from linkml.generators.common.type_designators import get_accepted_type_designator_values
from linkml.generators.rustgen.build import (
    AttributeResult,
    ClassResult,
    CrateResult,
    EnumResult,
    FileResult,
    SlotResult,
    TypeResult,
)
from linkml.generators.rustgen.template import (
    AsKeyValue,
    ContainerType,
    Import,
    Imports,
    PolyContainersFile,
    PolyFile,
    PolyTrait,
    PolyTraitImpl,
    PolyTraitImplForSubtypeEnum,
    PolyTraitProperty,
    PolyTraitPropertyImpl,
    PolyTraitPropertyMatch,
    RustCargo,
    RustClassModule,
    RustEnum,
    RustEnumItem,
    RustFile,
    RustLibShim,
    RustProperty,
    RustPyProject,
    RustRange,
    RustStruct,
    RustStructOrSubtypeEnum,
    RustTemplateModel,
    RustTypeAlias,
    SerdeUtilsFile,
    SlotRangeAsUnion,
    StubGenBin,
    StubUtilsFile,
)
from linkml.utils.generator import Generator
from linkml_runtime.linkml_model.meta import (
    ClassDefinition,
    EnumDefinition,
    PermissibleValue,
    SlotDefinition,
    TypeDefinition,
)
from linkml_runtime.utils.formatutils import camelcase, uncamelcase, underscore
from linkml_runtime.utils.schemaview import OrderedBy, SchemaView

RUST_MODES = Literal["crate", "file"]

PYTHON_TO_RUST = {
    int: "isize",
    float: "f64",
    str: "String",
    bool: "bool",
    "int": "isize",
    "float": "f64",
    "str": "String",
    "String": "String",
    "bool": "bool",
    "Bool": "bool",
    "XSDDate": "NaiveDate",
    "date": "NaiveDate",
    "XSDDateTime": "NaiveDateTime",
    "datetime": "NaiveDateTime",
    # "Decimal": "dec",
    "Decimal": "f64",
}
"""
Mapping from python types to rust types.

.. todo::

    - Add numpy types
    - make an enum wrapper for naivedatetime and datetime<fixedoffset> that can represent both of them

"""

PROTECTED_NAMES = ("type", "typeof", "abstract")

RUST_IMPORTS = {
    "dec": Import(module="rust_decimal", version="1.36", objects=[ObjectImport(name="dec")]),
    "NaiveDate": Import(
        module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDate")]
    ),
    "NaiveDateTime": Import(
        module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDateTime")]
    ),
}

MERGE_ANNOTATION = "rust.linkml.io/generate/merge"

MERGE_IMPORTS = Imports(
    imports=[Import(module="merge", version="0.2.0", objects=[ObjectImport(name="Merge")])],
)

DEFAULT_IMPORTS = Imports(
    imports=[
        Import(module="std::collections", objects=[ObjectImport(name="HashMap")]),
        # Import(module="std::fmt", objects=[ObjectImport(name="Display")]),
    ]
)

SERDE_IMPORTS = Imports(
    imports=[
        Import(
            module="serde",
            version="1.0",
            features=["derive"],
            objects=[
                ObjectImport(name="Serialize"),
                ObjectImport(name="Deserialize"),
                ObjectImport(name="de::IntoDeserializer"),
            ],
            feature_flag="serde",
        ),
        Import(module="serde-value", version="0.7.0", objects=[ObjectImport(name="Value")]),
        Import(module="serde_yml", version="0.0.12", feature_flag="serde", alias="_"),
        Import(module="serde_path_to_error", version="0.1.17", objects=[], feature_flag="serde"),
    ]
)

PYTHON_IMPORTS = Imports(
    imports=[
        Import(
            module="pyo3",
            version="0.25.0",
            objects=[ObjectImport(name="prelude::*"), ObjectImport(name="FromPyObject")],
            feature_flag="pyo3",
            features=["chrono"],
        ),
        # Import(module="serde_pyobject", version="0.6.1", objects=[], feature_flag="pyo3", features=[]),
    ]
)

STUBGEN_IMPORTS = Imports(
    imports=[
        Import(
            module="pyo3-stub-gen",
            version="0.13.1",
            objects=[
                ObjectImport(name="define_stub_info_gatherer"),
                ObjectImport(name="derive::gen_stub_pyclass"),
                ObjectImport(name="derive::gen_stub_pymethods"),
            ],
            feature_flag="stubgen",
            feature_dependencies=["pyo3"],
        ),
    ]
)


class SlotContainerMode(Enum):
    SINGLE_VALUE = "single_value"
    MAPPING = "mapping"
    LIST = "list"


class SlotInlineMode(Enum):
    INLINE = "inline"
    PRIMITIVE = "primitive"
    REFERENCE = "reference"


def get_key_or_identifier_slot(cls: ClassDefinition, sv: SchemaView) -> SlotDefinition | None:
    induced_slots = sv.class_induced_slots(cls.name)
    for slot in induced_slots:
        if slot.identifier or slot.key:
            return slot
    return None


def get_identifier_slot(cls: ClassDefinition, sv: SchemaView) -> SlotDefinition | None:
    induced_slots = sv.class_induced_slots(cls.name)
    for slot in induced_slots:
        if slot.identifier:
            return slot
    return None


def class_real_descendants(sv: SchemaView, class_name: str) -> list[str]:
    """Return true descendants of a class, excluding the class itself.

    Some SchemaView implementations include the class in `class_descendants`.
    We normalize here to avoid off-by-one errors when deciding if a class has
    subtypes (for OrSubtype generation and trait typing decisions).
    """
    try:
        descs = list(sv.class_descendants(class_name))
    except Exception:
        descs = []
    return [d for d in descs if d != class_name]


def has_real_subtypes(sv: SchemaView, class_name: str) -> bool:
    """True when the class has at least one real subtype (excluding itself)."""
    return len(class_real_descendants(sv, class_name)) > 0


def determine_slot_mode(s: SlotDefinition, sv: SchemaView) -> tuple[SlotContainerMode, SlotInlineMode]:
    """Return container and inline modes for a slot."""

    class_range = s.range in sv.all_classes()
    if not class_range:
        return (
            SlotContainerMode.LIST if s.multivalued else SlotContainerMode.SINGLE_VALUE,
            SlotInlineMode.PRIMITIVE,
        )

    if s.multivalued and s.inlined_as_list:
        return (SlotContainerMode.LIST, SlotInlineMode.INLINE)

    key_slot = get_key_or_identifier_slot(sv.get_class(s.range), sv)
    identifier_slot = get_identifier_slot(sv.get_class(s.range), sv)
    inlined = s.inlined
    if identifier_slot is None:
        # can only inline if identifier slot is none
        inlined = True

    if not s.multivalued:
        return (
            SlotContainerMode.SINGLE_VALUE,
            SlotInlineMode.INLINE if inlined else SlotInlineMode.REFERENCE,
        )

    if not inlined:
        return (SlotContainerMode.LIST, SlotInlineMode.REFERENCE)

    if key_slot is not None:
        return (SlotContainerMode.MAPPING, SlotInlineMode.INLINE)
    else:
        return (SlotContainerMode.LIST, SlotInlineMode.INLINE)


def can_contain_reference_to_class(s: SlotDefinition, cls: ClassDefinition, sv: SchemaView) -> bool:
    ref_name = cls.name
    seen_classes = set()
    classes_to_check = [s.range]
    while len(classes_to_check) > 0:
        a_class = classes_to_check.pop()
        seen_classes.add(a_class)
        if a_class not in sv.all_classes():
            continue
        if a_class == ref_name:
            return True
        induced_class = sv.induced_class(a_class)
        for attr in induced_class.attributes.values():
            if attr.range not in seen_classes:
                classes_to_check.append(attr.range)
    return False


def get_rust_type(
    t: TypeDefinition | type | str, sv: SchemaView, pyo3: bool = False, crate_ref: str | None = None
) -> str:
    """
    Get the rust type from a given linkml type
    """
    rsrange = None
    no_add_crate = False

    if isinstance(t, TypeDefinition):
        rsrange = t.base
        if rsrange is not None and rsrange not in PYTHON_TO_RUST:
            # A type like URIorCURIE which is an alias for a rust type
            rsrange = get_name(t)

        elif rsrange is None and t.typeof is not None:
            # A type with no base type,
            no_add_crate = True
            rsrange = get_rust_type(sv.get_type(t.typeof), sv, pyo3)

    elif isinstance(t, str):
        if tdef := sv.all_types().get(t, None):
            rsrange = get_rust_type(tdef, sv, pyo3)
            no_add_crate = True
        elif t in sv.all_enums():
            # Map LinkML enums to generated Rust enums rather than collapsing to String
            e = sv.get_enum(t)
            rsrange = get_name(e)
            no_add_crate = True
        elif t in sv.all_classes():
            c = sv.get_class(t)
            rsrange = get_name(c)

    # FIXME: Raise here once we have implemented all base types
    if rsrange is None:
        rsrange = PYTHON_TO_RUST[str]
    elif rsrange in PYTHON_TO_RUST:
        rsrange = PYTHON_TO_RUST[rsrange]
    elif crate_ref is not None and not no_add_crate:
        rsrange = f"{crate_ref}::{rsrange}"
    return rsrange


def get_rust_range_info(
    cls: ClassDefinition, s: SlotDefinition, sv: SchemaView, crate_ref: str | None = None
) -> RustRange:
    (container_mode, inline_mode) = determine_slot_mode(s, sv)
    all_ranges = sv.slot_range_as_union(s)
    sub_ranges = [
        RustRange(
            type_="String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(r, sv, True, crate_ref),
            is_class_range=r in sv.all_classes(),
            has_class_subtypes=has_real_subtypes(sv, r) if r in sv.all_classes() else False,
        )
        for r in all_ranges
    ]

    res = RustRange(
        optional=not s.required,
        has_default=not (s.required or False) or (s.multivalued or False),
        containerType=(
            ContainerType.LIST
            if container_mode == SlotContainerMode.LIST
            else ContainerType.MAPPING
            if container_mode == SlotContainerMode.MAPPING
            else None
        ),
        child_ranges=sub_ranges if len(sub_ranges) > 1 else None,
        box_needed=inline_mode == SlotInlineMode.INLINE and can_contain_reference_to_class(s, cls, sv),
        is_class_range=all_ranges[0] in sv.all_classes() if len(all_ranges) == 1 else False,
        is_reference=inline_mode == SlotInlineMode.REFERENCE,
        has_class_subtypes=(
            has_real_subtypes(sv, all_ranges[0])
            if (len(all_ranges) == 1 and all_ranges[0] in sv.all_classes())
            else False
        ),
        type_=(
            underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range"
            if len(sub_ranges) > 1
            else ("String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(s.range, sv, True, crate_ref))
        ),
    )
    return res


def protect_name(v: str) -> str:
    """
    append an underscore to a protected name
    """
    if v in PROTECTED_NAMES:
        v = f"{v}_"
    return v


def get_name(e: ClassDefinition | SlotDefinition | EnumDefinition | PermissibleValue | TypeDefinition) -> str:
    if isinstance(e, ClassDefinition | EnumDefinition):
        name = camelcase(e.name)
    elif isinstance(e, PermissibleValue):
        name = camelcase(e.text)
    elif isinstance(e, SlotDefinition | TypeDefinition):
        name = underscore(e.name)
    else:
        raise ValueError("Can only get the name from a slot or class!")

    name = protect_name(name)
    return name


[docs] @dataclass class RustGenerator(Generator, LifecycleMixin): """ Generate rust types from a linkml schema """ generatorname = "rustgenerator" generatorversion = "0.0.2" valid_formats = ["rust"] file_extension = "rs" crate_name: str | None = None pyo3: bool = True """Generate pyO3 bindings for the rust defs""" pyo3_version: str = ">=0.21.1" serde: bool = True """Generate serde derive serialization/deserialization attributes""" stubgen: bool = True """Generate pyo3-stub-gen instrumentation alongside PyO3 bindings""" handwritten_lib: bool = False """Place generated sources under src/generated and leave src/lib.rs for user code""" mode: RUST_MODES = "crate" """Generate a cargo.toml file""" output: Path | None = None """ * If ``mode == "crate"`` , a directory to contain the generated crate * If ``mode == "file"`` , a file with a ``.rs`` extension If output is not provided at object instantiation, it must be provided on a call to :meth:`.serialize` """ expand_subproperty_of: bool = True """If True, expand subproperty_of to Rust enums with slot descendants""" _environment: Environment | None = None _subproperty_enums: dict = None # Cache for generated subproperty enums def __post_init__(self): self.schemaview: SchemaView = SchemaView(self.schema) self._subproperty_enums = {} # Cache for generated subproperty enums super().__post_init__() def _select_root_class(self, class_defs: list[ClassDefinition]) -> ClassDefinition | None: """Return the schema-local class marked ``tree_root`` if present.""" schema_id = getattr(self.schemaview.schema, "id", None) def is_local(cls: ClassDefinition) -> bool: if schema_id is None: return cls.from_schema is None return cls.from_schema == schema_id local_classes = [cls for cls in class_defs if is_local(cls) and not getattr(cls, "mixin", False)] for cls in local_classes: if getattr(cls, "tree_root", False): return cls return None def generate_type(self, type_: TypeDefinition) -> TypeResult: type_ = self.before_generate_type(type_, self.schemaview) res = TypeResult( source=type_, type_=RustTypeAlias( name=get_name(type_), type_=get_rust_type(type_.base, self.schemaview, self.pyo3), pyo3=self.pyo3, stubgen=self.stubgen, ), imports=self.get_imports(type_), ) slot = self.after_generate_type(res, self.schemaview) return slot def generate_enum(self, enum: EnumDefinition) -> EnumResult: enum = self.before_generate_enum(enum, self.schemaview) items = [ RustEnumItem( variant=get_name(pv), text=pv.text or name, ) for name, pv in enum.permissible_values.items() ] res = EnumResult( source=enum, enum=RustEnum( name=get_name(enum), items=items, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, ), ) res = self.after_generate_enum(res, self.schemaview) return res
[docs] def generate_slot(self, slot: SlotDefinition) -> SlotResult: """ Generate a slot as a struct field """ slot = self.before_generate_slot(slot, self.schemaview) class_range = slot.range in self.schemaview.all_classes() type_ = get_rust_type(slot.range, self.schemaview, self.pyo3) slot = SlotResult( source=slot, slot=RustTypeAlias( name=get_name(slot), type_=type_, multivalued=slot.multivalued, pyo3=self.pyo3, class_range=class_range, stubgen=self.stubgen, ), imports=self.get_imports(slot), ) slot = self.after_generate_slot(slot, self.schemaview) return slot
[docs] def generate_class(self, cls: ClassDefinition) -> ClassResult: """ Generate a class as a struct! """ cls = self.before_generate_class(cls, self.schemaview) induced_attrs = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)] induced_attrs = self.before_generate_slots(induced_attrs, self.schemaview) slot_range_unions = [] for a in induced_attrs: # Promote union across descendants for canonical union enum in base module ranges = [] for r in self.schemaview.slot_range_as_union(a): ranges.append(r) for d in self.schemaview.class_descendants(cls.name): sdesc = self.schemaview.induced_slot(a.name, d) if sdesc is None: continue for r in self.schemaview.slot_range_as_union(sdesc): if r not in ranges: ranges.append(r) if len(ranges) > 1: slot_range_unions.append( SlotRangeAsUnion( slot_name=get_name(a), ranges=[get_rust_type(r, self.schemaview, True) for r in ranges], stubgen=self.stubgen, ) ) cls_mod = RustClassModule( class_name=get_name(cls), class_name_snakecase=underscore(uncamelcase(cls.name)), slot_ranges=slot_range_unions, stubgen=self.stubgen, ) attributes = [self.generate_attribute(attr, cls) for attr in induced_attrs] attributes = self.after_generate_slots(attributes, self.schemaview) unsendable = any([a.range in self.schemaview.all_classes() for a in induced_attrs]) res = ClassResult( source=cls, cls=RustStruct( name=get_name(cls), properties=[a.attribute for a in attributes], special_case_enabled=self.schemaview.get_uri(cls, expand=True).startswith("https://w3id.org/linkml"), generate_merge=MERGE_ANNOTATION in cls.annotations, unsendable=unsendable, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, as_key_value=self.generate_class_as_key_value(cls), struct_or_subtype_enum=self.gen_struct_or_subtype_enum(cls), class_module=cls_mod, ), ) # merge imports for attr in attributes: res = res.merge(attr) res = self.after_generate_class(res, self.schemaview) return res
def gen_struct_or_subtype_enum(self, cls: ClassDefinition) -> RustStructOrSubtypeEnum | None: descendants = class_real_descendants(self.schemaview, cls.name) td = self.schemaview.get_type_designator_slot(cls.name) td_mapping = {} if td is not None: for d in descendants: d_class = self.schemaview.get_class(d) values = get_accepted_type_designator_values(self.schemaview, td, d_class) td_mapping[d] = values if len(descendants) > 0: key_type = "String" key_slot = get_key_or_identifier_slot(cls, self.schemaview) if key_slot is not None: key_type = get_rust_type(key_slot.range, self.schemaview, self.pyo3) return RustStructOrSubtypeEnum( enum_name=get_name(cls) + "OrSubtype", struct_names=[get_name(self.schemaview.get_class(d)) for d in descendants], type_designator_name=get_name(td) if td else None, as_key_value=get_key_or_identifier_slot(cls, self.schemaview) is not None, type_designators=td_mapping, key_property_type=key_type, ) return None def generate_class_as_key_value(self, cls: ClassDefinition) -> AsKeyValue | None: induced_attrs = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)] key_attr = None value_attrs = [] value_args_no_default = [] non_key_attrs = [] for attr in induced_attrs: if attr.identifier: if key_attr is not None: ## multiple identifiers --> don't know what to do! return None key_attr = attr elif attr.key: if key_attr is not None: ## multiple keys --> don't know what to do! return None key_attr = attr else: non_key_attrs.append(attr) if not attr.multivalued: value_attrs.append(attr) if attr.required: value_args_no_default.append(attr) if key_attr is not None: # If there is a key/identifier but no single-valued non-multivalued # attribute to serve as the value, do not treat this as a key/value class. if len(value_attrs) == 0: return None value_attr = value_attrs[0] simple_dict_possible = ( len(non_key_attrs) == 1 and not value_attr.multivalued and ( value_attr.range not in self.schemaview.all_classes() or not bool(getattr(value_attr, "inlined", False)) ) ) return AsKeyValue( name=get_name(cls), key_property_name=get_name(key_attr), key_property_type=get_rust_type(key_attr.range, self.schemaview, self.pyo3), value_property_name=get_name(value_attr), value_property_type=get_rust_type(value_attr.range, self.schemaview, self.pyo3), can_convert_from_primitive=simple_dict_possible, can_convert_from_empty=len(value_args_no_default) == 0, value_property_optional=not bool(value_attr.required), serde=self.serde, pyo3=self.pyo3, stubgen=self.stubgen, ) return None
[docs] def generate_attribute(self, attr: SlotDefinition, cls: ClassDefinition) -> AttributeResult: """ Generate an attribute as a struct property """ attr = self.before_generate_slot(attr, self.schemaview) is_class_range = attr.range in self.schemaview.all_classes() (container_mode, inline_mode) = determine_slot_mode(attr, self.schemaview) # Check for subproperty_of constraint - generates enum type instead of regular range subproperty_enum_type = self._get_subproperty_enum_type(attr) if subproperty_enum_type: # Create a RustRange with the enum type range_info = RustRange( optional=not attr.required, has_default=not (attr.required or False) or (attr.multivalued or False), containerType=( ContainerType.LIST if container_mode == SlotContainerMode.LIST else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None ), is_class_range=False, # It's an enum, not a class is_reference=False, type_=subproperty_enum_type, ) else: range_info = get_rust_range_info(cls, attr, self.schemaview) res = AttributeResult( source=attr, attribute=RustProperty( name=get_name(attr), inline_mode=inline_mode.value, alias=attr.alias if attr.alias is not None and attr.alias != get_name(attr) else None, generate_merge=MERGE_ANNOTATION in cls.annotations, container_mode=container_mode.value, type_=range_info, required=bool(attr.required), multivalued=True if attr.multivalued else False, is_key_value=is_class_range and self.generate_class_as_key_value(self.schemaview.get_class(attr.range)) is not None, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, ), imports=self.get_imports(attr), ) res = self.after_generate_slot(res, self.schemaview) return res
[docs] def generate_cargo(self, imports: Imports) -> RustCargo: """ Generate a Cargo.toml file """ version = self.schemaview.schema.version if self.schemaview.schema.version is not None else "0.0.0" return RustCargo( name=self.crate_name if self.crate_name is not None else self.schemaview.schema.name, version=version, imports=imports, pyo3_version=self.pyo3_version, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, )
[docs] def generate_pyproject(self) -> RustPyProject: """ Generate a pyproject.toml file for a pyo3 rust crate """ version = self.schemaview.schema.version if self.schemaview.schema.version is not None else "0.0.0" return RustPyProject(name=self.schemaview.schema.name, version=version)
def get_imports(self, element: SlotDefinition | TypeDefinition) -> Imports: if isinstance(element, SlotDefinition): type_ = get_rust_type(element.range, self.schemaview, self.pyo3) elif isinstance(element, TypeDefinition): type_ = get_rust_type(element.base, self.schemaview, self.pyo3) else: raise TypeError("Must be a slot or type definition") if type_ in RUST_IMPORTS: return Imports(imports=[RUST_IMPORTS[type_]]) else: return Imports() def _get_subproperty_enum(self, slot: SlotDefinition) -> RustEnum | None: """ Generate a Rust enum for subproperty_of constrained slot. Following metamodel semantics: "any ontological child (related to X via an is_a relationship), is a valid value for the slot" Values are formatted according to range type: - uri/uriorcurie: Uses CURIEs (e.g., "biolink:causes") - string: Uses snake_case slot names (e.g., "causes") :param slot: SlotDefinition with subproperty_of set :return: RustEnum with variants for all valid values, or None if no values """ if not self.expand_subproperty_of or not slot.subproperty_of: return None sv = self.schemaview root_slot_name = slot.subproperty_of # Check cache first using both slot name and root slot cache_key = (slot.name, root_slot_name) if cache_key in self._subproperty_enums: return self._subproperty_enums[cache_key] # Get all descendants including root (reflexive) descendants = sv.slot_descendants(root_slot_name, reflexive=True) # Check if range is URI-like using shared utility use_uri = is_uri_range(sv, slot.range) # Generate enum items - Rust needs both variant names and text values items = [] seen_variants = set() for slot_name in sorted(descendants): descendant_slot = sv.get_slot(slot_name) if use_uri: # For URI-like ranges, use CURIE format for serde text = sv.get_uri(descendant_slot, expand=False) else: # For string ranges, use snake_case slot name text = underscore(slot_name) # Generate a valid Rust identifier for the variant variant = camelcase(slot_name) # Ensure uniqueness if variant in seen_variants: continue seen_variants.add(variant) items.append(RustEnumItem(variant=variant, text=text)) if not items: self._subproperty_enums[cache_key] = None return None # Create enum name based on slot name enum_name = camelcase(slot.name) + "Enum" rust_enum = RustEnum( name=enum_name, items=items, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, ) self._subproperty_enums[cache_key] = rust_enum return rust_enum def _get_subproperty_enum_type(self, slot: SlotDefinition) -> str | None: """ Get the Rust enum type name for a subproperty_of constrained slot. :param slot: SlotDefinition with subproperty_of set :return: Enum type name or None if not applicable """ enum = self._get_subproperty_enum(slot) if enum: return enum.name return None def _get_range_info_with_subproperty( self, cls: ClassDefinition, slot: SlotDefinition, crate_ref: str | None = None ) -> RustRange: """ Get RustRange info, considering subproperty_of constraint. If the slot has subproperty_of set and expand_subproperty_of is True, returns a RustRange with the generated enum type. Otherwise, falls back to the standard get_rust_range_info function. :param cls: ClassDefinition containing the slot :param slot: SlotDefinition to get range info for :param crate_ref: Optional crate reference for type paths :return: RustRange with appropriate type information """ subproperty_enum_type = self._get_subproperty_enum_type(slot) if subproperty_enum_type: (container_mode, _) = determine_slot_mode(slot, self.schemaview) return RustRange( optional=not slot.required, has_default=not (slot.required or False) or (slot.multivalued or False), containerType=( ContainerType.LIST if container_mode == SlotContainerMode.LIST else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None ), is_class_range=False, # It's an enum, not a class is_reference=False, type_=subproperty_enum_type, ) return get_rust_range_info(cls, slot, self.schemaview, crate_ref) @overload def render(self, mode: Literal["file"] = "file") -> FileResult: ... @overload def render(self, mode: Literal["crate"] = "crate") -> CrateResult: ...
[docs] def render(self, mode: RUST_MODES | None = None) -> FileResult | CrateResult: """ Render the template model of a rust file before serializing Args: mode (:class:`.RUST_MODES`, optional): Override the instance-level generation mode """ if mode is None: mode = self.mode sv = self.schemaview types = list(sv.all_types(imports=True).values()) types = self.before_generate_types(types, sv) types = [self.generate_type(t) for t in types] types = self.after_generate_types(types, sv) enums = list(sv.all_enums(imports=True).values()) enums = self.before_generate_enums(enums, sv) enums = [self.generate_enum(e) for e in enums] enums = self.after_generate_enums(enums, sv) slots = list(sv.induced_slot(s) for s in sv.all_slots()) slots = self.before_generate_slots(slots, sv) slots = [self.generate_slot(s) for s in slots] slots = self.after_generate_slots(slots, sv) need_merge_crate = False class_defs = [sv.induced_class(c) for c in sv.all_classes(ordered_by=OrderedBy.INHERITANCE)] root_class_def = self._select_root_class(class_defs) root_struct_name = get_name(root_class_def) if root_class_def is not None else None classes = class_defs for c in classes: if MERGE_ANNOTATION in c.annotations: need_merge_crate = True break classes = self.before_generate_classes(classes, sv) classes = [self.generate_class(c) for c in classes] classes = self.after_generate_classes(classes, sv) # Collect subproperty enums generated during class processing subproperty_enums = [e for e in self._subproperty_enums.values() if e is not None] poly_traits = [self.gen_poly_trait(sv.get_class(c)) for c in sv.all_classes(ordered_by=OrderedBy.INHERITANCE)] imports = DEFAULT_IMPORTS.model_copy() imports += PYTHON_IMPORTS imports += SERDE_IMPORTS if self.stubgen: imports += STUBGEN_IMPORTS if need_merge_crate: imports += MERGE_IMPORTS for result in [*enums, *slots, *classes]: imports += result.imports # Combine schema enums with subproperty enums all_enums = [e.enum for e in enums] + subproperty_enums file = RustFile( name=sv.schema.name, imports=imports, slots=[t.slot for t in slots], types=[t.type_ for t in types], enums=all_enums, structs=[c.cls for c in classes], pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, handwritten_lib=self.handwritten_lib, root_struct_name=root_struct_name, ) if mode == "crate": extra_files = {} extra_files["serde_utils"] = SerdeUtilsFile() extra_files["poly"] = PolyFile(imports=imports, traits=poly_traits) extra_files["poly_containers"] = PolyContainersFile() if self.stubgen: extra_files["stub_utils"] = StubUtilsFile() cargo = self.generate_cargo(imports) pyproject = self.generate_pyproject() bin_files = {} if self.stubgen: bin_files["bin/stub_gen"] = StubGenBin(crate_name=cargo.name, stubgen=self.stubgen) res = CrateResult( cargo=cargo, file=file, pyproject=pyproject, source=sv.schema, extra_files=extra_files, bin_files=bin_files, ) return res else: # Single file: inline serde utils, and skip poly modules file.inline_serde_utils = True file.emit_poly = False file.serde_utils = SerdeUtilsFile() res = FileResult(file=file, source=sv.schema) return res
def gen_poly_trait(self, cls: ClassDefinition) -> PolyTrait: impls = [] class_name = get_name(cls) attribs = self.schemaview.class_induced_slots(cls.name) superclass_names = [] if cls.is_a is not None: superclass_names.append(cls.is_a) for m in cls.mixins: superclass_names.append(m) superclasses = [self.schemaview.get_class(sn) for sn in superclass_names if sn is not None] for superclass in superclasses: attribs_sc = self.schemaview.class_induced_slots(superclass.name) attribs = [a for a in attribs if a.name not in [sc.name for sc in attribs_sc]] rust_attribs = [] for a in attribs: n = get_name(a) base_ri = self._get_range_info_with_subproperty(cls, a) promoted_ri = self.get_rust_range_info_across_descendants(cls, a) rust_attribs.append(PolyTraitProperty(name=n, range=base_ri, promoted_range=promoted_ri)) subtype_impls = [] for sc in self.schemaview.class_descendants(cls.name): sco = self.schemaview.get_class(sc) induced_slots = self.schemaview.class_induced_slots(sco.name) def find_slot(n: str): for s in induced_slots: if s.name == n: return s return None ptis = [ PolyTraitPropertyImpl( name=get_name(a), range=self._get_range_info_with_subproperty(sco, find_slot(a.name)), definition_range=self.get_rust_range_info_across_descendants(cls, a), trait_range=self.get_rust_range_info_across_descendants(cls, a), struct_name=get_name(sco), ) for a in attribs ] impls.append(PolyTraitImpl(name=class_name, struct_name=get_name(sco), attrs=ptis)) has_subtypes = has_real_subtypes(self.schemaview, sc) if has_subtypes: cases = [get_name(self.schemaview.get_class(x)) for x in class_real_descendants(self.schemaview, sc)] matches = [ PolyTraitPropertyMatch( name=get_name(a), range=self.get_rust_range_info_across_descendants(cls, a), cases=cases, struct_name=f"{get_name(sco)}OrSubtype", ) for a in attribs ] subtype_impls.append( PolyTraitImplForSubtypeEnum(name=class_name, enum_name=f"{get_name(sco)}OrSubtype", attrs=matches) ) return PolyTrait( name=class_name, impls=impls, attrs=rust_attribs, superclass_names=[get_name(scla) for scla in superclasses], subtypes=subtype_impls, )
[docs] def serialize(self, output: Path | None = None, mode: RUST_MODES | None = None, force: bool = False) -> str: """ Serialize a schema to a rust crate or file. Args: output (Path, optional): A ``.rs`` file if in ``file`` mode, directory otherwise. force (bool): If the output already exists, overwrite it. Otherwise raise a :class:`FileExistsError` """ if mode is None: mode = self.mode output = self._validate_output(output, mode, force) rendered = self.render(mode=mode) if mode == "crate": serialized = self.write_crate(output, rendered, force) else: serialized = rendered.file.render(self.template_environment) serialized = serialized.rstrip("\n") + "\n" with open(output, "w") as f: f.write(serialized) return serialized
[docs] def get_rust_range_info_across_descendants(self, cls: ClassDefinition, s: SlotDefinition) -> RustRange: """Compute a RustRange representing the union of a slot's ranges across a class and all its descendants. Container and optionality are taken from the base class slot. """ # If this slot has subproperty_of, return the enum type directly # (subproperty enums are fixed types that don't vary across descendants) subproperty_range = self._get_range_info_with_subproperty(cls, s) if self._get_subproperty_enum_type(s): return subproperty_range sv = self.schemaview # Collect rust type names for all ranges across base + descendants, and remember # the source class name (if any) responsible for each rust type so we can # correctly determine subtype presence against the metamodel (using class names, # not rust type identifiers). type_names: list[str] = [] rust_to_class: dict[str, str | None] = {} def add_for_slot(slot_def: SlotDefinition): for r in sv.slot_range_as_union(slot_def): if r in sv.all_classes(): # Special-case: treat Anything/AnyValue as inline to ensure # promoted unions include the corresponding variant. if r in {"Anything", "AnyValue"}: tname = get_rust_type(r, sv, True) rust_to_class[tname] = r if tname not in type_names: type_names.append(tname) continue # Prefer concrete observations: only add String if explicitly non-inlined inl = slot_def.inlined inl_list = slot_def.inlined_as_list if inl is True or inl_list is True: tname = get_rust_type(r, sv, True) rust_to_class[tname] = r elif inl is False and (inl_list is False or inl_list is None): tname = "String" rust_to_class[tname] = None else: # Unknown inlining at this definition; skip adding a guess continue else: tname = get_rust_type(r, sv, True) rust_to_class[tname] = None if tname not in type_names: type_names.append(tname) base_slot = sv.induced_slot(s.name, cls.name) if base_slot is not None: add_for_slot(base_slot) # Include descendants in the class inheritance tree for d in sv.class_descendants(cls.name): ds = sv.induced_slot(s.name, d) if ds is not None: add_for_slot(ds) # If this is a mixin, include classes that use the mixin and their descendants try: all_classes = list(sv.all_classes()) except Exception: all_classes = [] for cname in all_classes: cdef = sv.get_class(cname) if cdef is None: continue if cls.name in (cdef.mixins or []): ds = sv.induced_slot(s.name, cname) if ds is not None: add_for_slot(ds) for dd in sv.class_descendants(cname): dslot = sv.induced_slot(s.name, dd) if dslot is not None: add_for_slot(dslot) container_mode, _ = determine_slot_mode(s, sv) # Optionality across descendants/mixin users: optional if not all are required all_required = True def consider_required(slot_def: SlotDefinition): nonlocal all_required if not bool(slot_def.required): all_required = False if base_slot is not None: consider_required(base_slot) for d in sv.class_descendants(cls.name): ds = sv.induced_slot(s.name, d) if ds is not None: consider_required(ds) try: all_classes = list(sv.all_classes()) except Exception: all_classes = [] for cname in all_classes: cdef = sv.get_class(cname) if cdef is None: continue if cls.name in (cdef.mixins or []): ds = sv.induced_slot(s.name, cname) if ds is not None: consider_required(ds) for dd in sv.class_descendants(cname): dslot = sv.induced_slot(s.name, dd) if dslot is not None: consider_required(dslot) base_optional = not all_required if len(type_names) > 1: child_ranges = [ RustRange( type_=t, is_class_range=t not in ("String", "bool", "f64", "isize"), ) for t in type_names ] return RustRange( optional=base_optional, has_default=base_optional or (s.multivalued or False), containerType=( ContainerType.LIST if container_mode == SlotContainerMode.LIST else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None ), child_ranges=child_ranges, is_class_range=False, is_reference=False, type_=underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range", ) else: # Fall back to base definition only if nothing was observed concretely if len(type_names) == 0 and base_slot is not None: for r in sv.slot_range_as_union(base_slot): if r in sv.all_classes(): inl = base_slot.inlined inl_list = base_slot.inlined_as_list if inl is True or inl_list is True: tname = get_rust_type(r, sv, True) if tname not in type_names: type_names.append(tname) rust_to_class[tname] = r else: tname = get_rust_type(r, sv, True) if tname not in type_names: type_names.append(tname) rust_to_class[tname] = None # If still empty, fall back to original per-class range info if len(type_names) == 0: return get_rust_range_info(cls, s, sv) single = type_names[0] single_src_class = rust_to_class.get(single, None) return RustRange( optional=base_optional, has_default=base_optional or (s.multivalued or False), containerType=( ContainerType.LIST if container_mode == SlotContainerMode.LIST else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None ), child_ranges=None, is_class_range=single not in ("String", "bool", "f64", "isize"), is_reference=False, has_class_subtypes=( has_real_subtypes(self.schemaview, single_src_class) if single_src_class is not None else False ), type_=single, )
def write_crate( self, output: Path | None = None, rendered: FileResult | CrateResult = None, force: bool = False ) -> str: output = self._validate_output(output, mode="crate", force=force) if rendered is None: rendered = self.render(mode="crate") cargo = rendered.cargo.render(self.template_environment) cargo_file = output / "Cargo.toml" self._write_text_file(cargo_file, cargo, crate_root=output) pyproject = rendered.pyproject.render(self.template_environment) pyproject_file = output / "pyproject.toml" self._write_text_file(pyproject_file, pyproject, crate_root=output) rust_file = rendered.file.render(self.template_environment) src_dir = output / "src" src_dir.mkdir(exist_ok=True) if self.handwritten_lib: generated_dir = src_dir / "generated" generated_dir.mkdir(exist_ok=True) lib_file = generated_dir / "mod.rs" else: generated_dir = src_dir lib_file = src_dir / "lib.rs" self._write_text_file(lib_file, rust_file, crate_root=output) for k, f in rendered.extra_files.items(): extra_file = f.render(self.template_environment) extra_file_name = f"{k}.rs" extra_file_path = self._safe_subpath(generated_dir, extra_file_name) self._write_text_file(extra_file_path, extra_file, crate_root=output) if getattr(rendered, "bin_files", None): for rel_path, template in rendered.bin_files.items(): rendered_bin = template.render(self.template_environment) safe_bin_base = self._safe_subpath(src_dir, rel_path) bin_path = safe_bin_base.with_suffix(".rs") self._write_text_file(bin_path, rendered_bin, crate_root=output) if self.handwritten_lib: shim_path = src_dir / "lib.rs" if not shim_path.exists(): root_struct_name = getattr(rendered.file, "root_struct_name", None) root_struct_fn_snake = underscore(uncamelcase(root_struct_name)) if root_struct_name else None shim_template = RustLibShim( module_name=rendered.file.name, pyo3=self.pyo3, serde=self.serde, stubgen=self.stubgen, handwritten_lib=self.handwritten_lib, root_struct_name=root_struct_name, root_struct_fn_snake=root_struct_fn_snake, ) shim = shim_template.render(self.template_environment) self._write_text_file(shim_path, shim, crate_root=output) return rust_file def _validate_output(self, output: Path | None = None, mode: RUST_MODES | None = None, force: bool = False) -> Path: """Raise a ValueError if given a dir when in file mode or vice versa""" if output is None: if self.output is None: raise ValueError("Must provide an output if generator doesn't already have one") else: output = Path(self.output) else: output = Path(output) if mode == "file": assert output.suffix == ".rs", "Output must be a rust file in file mode" if not force and output.exists(): raise FileExistsError(f"{output} already exists and force is False! pass force=True to overwrite") output.parent.mkdir(exist_ok=True, parents=True) elif mode == "crate": if not force and len([d for d in output.iterdir()]) != 0: raise FileExistsError( f"{output} already exists, is not empty, and force is False! pass force=True to overwrite" ) output.mkdir(exist_ok=True, parents=True) else: raise ValueError(f"Invalid generation mode: {mode}") return output def _safe_subpath(self, base: Path, relative: str | Path) -> Path: """Return a path nested under base, validating it does not escape.""" rel_path = Path(relative) if rel_path.is_absolute(): raise ValueError(f"Relative path expected, got absolute path: {relative}") if not rel_path.parts: raise ValueError("Relative path must contain at least one segment") for part in rel_path.parts: if part in (".", ".."): raise ValueError(f"Invalid path segment: {part}") if "/" in part or "\\" in part: raise ValueError(f"Path segment must not contain separators: {part}") candidate = base / rel_path base_resolved = base.resolve() try: candidate.resolve().relative_to(base_resolved) except ValueError as exc: # pragma: no cover - defensive raise ValueError(f"Path {candidate} escapes base directory {base}") from exc return candidate def _write_text_file(self, path: Path, content: str, *, crate_root: Path) -> None: """Normalize trailing newline, ensure parent dirs, and write text.""" base_resolved = crate_root.resolve() try: path.resolve().relative_to(base_resolved) except ValueError as exc: raise ValueError(f"Path {path} escapes crate root {crate_root}") from exc normalized = content.rstrip("\n") + "\n" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(normalized) @property def template_environment(self) -> Environment: if self._environment is None: self._environment = RustTemplateModel.environment() return self._environment