From f4e9cda9487977db75939f8c30991df6f8083e1d Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 7 Jan 2022 10:17:19 -0500 Subject: [PATCH 1/9] Initial support for flamegraph diffing --- graphtage/__init__.py | 2 +- graphtage/flamegraph.py | 146 ++++++++++++++++++++++++++++++++++++++++ test/test_formatting.py | 37 +++++++++- 3 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 graphtage/flamegraph.py diff --git a/graphtage/__init__.py b/graphtage/__init__.py index 3bae418..7d9eb48 100644 --- a/graphtage/__init__.py +++ b/graphtage/__init__.py @@ -7,7 +7,7 @@ from .version import __version__, VERSION_STRING from . import bounds, edits, expressions, fibonacci, formatter, levenshtein, matching, printer, \ search, sequences, tree, utils -from . import csv, json, xml, yaml, plist +from . import csv, json, xml, yaml, plist, flamegraph import inspect diff --git a/graphtage/flamegraph.py b/graphtage/flamegraph.py new file mode 100644 index 0000000..2900e68 --- /dev/null +++ b/graphtage/flamegraph.py @@ -0,0 +1,146 @@ +"""A :class:`graphtage.Filetype` for parsing, diffing, and rendering `flame graphs`_. + +There are many libraries in different languages to produce a flame graph from a profiling run. +There unfortunately isn't a standardized textual file format to represent flame graphs. +Graphtage uses this common format: + +.. code-block:: none + + function1 #samples + function1;function2 #samples + function1;function2;function3 #samples + +In other words, each line of the file is a stack trace represented by a ``;``-delimited list of function names +followed by a space and the integer number of times that stack trace was sampled in the profiling run. + +.. _flame graphs: + https://www.brendangregg.com/flamegraphs.html +""" + +from typing import Iterable, List, Optional, Union + +from . import Printer +from .edits import EditSequence +from .graphtage import BuildOptions, Edit, Filetype, IntegerNode, ListNode, MultiSetNode, StringNode, TreeNode +from .sequences import SequenceFormatter +from .tree import GraphtageFormatter + + +class FlameGraphParseError(ValueError): + pass + + +class StackTrace(ListNode[StringNode]): + def __init__( + self, + functions: Iterable[StringNode], + samples: IntegerNode, + allow_list_edits: bool = True, + allow_list_edits_when_same_length: bool = True + ): + super().__init__(functions, allow_list_edits, allow_list_edits_when_same_length) + self.samples: IntegerNode = samples + + def to_obj(self): + return [n.to_obj() for n in self] + [self.samples] + + def edits(self, node: TreeNode) -> Edit: + # first, match the functions: + edit = super().edits(node) + if not isinstance(node, StackTrace) or self.samples == node.samples: + return edit + # now match the samples: + return EditSequence( + from_node=self, + to_node=node, + edits=(edit, self.samples.edits(node.samples)) + ) + + +class FlameGraph(MultiSetNode[StackTrace]): + pass + + +class StackTraceFormatter(SequenceFormatter): + is_partial = True + + def __init__(self): + super().__init__('', '', ';') + + def item_newline(self, printer: Printer, is_first: bool = False, is_last: bool = False): + pass + + def print(self, printer: Printer, *args, **kwargs): + # Flamegraphs are not indented + printer.indent_str = "" + super().print(printer, *args, **kwargs) + + def print_StackTrace(self, printer: Printer, node: StackTrace): + self.print_SequenceNode(printer, node) + printer.write(" ") + self.print(printer, node.samples) + + +class FlameGraphFormatter(SequenceFormatter): + sub_format_types = [StackTraceFormatter] + + def __init__(self): + super().__init__('', '', '') + + def print(self, printer: Printer, *args, **kwargs): + # Flamegraphs are not indented + printer.indent_str = "" + super().print(printer, *args, **kwargs) + + +class FlameGraphFile(Filetype): + """A textual representation of a flame graph.""" + def __init__(self): + """Initializes the FlameGraph file type. + + There is no official MIME type associated with a flame graph. Graphtage assigns it the MIME type + ``text/x-flame-graph``. + + """ + super().__init__( + 'flamegraph', + 'text/x-flame-graph' + ) + + def build_tree(self, path: str, options: Optional[BuildOptions] = None) -> FlameGraph: + traces: List[StackTrace] = [] + allow_list_edits = options is None or options.allow_list_edits + allow_list_edits_when_same_length = options is None or options.allow_list_edits_when_same_length + with open(path, "r") as f: + for n, line in enumerate(f): + line = line.strip() + if not line: + continue + # first parse the int off the end: + final_int = "" + while ord('0') <= ord(line[-1]) <= ord('9'): + final_int = f"{line[-1]}{final_int}" + line = line[:-1] + if not final_int: + raise FlameGraphParseError(f"{path}:{n+1} expected the line to end with an integer number of " + "samples") + samples = int(final_int) + functions = line.strip().split(";") + if not functions: + raise FlameGraphParseError(f"{path}:{n+1} the line did not contain a stack trace") + traces.append(StackTrace( + functions=(StringNode(f, quoted=False) for f in functions), + samples=IntegerNode(samples), + allow_list_edits=allow_list_edits, + allow_list_edits_when_same_length=allow_list_edits_when_same_length + )) + return FlameGraph(traces) + + def build_tree_handling_errors(self, path: str, options: Optional[BuildOptions] = None) -> Union[str, TreeNode]: + try: + return self.build_tree(path=path, options=options) + except FlameGraphParseError as e: + return str(e) + + def get_default_formatter(self) -> FlameGraphFormatter: + return FlameGraphFormatter.DEFAULT_INSTANCE diff --git a/test/test_formatting.py b/test/test_formatting.py index bc65098..e018ffb 100644 --- a/test/test_formatting.py +++ b/test/test_formatting.py @@ -84,12 +84,16 @@ def make_random_bool() -> bool: return random.choice([True, False]) @staticmethod - def make_random_str(exclude_bytes: FrozenSet[str] = frozenset(), allow_empty_strings: bool = True) -> str: + def make_random_str( + exclude_bytes: FrozenSet[str] = frozenset(), + allow_empty_strings: bool = True, + max_length: int = 128 + ) -> str: if allow_empty_strings: min_length = 0 else: min_length = 1 - return ''.join(random.choices(list(STR_BYTES - exclude_bytes), k=random.randint(min_length, 128))) + return ''.join(random.choices(list(STR_BYTES - exclude_bytes), k=random.randint(min_length, max_length))) @staticmethod def make_random_non_container(exclude_bytes: FrozenSet[str] = frozenset(), allow_empty_strings: bool = True): @@ -253,3 +257,32 @@ def test_yaml_formatting(self): def test_plist_formatting(self): orig_obj = TestFormatting.make_random_obj(force_string_keys=True, exclude_bytes=frozenset('<>/\n&?|@{}[]')) return orig_obj, plistlib.dumps(orig_obj) + + @staticmethod + def make_random_flamegraph() -> str: + num_traces = random.randint(1, 500) + traces = [] + for _ in range(num_traces): + num_functions = random.randint(1, 32) + functions = [ + TestFormatting.make_random_str( + max_length=32, + exclude_bytes=frozenset({'\n', ' ', '\t', ';', '\r'}), + allow_empty_strings=False + ) + for _ in range(num_functions) + ] + stack_trace = ";".join(functions) + assert "\n" not in stack_trace + assert sum(1 for c in stack_trace if c == ";") == num_functions - 1 + num_samples = random.randint(1, 10000) + traces.append(f"{stack_trace} {num_samples}\n") + return "".join(traces) + + @filetype_test + def test_flamegraph_formatting(self): + orig_obj = TestFormatting.make_random_flamegraph() + num_spaces = sum(1 for c in orig_obj if c == " ") + num_newlines = sum(1 for c in orig_obj if c == "\n") + self.assertEqual(num_spaces, num_newlines) + return orig_obj, orig_obj From 88311b62c08ba1b3b368fbc603a103c1b0a2f6ce Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 7 Jan 2022 11:27:43 -0500 Subject: [PATCH 2/9] Custom FlameGraph edit type --- graphtage/__main__.py | 8 +++- graphtage/flamegraph.py | 87 ++++++++++++++++++++++++++++++++--------- graphtage/graphtage.py | 29 ++++---------- graphtage/tree.py | 35 +++++++++++++++++ test/test_formatting.py | 5 ++- 5 files changed, 121 insertions(+), 43 deletions(-) diff --git a/graphtage/__main__.py b/graphtage/__main__.py index f8964c7..b815160 100644 --- a/graphtage/__main__.py +++ b/graphtage/__main__.py @@ -283,8 +283,12 @@ def printer_type(*pos_args, **kwargs): with printer: with PathOrStdin(args.FROM_PATH) as from_path: with PathOrStdin(args.TO_PATH) as to_path: - from_format = graphtage.get_filetype(from_path, from_mime) - to_format = graphtage.get_filetype(to_path, to_mime) + try: + from_format = graphtage.get_filetype(from_path, from_mime) + to_format = graphtage.get_filetype(to_path, to_mime) + except ValueError as e: + log.error(str(e)) + sys.exit(1) from_tree = from_format.build_tree_handling_errors(from_path, options) if isinstance(from_tree, str): sys.stderr.write(from_tree) diff --git a/graphtage/flamegraph.py b/graphtage/flamegraph.py index 2900e68..50818ac 100644 --- a/graphtage/flamegraph.py +++ b/graphtage/flamegraph.py @@ -17,20 +17,24 @@ https://www.brendangregg.com/flamegraphs.html """ -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, Optional, Union, Iterator from . import Printer -from .edits import EditSequence -from .graphtage import BuildOptions, Edit, Filetype, IntegerNode, ListNode, MultiSetNode, StringNode, TreeNode -from .sequences import SequenceFormatter +from .edits import AbstractCompoundEdit, Match, Range, Replace +from .graphtage import ( + BuildOptions, ContainerNode, Edit, Filetype, IntegerNode, ListNode, MultiSetNode, StringNode, TreeNode +) from .tree import GraphtageFormatter +from .sequences import SequenceFormatter class FlameGraphParseError(ValueError): pass -class StackTrace(ListNode[StringNode]): +class StackTrace(ContainerNode): + """A stack trace and sample count""" + def __init__( self, functions: Iterable[StringNode], @@ -38,23 +42,70 @@ def __init__( allow_list_edits: bool = True, allow_list_edits_when_same_length: bool = True ): - super().__init__(functions, allow_list_edits, allow_list_edits_when_same_length) + """Initializes a stack trace. + + Args: + functions: the functions in the stack trace, in order. + samples: the number of times this stack trace was sampled in the profiling run. + """ + if samples.object < 0: + raise ValueError(f"Invalid number of samples: {samples.object}; the sample count must be non-negative") + self.functions: ListNode[StringNode] = ListNode( + functions, allow_list_edits, allow_list_edits_when_same_length + ) self.samples: IntegerNode = samples + def calculate_total_size(self) -> int: + return self.functions.calculate_total_size() + self.samples.calculate_total_size() + + def print(self, printer: Printer): + StackTraceFormatter.DEFAULT_INSTANCE.print(printer, self) + + def __iter__(self): + yield self.functions + yield self.samples + + def __len__(self) -> int: + return 2 + def to_obj(self): - return [n.to_obj() for n in self] + [self.samples] + return self.functions.to_obj() + [self.samples.to_obj()] def edits(self, node: TreeNode) -> Edit: - # first, match the functions: - edit = super().edits(node) - if not isinstance(node, StackTrace) or self.samples == node.samples: - return edit - # now match the samples: - return EditSequence( - from_node=self, - to_node=node, - edits=(edit, self.samples.edits(node.samples)) - ) + if self == node: + return Match(self, node, cost=0) + elif isinstance(node, StackTrace): + return StackTraceEdit(from_node=self, to_node=node) + else: + return Replace(self, node) + + def __str__(self): + return f"{';'.join((str(f.object) for f in self.functions))} {self.samples.object!s}" + + +class StackTraceEdit(AbstractCompoundEdit): + """An edit on a stack trace.""" + + def __init__(self, from_node: "StackTrace", to_node: "StackTrace"): + """Initializes a stack trace edit. + + Args: + from_node: The node being edited. + to_node: The node to which :obj:`from_node` will be transformed. + """ + self.functions_edit = from_node.functions.edits(to_node.functions) + self.samples_edit = from_node.samples.edits(to_node.samples) + super().__init__(from_node, to_node) + + def print(self, formatter: GraphtageFormatter, printer: Printer): + formatter.get_formatter(self.from_node)(printer, self.from_node) + + def bounds(self) -> Range: + return self.functions_edit.bounds() + self.samples_edit.bounds() + + def edits(self) -> Iterator[Edit]: + yield self.functions_edit + yield self.samples_edit class FlameGraph(MultiSetNode[StackTrace]): @@ -76,7 +127,7 @@ def print(self, printer: Printer, *args, **kwargs): super().print(printer, *args, **kwargs) def print_StackTrace(self, printer: Printer, node: StackTrace): - self.print_SequenceNode(printer, node) + self.print_SequenceNode(printer, node.functions) printer.write(" ") self.print(printer, node.samples) diff --git a/graphtage/graphtage.py b/graphtage/graphtage.py index 4323d35..1279688 100644 --- a/graphtage/graphtage.py +++ b/graphtage/graphtage.py @@ -237,27 +237,6 @@ def __lt__(self, other): return self.key < other return (self.key < other.key) or (self.key == other.key and self.value < other.value) - def __eq__(self, other): - """Tests whether this key/value pair equals another. - - Equivalent to:: - - isinstance(other, KeyValuePair) and self.key == other.key and self.value == other.value - - Args: - other: The object to test. - - Returns: - bool: :const:`True` if this key/value pair is equal to :obj:`other`. - - """ - if not isinstance(other, KeyValuePairNode): - return False - return self.key == other.key and self.value == other.value - - def __hash__(self): - return hash((self.key, self.value)) - def __len__(self): return 2 @@ -1031,6 +1010,14 @@ def get_filetype(path: Optional[str] = None, mime_type: Optional[str] = None) -> elif mime_type is None: mime_type = mimetypes.guess_type(path)[0] if mime_type is None: + # do non-MIME based filetype tests here + try: + ft = FILETYPES_BY_TYPENAME["flamegraph"] + _ = ft.build_tree(path) + # this is a valid flamegraph! + return ft + except: + pass raise ValueError(f"Could not determine the filetype for {path}") elif mime_type not in FILETYPES_BY_MIME: raise ValueError(f"Unsupported MIME type {mime_type} for {path}") diff --git a/graphtage/tree.py b/graphtage/tree.py index 5d9b2c8..dd28e38 100644 --- a/graphtage/tree.py +++ b/graphtage/tree.py @@ -192,6 +192,15 @@ def edits(self) -> Iterator[Edit]: """Returns an iterator over this edit's sub-edits""" raise NotImplementedError() + def is_complete(self) -> bool: + return all(e.is_complete() for e in self) + + def tighten_bounds(self) -> bool: + for e in self: + if e.tighten_bounds(): + return True + return False + def on_diff(self, from_node: 'EditedTreeNode'): """A callback for when an edit is assigned to an :class:`EditedTreeNode` in :meth:`TreeNode.diff`. @@ -546,6 +555,32 @@ def print(self, printer: Printer): class ContainerNode(TreeNode, Iterable, Sized, ABC): """A tree node that has children.""" + def __hash__(self): + """Hashes the contents of this container node. + + Equivalent to:: + + hash(tuple(self)) + + """ + return hash(tuple(self)) + + def __eq__(self, other): + """Tests whether this container node equals another. + + Equivalent to:: + + isinstance(other, ContainerNode) and all(a == b for a, b in zip(self, other)) + + Args: + other: The object to test. + + Returns: + bool: :const:`True` if this container node is equal to :obj:`other`. + + """ + return isinstance(other, ContainerNode) and all(a == b for a, b in zip(self, other)) + def children(self) -> List[TreeNode]: """The children of this node. diff --git a/test/test_formatting.py b/test/test_formatting.py index e018ffb..3bf9389 100644 --- a/test/test_formatting.py +++ b/test/test_formatting.py @@ -40,7 +40,8 @@ def wrapper(self: 'TestFormatting'): raise ValueError(f'@filetype_test {name} must end with "{FILETYPE_TEST_SUFFIX}"') filetype_name = name[len(FILETYPE_TEST_PREFIX):-len(FILETYPE_TEST_SUFFIX)] if filetype_name not in graphtage.FILETYPES_BY_TYPENAME: - raise ValueError(f'Filetype "{filetype_name}" for @filetype_test {name} not found in graphtage.FILETYPES_BY_TYPENAME') + raise ValueError(f'Filetype "{filetype_name}" for @filetype_test {name} not found in ' + 'graphtage.FILETYPES_BY_TYPENAME') filetype = graphtage.FILETYPES_BY_TYPENAME[filetype_name] formatter = filetype.get_default_formatter() @@ -279,7 +280,7 @@ def make_random_flamegraph() -> str: traces.append(f"{stack_trace} {num_samples}\n") return "".join(traces) - @filetype_test + @filetype_test(iterations=10) def test_flamegraph_formatting(self): orig_obj = TestFormatting.make_random_flamegraph() num_spaces = sum(1 for c in orig_obj if c == " ") From 98ba5a2830919d25c63b0cd5f0204176ff69d275 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 7 Jan 2022 15:45:18 -0500 Subject: [PATCH 3/9] Added Flame Graph info to the README --- README.md | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 763df42..6582103 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![Slack Status](https://empireslacking.herokuapp.com/badge.svg)](https://empireslacking.herokuapp.com) Graphtage is a command-line utility and [underlying library](https://trailofbits.github.io/graphtage/latest/library.html) -for semantically comparing and merging tree-like structures, such as JSON, XML, HTML, YAML, plist, and CSS files. Its name is a -portmanteau of “graph” and “graftage”—the latter being the horticultural practice of joining two trees together such -that they grow as one. +for semantically comparing and merging tree-like structures, such as JSON, XML, HTML, YAML, plist, CSS files, and +flame graphs. Its name is a portmanteau of “graph” and “graftage”—the latter being the horticultural practice of joining +two trees together such that they grow as one.

@@ -85,6 +85,32 @@ By default, Graphtage prints status messages and a progress bar to STDERR. To su option. To additionally suppress all but critical log messages, use `--quiet`. Fine-grained control of log messages is via the `--log-level` option. +### Specifying File Types + +By default, Graphtage makes a best-effort guess of the input file types based upon file extensions and, in some +cases, file contents. This is largely based off of the +[Python `mimetypes` library](https://docs.python.org/3/library/mimetypes.html#mimetypes.guess_type). + +The input files' mimetypes can be explicitly specified using the `--from-mime` and `--to-mime` arguments. + +#### Flame Graphs + +Graphtage has support for diffing +[flame graphs](https://trailofbits.github.io/graphtage/latest/graphtage.flamegraph.html). +This is useful to identify performance regressions between program refactors, _e.g._, when control flow is modified or +functions are added, removed, or renamed. + +There are many libraries in different languages to produce a flame graph from a profiling run. +There unfortunately isn't a standardized textual file format to represent flame graphs. +Graphtage uses this common format: +``` +function1 #samples +function1;function2 #samples +function1;function2;function3 #samples +``` +In other words, each line of the file is a stack trace represented by a ``;``-delimited list of function names +followed by a space and the integer number of times that stack trace was sampled in the profiling run. + ## Why does Graphtage exist? Diffing tree-like structures with unordered elements is tough. Say you want to compare two JSON files. From 3003391ee4d9660037fc9cbb23f2c2058bd029fd Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 7 Jan 2022 16:50:08 -0500 Subject: [PATCH 4/9] Print progress as the dense matching graph is built --- graphtage/matching.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/graphtage/matching.py b/graphtage/matching.py index 971f93a..27239c2 100644 --- a/graphtage/matching.py +++ b/graphtage/matching.py @@ -46,6 +46,7 @@ from .bounds import Bounded, make_distinct, Range, repeat_until_tightened from .bounds import sort as bounds_sort from .fibonacci import FibonacciHeap +from .printer import DEFAULT_PRINTER from .utils import smallest, largest @@ -613,7 +614,10 @@ def edges(self) -> List[List[Optional[Bounded]]]: """ if self._edges is None: self._edges = [ - [self.get_edge(from_node, to_node) for to_node in self.to_nodes] for from_node in self.from_nodes + [self.get_edge(from_node, to_node) for to_node in self.to_nodes] + for from_node in DEFAULT_PRINTER.tqdm( + self.from_nodes, leave=False, unit=" rows", desc="Building Matching Graph" + ) ] return self._edges From b84f138208e43eadde3273949298b4b0946b3766 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 10 Jan 2022 08:46:13 -0500 Subject: [PATCH 5/9] Class cleanup --- graphtage/flamegraph.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/graphtage/flamegraph.py b/graphtage/flamegraph.py index 50818ac..5aeb7bb 100644 --- a/graphtage/flamegraph.py +++ b/graphtage/flamegraph.py @@ -32,6 +32,27 @@ class FlameGraphParseError(ValueError): pass +class Samples(IntegerNode): + def __init__(self, num_samples: int, total_samples: int): + if total_samples <= 0: + raise ValueError("total_samples must be a positive integer") + elif num_samples < 0: + raise ValueError("num_samples must be non-negative") + super().__init__(num_samples) + self.total_samples: int = total_samples + + @property + def num_samples(self) -> int: + return self.object + + @property + def percent(self) -> float: + return self.num_samples / self.total_samples + + def calculate_total_size(self) -> int: + return int(self.percent * 100.0 + 0.5) + + class StackTrace(ContainerNode): """A stack trace and sample count""" @@ -61,6 +82,13 @@ def calculate_total_size(self) -> int: def print(self, printer: Printer): StackTraceFormatter.DEFAULT_INSTANCE.print(printer, self) + def __eq__(self, other): + """Two stack traces are the same if their functions exactly match (regardless of their sample count)""" + return isinstance(other, StackTrace) and self.functions == other.functions + + def __hash__(self): + return hash(self.functions) + def __iter__(self): yield self.functions yield self.samples @@ -72,7 +100,7 @@ def to_obj(self): return self.functions.to_obj() + [self.samples.to_obj()] def edits(self, node: TreeNode) -> Edit: - if self == node: + if self == node and self.samples == node.samples: return Match(self, node, cost=0) elif isinstance(node, StackTrace): return StackTraceEdit(from_node=self, to_node=node) @@ -82,6 +110,9 @@ def edits(self, node: TreeNode) -> Edit: def __str__(self): return f"{';'.join((str(f.object) for f in self.functions))} {self.samples.object!s}" + def __repr__(self): + return f"{self.__class__.__name__}({self.functions!r}, {self.samples!r})" + class StackTraceEdit(AbstractCompoundEdit): """An edit on a stack trace.""" @@ -172,13 +203,13 @@ def build_tree(self, path: str, options: Optional[BuildOptions] = None) -> Flame while ord('0') <= ord(line[-1]) <= ord('9'): final_int = f"{line[-1]}{final_int}" line = line[:-1] + if not line: + break if not final_int: raise FlameGraphParseError(f"{path}:{n+1} expected the line to end with an integer number of " "samples") samples = int(final_int) functions = line.strip().split(";") - if not functions: - raise FlameGraphParseError(f"{path}:{n+1} the line did not contain a stack trace") traces.append(StackTrace( functions=(StringNode(f, quoted=False) for f in functions), samples=IntegerNode(samples), From 3873ad236681c866f197845c20093f6714bda7cb Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 10 Jan 2022 08:50:45 -0500 Subject: [PATCH 6/9] Ensure that the caches are calculated in all cases --- graphtage/matching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graphtage/matching.py b/graphtage/matching.py index 27239c2..8d72eee 100644 --- a/graphtage/matching.py +++ b/graphtage/matching.py @@ -623,6 +623,8 @@ def edges(self) -> List[List[Optional[Bounded]]]: def bounds(self) -> Range: if self._bounds is None: + # first calculate and cache the edges to do equality optimizations: + _ = self.edges if not self.from_nodes or not self.to_nodes: lb = ub = 0 elif self._match is None: From 3ea29d180adf442bcadcc46a8c64098db684b9d6 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 10 Jan 2022 13:44:41 -0500 Subject: [PATCH 7/9] Print progress while making bounds distinct --- graphtage/bounds.py | 113 +++++++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 48 deletions(-) diff --git a/graphtage/bounds.py b/graphtage/bounds.py index ecfc622..9fa2306 100644 --- a/graphtage/bounds.py +++ b/graphtage/bounds.py @@ -30,6 +30,7 @@ from intervaltree import Interval, IntervalTree from .fibonacci import FibonacciHeap +from .printer import DEFAULT_PRINTER log = logging.getLogger(__name__) @@ -197,7 +198,7 @@ def definitive(self) -> bool: """ return self.lower_bound == self.upper_bound and not isinstance(self.lower_bound, Infinity) - def intersect(self, other) -> 'Range': + def intersect(self, other) -> "Range": """Intersects this range with another.""" if not self or not other or self < other or other < self: return Range() @@ -388,51 +389,67 @@ def make_distinct(*bounded: Bounded): if not b.bounds().finite: raise ValueError(f"Could not tighten {b!r} to a finite bound") tree.add(Interval(b.bounds().lower_bound, b.bounds().upper_bound + 1, b)) - while len(tree) > 1: - # find the biggest interval in the tree - biggest: Optional[Interval] = None - for m in tree: - m_size = m.end - m.begin - if biggest is None or m_size > biggest.end - biggest.begin: - biggest = m - assert biggest is not None - if biggest.data.bounds().definitive(): - # This means that all intervals are points, so we are done! - break - tree.remove(biggest) - matching = tree[biggest.begin:biggest.end] - if len(matching) < 1: - # This interval does not intersect any others, so it is distinct - continue - # now find the biggest other interval that intersects with biggest: - second_biggest: Optional[Interval] = None - for m in matching: - m_size = m.end - m.begin - if second_biggest is None or m_size > second_biggest.end - second_biggest.begin: - second_biggest = m - assert second_biggest is not None - tree.remove(second_biggest) - # Shrink the two biggest intervals until they are distinct - while True: - biggest_bound: Range = biggest.data.bounds() - second_biggest_bound: Range = second_biggest.data.bounds() - if (biggest_bound.definitive() and second_biggest_bound.definitive()) or \ - biggest_bound.upper_bound < second_biggest_bound.lower_bound or \ - second_biggest_bound.upper_bound < biggest_bound.lower_bound: + last_tree_len = len(tree) - 1 + with DEFAULT_PRINTER.tqdm( + desc="Making Bounds Distinct", unit=" bounds", leave=False, total=last_tree_len, delay=2.0) as d: + while len(tree) > 1: + remaining_nodes = len(tree) - 1 + if remaining_nodes < last_tree_len: + d.update(last_tree_len - remaining_nodes) + # find the biggest interval in the tree + biggest: Optional[Interval] = None + for m in tree: + m_size = m.end - m.begin + if biggest is None or m_size > biggest.end - biggest.begin: + biggest = m + assert biggest is not None + if biggest.data.bounds().definitive(): + # This means that all intervals are points, so we are done! break - biggest.data.tighten_bounds() - second_biggest.data.tighten_bounds() - new_interval = Interval( - begin=biggest.data.bounds().lower_bound, - end=biggest.data.bounds().upper_bound + 1, - data=biggest.data - ) - if tree.overlaps(new_interval.begin, new_interval.end): - tree.add(new_interval) - new_interval = Interval( - begin=second_biggest.data.bounds().lower_bound, - end=second_biggest.data.bounds().upper_bound + 1, - data=second_biggest.data - ) - if tree.overlaps(new_interval.begin, new_interval.end): - tree.add(new_interval) + tree.remove(biggest) + matching = tree[biggest.begin:biggest.end] + if len(matching) < 1: + # This interval does not intersect any others, so it is distinct + continue + # now find the biggest other interval that intersects with biggest: + second_biggest: Optional[Interval] = None + for m in matching: + m_size = m.end - m.begin + if second_biggest is None or m_size > second_biggest.end - second_biggest.begin: + second_biggest = m + assert second_biggest is not None + tree.remove(second_biggest) + # Shrink the two biggest intervals until they are distinct + with DEFAULT_PRINTER.tqdm(desc="Tightening Bounding Intervals", delay=2.0, leave=False, unit=" units") as t: + last_overlap: Optional[int] = None + while True: + biggest_bound: Range = biggest.data.bounds() + second_biggest_bound: Range = second_biggest.data.bounds() + if (biggest_bound.definitive() and second_biggest_bound.definitive()) or \ + biggest_bound.upper_bound < second_biggest_bound.lower_bound or \ + second_biggest_bound.upper_bound < biggest_bound.lower_bound: + break + # the ranges still overlap + overlap = min(biggest_bound.upper_bound, second_biggest_bound.upper_bound) - \ + max(biggest_bound.lower_bound, second_biggest_bound.lower_bound) + if last_overlap is None: + t.total = overlap + elif overlap < last_overlap: + t.update(last_overlap - overlap) + last_overlap = overlap + biggest.data.tighten_bounds() + second_biggest.data.tighten_bounds() + new_interval = Interval( + begin=biggest.data.bounds().lower_bound, + end=biggest.data.bounds().upper_bound + 1, + data=biggest.data + ) + if tree.overlaps(new_interval.begin, new_interval.end): + tree.add(new_interval) + new_interval = Interval( + begin=second_biggest.data.bounds().lower_bound, + end=second_biggest.data.bounds().upper_bound + 1, + data=second_biggest.data + ) + if tree.overlaps(new_interval.begin, new_interval.end): + tree.add(new_interval) From de282cef26030d2808049af1abf784997fee557a Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 10 Jan 2022 13:46:35 -0500 Subject: [PATCH 8/9] Adds a SetNode in addition to the MultiSet --- graphtage/graphtage.py | 57 ++++++++++++++++++++++++++++-- graphtage/multiset.py | 80 +++++++++++++++++++++++++++++------------- 2 files changed, 110 insertions(+), 27 deletions(-) diff --git a/graphtage/graphtage.py b/graphtage/graphtage.py index 1279688..b332bb8 100644 --- a/graphtage/graphtage.py +++ b/graphtage/graphtage.py @@ -2,13 +2,14 @@ import mimetypes from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Collection, Dict, Generic, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union +from collections import OrderedDict +from typing import Any, Collection, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union from .bounds import Range from .edits import AbstractEdit, EditCollection from .edits import Insert, Match, Remove, Replace, AbstractCompoundEdit from .levenshtein import EditDistance, levenshtein_distance -from .multiset import MultiSetEdit +from .multiset import MultiSetEdit, SetEdit from .printer import Back, Fore, NullANSIContext, Printer from .sequences import FixedLengthSequenceEdit, SequenceEdit, SequenceNode from .tree import ContainerNode, Edit, GraphtageFormatter, TreeNode @@ -314,6 +315,58 @@ def edits(self, node: TreeNode) -> Edit: return Replace(self, node) +class SetNode(SequenceNode[List[T]], Generic[T]): + def __init__(self, nodes: Iterable[T] = ()): + if isinstance(nodes, set): + super().__init__(list(nodes)) + self._node_set: Set[T] = nodes + else: + ordered_nodes = list(OrderedDict.fromkeys(nodes)) + self._node_set = set(ordered_nodes) + super().__init__(ordered_nodes) + + @property + def container_type(self) -> Type[List[T]]: + return list + + def to_obj(self): + return list(self) + + def add(self, node: T): + if node not in self._node_set: + self._node_set.add(node) + self._children.append(node) + + def __xor__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set ^ other._node_set) + + def __and__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set & other._node_set) + + def __or__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set | other._node_set) + + def __contains__(self, node: T): + return node in self._node_set + + def edits(self, node: 'TreeNode') -> Edit: + if isinstance(node, SetNode): + if len(self._children) == len(node._children) == 0: + return Match(self, node, 0) + elif self._node_set == node._node_set: + return Match(self, node, 0) + else: + return SetEdit(self, node, self._node_set, node._node_set) + else: + return Replace(self, node) + + class MultiSetNode(SequenceNode[HashableCounter[T]], Generic[T]): """A node representing a set that can contain duplicate items.""" diff --git a/graphtage/multiset.py b/graphtage/multiset.py index e855fd1..ca3472b 100644 --- a/graphtage/multiset.py +++ b/graphtage/multiset.py @@ -5,7 +5,8 @@ """ -from typing import Iterator, List +from abc import ABC +from typing import Collection, Generic, Iterable, Iterator, List, Set, TypeVar from .bounds import Range from .edits import Insert, Match, Remove @@ -15,19 +16,23 @@ from .utils import HashableCounter, largest -class MultiSetEdit(SequenceEdit): +T = TypeVar("T", bound=Collection[TreeNode]) + + +class AbstractSetEdit(SequenceEdit, Generic[T], ABC): """An edit matching one unordered collection of items to another. - It works by using a :class:`graphtage.matching.WeightedBipartiteMatcher` to find the minimum cost matching from - the elements of one collection to the elements of the other. + It works by using a :class:`graphtage.matching.WeightedBipartiteMatcher` to find the minimum cost matching from + the elements of one collection to the elements of the other. + + """ - """ def __init__( self, from_node: SequenceNode, to_node: SequenceNode, - from_set: HashableCounter[TreeNode], - to_set: HashableCounter[TreeNode] + from_set: T, + to_set: T ): """Initializes the edit. @@ -40,15 +45,15 @@ def __init__( is neither checked nor enforced. """ - self.to_insert = to_set - from_set + self.to_insert: T = to_set - from_set """The set of nodes in :obj:`to_set` that do not exist in :obj:`from_set`.""" - self.to_remove = from_set - to_set + self.to_remove: T = from_set - to_set """The set of nodes in :obj:`from_set` that do not exist in :obj:`to_set`.""" to_match = from_set & to_set - self._edits: List[Edit] = [Match(n, n, 0) for n in to_match.elements()] + self._edits: List[Edit] = [Match(n, n, 0) for n in self.__class__.get_elements(to_match)] self._matcher = WeightedBipartiteMatcher( - from_nodes=self.to_remove.elements(), - to_nodes=self.to_insert.elements(), + from_nodes=self.__class__.get_elements(self.to_remove), + to_nodes=self.__class__.get_elements(self.to_insert), get_edge=lambda f, t: f.edits(t) ) super().__init__( @@ -56,22 +61,13 @@ def __init__( to_node=to_node ) + @classmethod + def get_elements(cls, collection: T) -> Iterable[TreeNode]: + return collection + def is_complete(self) -> bool: return self._matcher.is_complete() - def edits(self) -> Iterator[Edit]: - yield from self._edits - remove_matched: HashableCounter[TreeNode] = HashableCounter() - insert_matched: HashableCounter[TreeNode] = HashableCounter() - for (rem, (ins, edit)) in self._matcher.matching.items(): - yield edit - remove_matched[rem] += 1 - insert_matched[ins] += 1 - for rm in (self.to_remove - remove_matched).elements(): - yield Remove(to_remove=rm, remove_from=self.from_node) - for ins in (self.to_insert - insert_matched).elements(): - yield Insert(to_insert=ins, insert_into=self.from_node) - def tighten_bounds(self) -> bool: """Delegates to :meth:`WeightedBipartiteMatcher.tighten_bounds`.""" return self._matcher.tighten_bounds() @@ -93,3 +89,37 @@ def bounds(self) -> Range: ): b = b + edit.bounds() return b + + +class SetEdit(AbstractSetEdit[Set[TreeNode]]): + def edits(self) -> Iterator[Edit]: + yield from self._edits + remove_matched: Set[TreeNode] = set() + insert_matched: Set[TreeNode] = set() + for (rem, (ins, edit)) in self._matcher.matching.items(): + yield edit + remove_matched.add(rem) + insert_matched.add(ins) + for rm in (self.to_remove - remove_matched): + yield Remove(to_remove=rm, remove_from=self.from_node) + for ins in (self.to_insert - insert_matched): + yield Insert(to_insert=ins, insert_into=self.from_node) + + +class MultiSetEdit(AbstractSetEdit[HashableCounter[TreeNode]]): + @classmethod + def get_elements(cls, collection: HashableCounter[TreeNode]) -> Iterable[TreeNode]: + return collection.elements() + + def edits(self) -> Iterator[Edit]: + yield from self._edits + remove_matched: HashableCounter[TreeNode] = HashableCounter() + insert_matched: HashableCounter[TreeNode] = HashableCounter() + for (rem, (ins, edit)) in self._matcher.matching.items(): + yield edit + remove_matched[rem] += 1 + insert_matched[ins] += 1 + for rm in (self.to_remove - remove_matched).elements(): + yield Remove(to_remove=rm, remove_from=self.from_node) + for ins in (self.to_insert - insert_matched).elements(): + yield Insert(to_insert=ins, insert_into=self.from_node) From 497a72cc27f2630755aa7f12304a7b52366d8392 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 10 Jan 2022 13:54:31 -0500 Subject: [PATCH 9/9] Make SequenceNodes subscriptable --- graphtage/sequences.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/graphtage/sequences.py b/graphtage/sequences.py index 7fcdde4..e138270 100644 --- a/graphtage/sequences.py +++ b/graphtage/sequences.py @@ -143,6 +143,9 @@ def __len__(self) -> int: """ return len(self._children) + def __getitem__(self, index) -> TreeNode: + return self._children[index] + def __iter__(self) -> Iterator[TreeNode]: """Iterates over this sequence's child nodes.