Source code for linkml.generators.erdiagramgen

import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union

import click
import pydantic

from linkml._version import __version__
from linkml.utils.generator import Generator, shared_arguments
from linkml_runtime.linkml_model.meta import ClassDefinition, ClassDefinitionName, SlotDefinition
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView

logger = logging.getLogger(__name__)

MERMAID_SERIALIZATION = str


class Attribute(pydantic.BaseModel):
    datatype: str = "string"
    name: str = None
    key: str = None
    comment: str = None

    def __str__(self):
        cmt = f'"{self.comment}"' if self.comment else ""
        return f"    {self.datatype} {self.name} {self.key if self.key else ''} {cmt}"


class IdentifyingType(str, Enum):
    """Identifying types.

    See: https://mermaid.js.org/syntax/entityRelationshipDiagram.html#identification"""

    IDENTIFYING = "--"
    NON_IDENTIFYING = ".."

    def __str__(self):
        return self.value


class Cardinality(pydantic.BaseModel):
    """Cardinality of a slot.

    See https://mermaid.js.org/syntax/entityRelationshipDiagram.html#relationship-syntax"""

    required: bool = True
    multivalued: bool = False
    is_left: bool = True

    def __str__(self):
        required_char = "|" if self.required else "o"
        if self.multivalued:
            multivalued_char = "}" if self.is_left else "{"
        else:
            multivalued_char = "|"
        if self.is_left:
            return f"{multivalued_char}{required_char}"
        else:
            return f"{required_char}{multivalued_char}"


class RelationshipType(pydantic.BaseModel):
    """Relationship type.
    See https://mermaid.js.org/syntax/entityRelationshipDiagram.html#relationship-syntax"""

    left_cardinality: Cardinality
    identifying: IdentifyingType = IdentifyingType.IDENTIFYING
    right_cardinality: Cardinality = None

    def __str__(self):
        return f"{self.left_cardinality}{self.identifying}{self.right_cardinality}"


class Entity(pydantic.BaseModel):
    """Entity in an ER diagram."""

    name: str
    attributes: list[Attribute] = []

    def __str__(self):
        attrs = "\n".join([str(a) for a in self.attributes])
        return f"{self.name} {{\n{attrs}\n}}"


class Relationship(pydantic.BaseModel):
    """Relationship in an ER diagram."""

    first_entity: str
    relationship_type: RelationshipType = None
    second_entity: str = None
    relationship_label: str = None

    def __str__(self):
        return f'{self.first_entity} {self.relationship_type} {self.second_entity} : "{self.relationship_label}"'


class ERDiagram(pydantic.BaseModel):
    """Represents an Diagram of Entities and Relationships"""

    entities: list[Entity] = []
    relationships: list[Relationship] = []

    def _collapse_relationships(self, relationships: list[Relationship]) -> list[Relationship]:
        """
        Collapse multiple relationships between the same two entities into a single relationship
        with comma-separated labels.

        Only relationships with the same cardinality (relationship_type) are collapsed together,
        preserving semantic differences between required/optional and single/multivalued relationships.

        :param relationships: List of relationships
        :return: List of collapsed relationships
        """
        # Group relationships by (first_entity, second_entity, relationship_type) tuple
        # This preserves different cardinalities as separate relationships
        relationship_groups = {}
        for rel in relationships:
            # Use string representation of relationship_type for grouping
            rel_type_key = str(rel.relationship_type) if rel.relationship_type else ""
            key = (rel.first_entity, rel.second_entity, rel_type_key)
            if key not in relationship_groups:
                relationship_groups[key] = []
            relationship_groups[key].append(rel)

        # Collapse groups with multiple relationships
        collapsed = []
        for key, group in relationship_groups.items():
            if len(group) == 1:
                # Single relationship, keep as is
                collapsed.append(group[0])
            else:
                # Multiple relationships with same cardinality, combine labels
                base_rel = group[0]
                labels = [rel.relationship_label for rel in group if rel.relationship_label]
                combined_label = ", ".join(sorted(labels))

                # Create new collapsed relationship
                collapsed_rel = Relationship(
                    first_entity=base_rel.first_entity,
                    relationship_type=base_rel.relationship_type,
                    second_entity=base_rel.second_entity,
                    relationship_label=combined_label,
                )
                collapsed.append(collapsed_rel)

        return collapsed

    def __str__(self):
        # Sort entities and relationships for stable output
        sorted_entities = sorted(self.entities, key=lambda e: e.name)

        # Collapse multiple relationships between same entities
        collapsed_relationships = self._collapse_relationships(self.relationships)
        sorted_relationships = sorted(collapsed_relationships, key=str)

        ents = "\n".join([str(e) for e in sorted_entities])
        rels = "\n".join([str(r) for r in sorted_relationships])
        return f"erDiagram\n{ents}\n\n{rels}\n"


[docs] @dataclass class ERDiagramGenerator(Generator): """ A generator for serializing schemas as Entity-Relationship diagrams. Currently this generates diagrams in mermaid syntax, but in future this could easily be extended to have for example a direct SVG or PNG generator using PyGraphViz, similar to erdantic. """ # ClassVars generatorname = os.path.basename(__file__) generatorversion = "0.0.1" valid_formats = ["markdown", "mermaid"] uses_schemaloader = False requires_metamodel = False structural: bool = True """If True, then only the tree_root and entities reachable from the root are drawn""" exclude_attributes: bool = False """If True, do not include attributes in entities""" exclude_abstract_classes: bool = False """If True, do not include abstract classes in the diagram""" genmeta: bool = False gen_classvars: bool = True gen_slots: bool = True no_types_dir: bool = False use_slot_uris: bool = False preserve_names: bool = False """If true, preserve LinkML element names (classes/slots) in diagram labels.""" def __post_init__(self): self.schemaview = SchemaView(self.schema) super().__post_init__() # Mermaid ERD reserved keywords that conflict with entity names # See: https://mermaid.js.org/syntax/entityRelationshipDiagram.html # "Class" confirmed as reserved; others may exist but are undocumented _MERMAID_ERD_RESERVED_KEYWORDS = {"Class"} def _sanitize_class_name_for_erd(self, name: str) -> str: """ Sanitize class name for ERD diagram output. Mermaid ERD syntax has reserved keywords that conflict with common class names. This method prefixes problematic names with '__' to avoid conflicts. :param name: Original class name :return: Sanitized class name safe for ERD diagrams """ if name in self._MERMAID_ERD_RESERVED_KEYWORDS: return f"__{name}" return name
[docs] def serialize(self) -> MERMAID_SERIALIZATION: """ Serialize a schema as an ER Diagram. If a tree_root is present in the schema, then only Entities traversable from here will be included. Otherwise, all Entities will be included. :return: mermaid string """ sv = self.schemaview structural_roots = [cn for cn, c in sv.all_classes().items() if c.tree_root] if self.structural and structural_roots: return self.serialize_classes(structural_roots, follow_references=True) else: diagram = ERDiagram() for cn in sv.all_classes(): self.add_class(cn, diagram) return self.serialize_diagram(diagram)
[docs] def serialize_classes( self, class_names: list[Union[str, ClassDefinitionName]], follow_references=False, max_hops: int = None, include_upstream: bool = False, ) -> MERMAID_SERIALIZATION: """ Serialize a list of classes as an ER Diagram. This will also traverse the reference graph and include any Entities reachable from the specified classes. By default, all reachable Entities are included, unless max_hops is specified. :param class_names: initial seed :param follow_references: if True, follow references even if not inlined :param max_hops: maximum number of hops to follow references :return: """ visited = set() sv = self.schemaview stack = [(cn, 0) for cn in class_names] diagram = ERDiagram() while stack: cn, depth = stack.pop() if cn in visited: continue self.add_class(cn, diagram) visited.add(cn) if max_hops is not None and depth >= max_hops: continue for slot in sv.class_induced_slots(cn): rng = slot.range if rng in sv.all_classes(): if follow_references or sv.is_inlined(slot): if rng not in visited: stack.append((rng, depth + 1)) # Now Add upstream classes if needed if include_upstream: for sn in sv.all_slots(): slot = sv.schema.slots.get(sn) if slot and slot.range in set(class_names): for cl in sv.all_classes(): if slot.name in sv.get_class(cl).slots and cl not in visited: self.add_upstream_class(cl, set(class_names), diagram) return self.serialize_diagram(diagram)
def serialize_diagram(self, diagram: ERDiagram) -> str: """ Serialize an ER Diagram, from the native pydantic schema. :param diagram: ER Diagram :return: """ er = str(diagram) if self.format == "markdown": return f"```mermaid\n{er}\n```\n" else: return er def add_upstream_class(self, class_name: ClassDefinitionName, targets: set[str], diagram: ERDiagram) -> None: sv = self.schemaview cls = sv.get_class(class_name) if self.exclude_abstract_classes and cls.abstract: return entity_name = cls.name if self.preserve_names else camelcase(cls.name) entity = Entity(name=self._sanitize_class_name_for_erd(entity_name)) diagram.entities.append(entity) for slot in sv.class_induced_slots(class_name): if slot.range in targets: self.add_relationship(entity, slot, diagram) def _create_attribute_sort_key(self, slot: SlotDefinition, cls: ClassDefinition): """ Create a sort key for attributes for stable, deterministic output. Attributes are sorted by: 1. Source type (own/direct slots first, then inherited) 2. Priority slots (id, name, description) 3. Rank (if defined) 4. Alphabetically :param slot: The slot to create a sort key for :param cls: The class this slot belongs to :return: Sort key tuple """ sv = self.schemaview priority_slots = ["id", "name", "description"] # Determine if slot is inherited or own # A slot is inherited if it's in domain_of of a parent class AND not overridden # in the current class's slot_usage (which makes it "own") is_inherited = False domain_of = slot.domain_of if hasattr(slot, "domain_of") and slot.domain_of else [] # Check if slot is explicitly customized in current class's slot_usage slot_overridden_in_current = cls.slot_usage and slot.name in cls.slot_usage if domain_of and hasattr(cls, "name") and cls.name and not slot_overridden_in_current: # Use class_ancestors for efficiency (includes is_a and mixins) for ancestor in sv.class_ancestors(cls.name, reflexive=False): if ancestor in domain_of: is_inherited = True break # Primary sort: own/direct (0) vs inherited (1) primary_sort = 1 if is_inherited else 0 # Secondary sort: priority slots come first if slot.name in priority_slots: secondary_sort = priority_slots.index(slot.name) else: secondary_sort = len(priority_slots) # Tertiary sort: by rank if exists rank = getattr(slot, "rank", None) rank_value = rank if rank is not None else float("inf") # Quaternary sort: alphabetically alphabetical_sort = slot.name return (primary_sort, secondary_sort, rank_value, alphabetical_sort) def add_class(self, class_name: ClassDefinitionName, diagram: ERDiagram) -> None: """ Add a class to the ER Diagram. :param class_name: ClassDefinitionName :param diagram: ER Diagram :return: """ sv = self.schemaview cls = sv.get_class(class_name) if self.exclude_abstract_classes and cls.abstract: return entity_name = cls.name if self.preserve_names else camelcase(cls.name) entity = Entity(name=self._sanitize_class_name_for_erd(entity_name)) diagram.entities.append(entity) # Collect all slots first (to enable sorting before adding) all_slots = {} # slot_name -> SlotDefinition relationship_slots = [] attribute_slots = [] # Process regular induced slots for slot in sv.class_induced_slots(class_name): # TODO: schemaview should infer this if slot.range is None: slot.range = sv.schema.default_range or "string" all_slots[slot.name] = slot # Also process slots defined in slot_usage (which may not appear in induced_slots) # This includes slot_usage from the current class and all ancestors (including mixins) slot_usage_to_check = set() # Collect slot_usage from current class if cls.slot_usage: slot_usage_to_check.update(cls.slot_usage.keys()) # Collect slot_usage from all ancestor classes (includes is_a and mixins) for ancestor_name in sv.class_ancestors(class_name, reflexive=False): ancestor_cls = sv.get_class(ancestor_name) if ancestor_cls and ancestor_cls.slot_usage: slot_usage_to_check.update(ancestor_cls.slot_usage.keys()) # Process all slot_usage entries for slot_name in slot_usage_to_check: if slot_name not in all_slots: # Get the induced slot to get the properly overridden range try: slot = sv.induced_slot(slot_name, class_name) if slot.range is None: slot.range = sv.schema.default_range or "string" all_slots[slot_name] = slot except (KeyError, AttributeError, ValueError) as e: # Slot might not exist, have missing attributes, or not be a valid induced slot, skip it logger.debug(f"Could not get induced slot '{slot_name}' for class '{class_name}': {e}") # Separate into relationships and attributes for slot in all_slots.values(): if slot.range in sv.all_classes(): relationship_slots.append(slot) else: attribute_slots.append(slot) # Sort attribute slots using the same logic as the markdown data dictionary attribute_slots.sort(key=lambda s: self._create_attribute_sort_key(s, cls)) # Add sorted attributes to entity for slot in attribute_slots: self.add_attribute(entity, slot) # Add relationships (order doesn't matter as much for relationships) for slot in relationship_slots: self.add_relationship(entity, slot, diagram) def add_relationship(self, entity: Entity, slot: SlotDefinition, diagram: ERDiagram) -> None: """ Add a relationship to the ER Diagram. :param class_name: ClassDefinitionName :param slot: SlotDefinition :param diagram: ER Diagram :return: """ sv = self.schemaview rel_type = RelationshipType( right_cardinality=Cardinality( required=slot.required is True, multivalued=slot.multivalued is True, is_left=True ), left_cardinality=Cardinality(is_left=False), ) second_entity_name = ( sv.get_class(slot.range).name if self.preserve_names else camelcase(sv.get_class(slot.range).name) ) rel = Relationship( first_entity=entity.name, relationship_type=rel_type, second_entity=self._sanitize_class_name_for_erd(second_entity_name), relationship_label=slot.name, ) diagram.relationships.append(rel) def add_attribute(self, entity: Entity, slot: SlotDefinition) -> None: """ Add an attribute to the ER Diagram. :param class_name: Class :param slot: SlotDefinition :return: """ if self.exclude_attributes: return dt = slot.range if slot.multivalued: # NOTE: mermaid does not support []s or *s in attribute types dt = f"{dt}List" attr = Attribute(name=(slot.name if self.preserve_names else underscore(slot.name)), datatype=dt) entity.attributes.append(attr)
@shared_arguments(ERDiagramGenerator) @click.option( "--structural/--no-structural", default=True, help="If True, then only the tree_root and entities reachable from the root are drawn", ) @click.option( "--exclude-attributes/--no-exclude-attributes", default=False, help="If True, do not include attributes in entities", ) @click.option( "--exclude-abstract-classes/--no-exclude-abstract-classes", default=False, help="If True, do not include abstract classes in the diagram", ) @click.option( "--follow-references/--no-follow-references", default=False, help="If True, follow references even if not inlined", ) @click.option("--max-hops", default=None, type=click.INT, help="Maximum number of hops") @click.option("--classes", "-c", multiple=True, help="List of classes to serialize") @click.option("--include-upstream", is_flag=True, help="Include upstream classes") @click.version_option(__version__, "-V", "--version") @click.command(name="erdiagram") def cli( yamlfile, classes: list[str], max_hops: Optional[int], follow_references: bool, include_upstream: bool = False, **args, ): """Generate a mermaid ER diagram from a schema. By default, all entities traversable from the tree_root are included. If no tree_root is present, then all entities are included. To create an ER diagram for selected classes, use the --classes option. """ gen = ERDiagramGenerator( yamlfile, **args, ) if classes: print( gen.serialize_classes( classes, follow_references=follow_references, max_hops=max_hops, include_upstream=include_upstream, ) ) else: print(gen.serialize()) if __name__ == "__main__": cli()