Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions flatdata-generator/flatdata/generator/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BaseGenerator(metaclass=ABCMeta):

def __init__(self, template: str) -> None:
self._template = template
self._env: Environment | None = None

@abstractmethod
def supported_nodes(self) -> list[type]:
Expand All @@ -49,30 +50,36 @@ def get_import_directives(self, tree: SyntaxTree) -> list[str]:
"""Return language-specific import directives. Override in subclasses."""
return []

def _get_environment(self, tree: SyntaxTree) -> Environment:
if self._env is None:
env = Environment(loader=PackageLoader('flatdata.generator', 'templates'), lstrip_blocks=True,
trim_blocks=True, autoescape=False, extensions=[RaiseExtension])
env.filters['is_archive'] = lambda n: isinstance(n, Archive)
env.filters['is_instance'] = lambda n: isinstance(n, Instance)
env.filters['is_raw_data'] = lambda n: isinstance(n, RawData)
env.filters['is_archive_resource'] = lambda n: isinstance(
n, ArchiveResource)
env.filters['is_structure'] = lambda n: isinstance(n, Structure)
env.filters['is_enumeration_reference'] = lambda n: isinstance(n, EnumerationReference)
env.filters['is_enumeration'] = lambda n: isinstance(n, Enumeration)
env.filters['is_constant'] = lambda n: isinstance(n, Constant)
env.filters['is_namespace'] = lambda n: isinstance(n, Namespace)
env.filters['is_resource'] = lambda n: isinstance(n, ResourceBase)
env.filters['is_bound_resource'] = lambda n: isinstance(
n, BoundResource)
env.filters['is_vector'] = lambda n: isinstance(n, Vector)
env.filters['is_multivector'] = lambda n: isinstance(n, Multivector)
env.filters['is_multivector_index'] = lambda n: (isinstance(
n, Structure) and "_builtin.multivector" in SyntaxTree.namespace_path(n))
env.filters['namespaces'] = SyntaxTree.namespaces
env.filters['not_auto_generated'] = lambda n: [ x for x in n if not x.auto_generated]
self._populate_environment(env, tree)
self._env = env
return self._env

def render(self, tree: SyntaxTree) -> str:
"""Generate the language implementation from the AST"""
env = Environment(loader=PackageLoader('flatdata.generator', 'templates'), lstrip_blocks=True,
trim_blocks=True, autoescape=False, extensions=[RaiseExtension])
env.filters['is_archive'] = lambda n: isinstance(n, Archive)
env.filters['is_instance'] = lambda n: isinstance(n, Instance)
env.filters['is_raw_data'] = lambda n: isinstance(n, RawData)
env.filters['is_archive_resource'] = lambda n: isinstance(
n, ArchiveResource)
env.filters['is_structure'] = lambda n: isinstance(n, Structure)
env.filters['is_enumeration_reference'] = lambda n: isinstance(n, EnumerationReference)
env.filters['is_enumeration'] = lambda n: isinstance(n, Enumeration)
env.filters['is_constant'] = lambda n: isinstance(n, Constant)
env.filters['is_namespace'] = lambda n: isinstance(n, Namespace)
env.filters['is_resource'] = lambda n: isinstance(n, ResourceBase)
env.filters['is_bound_resource'] = lambda n: isinstance(
n, BoundResource)
env.filters['is_vector'] = lambda n: isinstance(n, Vector)
env.filters['is_multivector'] = lambda n: isinstance(n, Multivector)
env.filters['is_multivector_index'] = lambda n: (isinstance(
n, Structure) and "_builtin.multivector" in SyntaxTree.namespace_path(n))
env.filters['namespaces'] = SyntaxTree.namespaces
env.filters['not_auto_generated'] = lambda n: [ x for x in n if not x.auto_generated]
self._populate_environment(env, tree)
env = self._get_environment(tree)
template = env.get_template(self._template)

flatdata_nodes = [n for n, _ in DfsTraversal(tree).dependency_order() if
Expand Down
36 changes: 29 additions & 7 deletions flatdata-generator/flatdata/generator/tree/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
_T = TypeVar('_T', bound='Node')


class _NodeNotFoundError(RuntimeError):
"""Lazy error that defers expensive symbols() computation to str()."""

def __init__(self, path: str, tree: 'Node') -> None:
self._path = path
self._tree = tree

def __str__(self) -> str:
return "Path '{path}' not found in tree. Options: {options}".format(
path=self._path, options=tuple(self._tree.symbols()))


class Node:
"""
Node of a Syntax Tree.
Expand Down Expand Up @@ -48,6 +60,7 @@ def __init__(self, name: str, properties: ParseResults | None = None) -> None:
self._parent: Node | None = None
self._source_file: str | None = None
self._is_local: bool = True
self._cached_path: str | None = None

@property
def source_file(self) -> str | None:
Expand Down Expand Up @@ -110,9 +123,17 @@ def path(self) -> str:
"""
Returns nodes' path in a tree.
"""
if self._parent is None:
return self.name
return Node.jointwo(self._parent.path, self.name)
if self._cached_path is None:
if self._parent is None:
self._cached_path = self.name
else:
self._cached_path = Node.jointwo(self._parent.path, self.name)
return self._cached_path

def _invalidate_path_cache(self) -> None:
self._cached_path = None
for child in self._children.values():
child._invalidate_path_cache()

def path_with(self, separator: str = '_') -> str:
"""
Expand Down Expand Up @@ -142,6 +163,7 @@ def set_name(self, value: str) -> None:
"Cannot rename the node, name {value} is already in use".format(value=value))

self._name = value
self._invalidate_path_cache()
if self.parent is not None:
self.parent.reindex()

Expand All @@ -157,14 +179,12 @@ def find(self, path: str) -> Node:
try:
target = self
if target.name != keys[0]:
raise RuntimeError("Path {path} not found in tree. Options: {options}".format(
path=path, options=tuple(self.symbols())))
raise _NodeNotFoundError(path, self)

for key in keys[1:]:
target = target._children[key]
except (KeyError, IndexError):
raise RuntimeError("Path '{path}' not found in tree. Options: {options}".format(
path=path, options=tuple(self.symbols())))
raise _NodeNotFoundError(path, self)
return target

def get(self, path: str, default: Node | None = None) -> Node | None:
Expand Down Expand Up @@ -251,6 +271,7 @@ def insert(self, *nodes: Node) -> Node:

self._children[node.name] = node
node._parent = self
node._invalidate_path_cache()
return self

def erase(self, key: str) -> None:
Expand Down Expand Up @@ -300,6 +321,7 @@ def detach(self) -> Node:
return self
del self._parent._children[self.name]
self._parent = None
self._invalidate_path_cache()
return self

def symbols(self, include_types: bool = False) -> set[str] | dict[str, type]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Reference(Node):

def __init__(self, name: str) -> None:
super().__init__(name=Reference._referencify(name))
self._cached_node: Node | None = None

@property
def target(self) -> str:
Expand All @@ -23,11 +24,14 @@ def update_reference(self, new_value: str) -> None:
assert new_value.endswith(self.target), \
"References can only be updated during resolution for the same symbol: %s -> %s" % \
(self.target, new_value)
self._cached_node = None
self.set_name(Reference._referencify(new_value))

@property
def node(self) -> Node:
return self.root.find(self.target)
if self._cached_node is None:
self._cached_node = self.root.find(self.target)
return self._cached_node

@property
def is_qualified(self) -> bool:
Expand Down
12 changes: 9 additions & 3 deletions flatdata-generator/flatdata/generator/tree/syntax_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING

from flatdata.generator.tree.nodes.references import TypeReference
from flatdata.generator.tree.nodes.trivial import Namespace
Expand All @@ -15,6 +16,9 @@
from flatdata.generator.tree.nodes.root import Root
from flatdata.generator.tree.importer import ImportInfo

if TYPE_CHECKING:
from flatdata.generator.generators import BaseGenerator

class SyntaxTree:
"""
Flatdata Syntax Tree.
Expand Down Expand Up @@ -99,14 +103,16 @@ def _unique(sequence: list[Node]) -> list[Node]:
nodes.extend(dependent_type)
return _unique(nodes)

_schema_generator: 'BaseGenerator | None' = None

@staticmethod
def schema(node: Node) -> str:
from ..generators.flatdata import FlatdataGenerator
generator = FlatdataGenerator()
if SyntaxTree._schema_generator is None:
SyntaxTree._schema_generator = FlatdataGenerator()

# extract subtree from syntax tree
subtree = node.extract_subtree()
return str(generator.render(SyntaxTree(subtree)))
return str(SyntaxTree._schema_generator.render(SyntaxTree(subtree)))

@staticmethod
def namespaces(node: Node) -> Iterator[Node]:
Expand Down
Loading