diff --git a/README.md b/README.md index 763df42..6582103 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![Slack Status](https://empireslacking.herokuapp.com/badge.svg)](https://empireslacking.herokuapp.com) Graphtage is a command-line utility and [underlying library](https://trailofbits.github.io/graphtage/latest/library.html) -for semantically comparing and merging tree-like structures, such as JSON, XML, HTML, YAML, plist, and CSS files. Its name is a -portmanteau of “graph” and “graftage”—the latter being the horticultural practice of joining two trees together such -that they grow as one. +for semantically comparing and merging tree-like structures, such as JSON, XML, HTML, YAML, plist, CSS files, and +flame graphs. Its name is a portmanteau of “graph” and “graftage”—the latter being the horticultural practice of joining +two trees together such that they grow as one.

@@ -85,6 +85,32 @@ By default, Graphtage prints status messages and a progress bar to STDERR. To su option. To additionally suppress all but critical log messages, use `--quiet`. Fine-grained control of log messages is via the `--log-level` option. +### Specifying File Types + +By default, Graphtage makes a best-effort guess of the input file types based upon file extensions and, in some +cases, file contents. This is largely based off of the +[Python `mimetypes` library](https://docs.python.org/3/library/mimetypes.html#mimetypes.guess_type). + +The input files' mimetypes can be explicitly specified using the `--from-mime` and `--to-mime` arguments. + +#### Flame Graphs + +Graphtage has support for diffing +[flame graphs](https://trailofbits.github.io/graphtage/latest/graphtage.flamegraph.html). +This is useful to identify performance regressions between program refactors, _e.g._, when control flow is modified or +functions are added, removed, or renamed. + +There are many libraries in different languages to produce a flame graph from a profiling run. +There unfortunately isn't a standardized textual file format to represent flame graphs. +Graphtage uses this common format: +``` +function1 #samples +function1;function2 #samples +function1;function2;function3 #samples +``` +In other words, each line of the file is a stack trace represented by a ``;``-delimited list of function names +followed by a space and the integer number of times that stack trace was sampled in the profiling run. + ## Why does Graphtage exist? Diffing tree-like structures with unordered elements is tough. Say you want to compare two JSON files. diff --git a/graphtage/__init__.py b/graphtage/__init__.py index 3bae418..7d9eb48 100644 --- a/graphtage/__init__.py +++ b/graphtage/__init__.py @@ -7,7 +7,7 @@ from .version import __version__, VERSION_STRING from . import bounds, edits, expressions, fibonacci, formatter, levenshtein, matching, printer, \ search, sequences, tree, utils -from . import csv, json, xml, yaml, plist +from . import csv, json, xml, yaml, plist, flamegraph import inspect diff --git a/graphtage/__main__.py b/graphtage/__main__.py index f8964c7..b815160 100644 --- a/graphtage/__main__.py +++ b/graphtage/__main__.py @@ -283,8 +283,12 @@ def printer_type(*pos_args, **kwargs): with printer: with PathOrStdin(args.FROM_PATH) as from_path: with PathOrStdin(args.TO_PATH) as to_path: - from_format = graphtage.get_filetype(from_path, from_mime) - to_format = graphtage.get_filetype(to_path, to_mime) + try: + from_format = graphtage.get_filetype(from_path, from_mime) + to_format = graphtage.get_filetype(to_path, to_mime) + except ValueError as e: + log.error(str(e)) + sys.exit(1) from_tree = from_format.build_tree_handling_errors(from_path, options) if isinstance(from_tree, str): sys.stderr.write(from_tree) diff --git a/graphtage/bounds.py b/graphtage/bounds.py index ecfc622..9fa2306 100644 --- a/graphtage/bounds.py +++ b/graphtage/bounds.py @@ -30,6 +30,7 @@ from intervaltree import Interval, IntervalTree from .fibonacci import FibonacciHeap +from .printer import DEFAULT_PRINTER log = logging.getLogger(__name__) @@ -197,7 +198,7 @@ def definitive(self) -> bool: """ return self.lower_bound == self.upper_bound and not isinstance(self.lower_bound, Infinity) - def intersect(self, other) -> 'Range': + def intersect(self, other) -> "Range": """Intersects this range with another.""" if not self or not other or self < other or other < self: return Range() @@ -388,51 +389,67 @@ def make_distinct(*bounded: Bounded): if not b.bounds().finite: raise ValueError(f"Could not tighten {b!r} to a finite bound") tree.add(Interval(b.bounds().lower_bound, b.bounds().upper_bound + 1, b)) - while len(tree) > 1: - # find the biggest interval in the tree - biggest: Optional[Interval] = None - for m in tree: - m_size = m.end - m.begin - if biggest is None or m_size > biggest.end - biggest.begin: - biggest = m - assert biggest is not None - if biggest.data.bounds().definitive(): - # This means that all intervals are points, so we are done! - break - tree.remove(biggest) - matching = tree[biggest.begin:biggest.end] - if len(matching) < 1: - # This interval does not intersect any others, so it is distinct - continue - # now find the biggest other interval that intersects with biggest: - second_biggest: Optional[Interval] = None - for m in matching: - m_size = m.end - m.begin - if second_biggest is None or m_size > second_biggest.end - second_biggest.begin: - second_biggest = m - assert second_biggest is not None - tree.remove(second_biggest) - # Shrink the two biggest intervals until they are distinct - while True: - biggest_bound: Range = biggest.data.bounds() - second_biggest_bound: Range = second_biggest.data.bounds() - if (biggest_bound.definitive() and second_biggest_bound.definitive()) or \ - biggest_bound.upper_bound < second_biggest_bound.lower_bound or \ - second_biggest_bound.upper_bound < biggest_bound.lower_bound: + last_tree_len = len(tree) - 1 + with DEFAULT_PRINTER.tqdm( + desc="Making Bounds Distinct", unit=" bounds", leave=False, total=last_tree_len, delay=2.0) as d: + while len(tree) > 1: + remaining_nodes = len(tree) - 1 + if remaining_nodes < last_tree_len: + d.update(last_tree_len - remaining_nodes) + # find the biggest interval in the tree + biggest: Optional[Interval] = None + for m in tree: + m_size = m.end - m.begin + if biggest is None or m_size > biggest.end - biggest.begin: + biggest = m + assert biggest is not None + if biggest.data.bounds().definitive(): + # This means that all intervals are points, so we are done! break - biggest.data.tighten_bounds() - second_biggest.data.tighten_bounds() - new_interval = Interval( - begin=biggest.data.bounds().lower_bound, - end=biggest.data.bounds().upper_bound + 1, - data=biggest.data - ) - if tree.overlaps(new_interval.begin, new_interval.end): - tree.add(new_interval) - new_interval = Interval( - begin=second_biggest.data.bounds().lower_bound, - end=second_biggest.data.bounds().upper_bound + 1, - data=second_biggest.data - ) - if tree.overlaps(new_interval.begin, new_interval.end): - tree.add(new_interval) + tree.remove(biggest) + matching = tree[biggest.begin:biggest.end] + if len(matching) < 1: + # This interval does not intersect any others, so it is distinct + continue + # now find the biggest other interval that intersects with biggest: + second_biggest: Optional[Interval] = None + for m in matching: + m_size = m.end - m.begin + if second_biggest is None or m_size > second_biggest.end - second_biggest.begin: + second_biggest = m + assert second_biggest is not None + tree.remove(second_biggest) + # Shrink the two biggest intervals until they are distinct + with DEFAULT_PRINTER.tqdm(desc="Tightening Bounding Intervals", delay=2.0, leave=False, unit=" units") as t: + last_overlap: Optional[int] = None + while True: + biggest_bound: Range = biggest.data.bounds() + second_biggest_bound: Range = second_biggest.data.bounds() + if (biggest_bound.definitive() and second_biggest_bound.definitive()) or \ + biggest_bound.upper_bound < second_biggest_bound.lower_bound or \ + second_biggest_bound.upper_bound < biggest_bound.lower_bound: + break + # the ranges still overlap + overlap = min(biggest_bound.upper_bound, second_biggest_bound.upper_bound) - \ + max(biggest_bound.lower_bound, second_biggest_bound.lower_bound) + if last_overlap is None: + t.total = overlap + elif overlap < last_overlap: + t.update(last_overlap - overlap) + last_overlap = overlap + biggest.data.tighten_bounds() + second_biggest.data.tighten_bounds() + new_interval = Interval( + begin=biggest.data.bounds().lower_bound, + end=biggest.data.bounds().upper_bound + 1, + data=biggest.data + ) + if tree.overlaps(new_interval.begin, new_interval.end): + tree.add(new_interval) + new_interval = Interval( + begin=second_biggest.data.bounds().lower_bound, + end=second_biggest.data.bounds().upper_bound + 1, + data=second_biggest.data + ) + if tree.overlaps(new_interval.begin, new_interval.end): + tree.add(new_interval) diff --git a/graphtage/flamegraph.py b/graphtage/flamegraph.py new file mode 100644 index 0000000..5aeb7bb --- /dev/null +++ b/graphtage/flamegraph.py @@ -0,0 +1,228 @@ +"""A :class:`graphtage.Filetype` for parsing, diffing, and rendering `flame graphs`_. + +There are many libraries in different languages to produce a flame graph from a profiling run. +There unfortunately isn't a standardized textual file format to represent flame graphs. +Graphtage uses this common format: + +.. code-block:: none + + function1 #samples + function1;function2 #samples + function1;function2;function3 #samples + +In other words, each line of the file is a stack trace represented by a ``;``-delimited list of function names +followed by a space and the integer number of times that stack trace was sampled in the profiling run. + +.. _flame graphs: + https://www.brendangregg.com/flamegraphs.html +""" + +from typing import Iterable, List, Optional, Union, Iterator + +from . import Printer +from .edits import AbstractCompoundEdit, Match, Range, Replace +from .graphtage import ( + BuildOptions, ContainerNode, Edit, Filetype, IntegerNode, ListNode, MultiSetNode, StringNode, TreeNode +) +from .tree import GraphtageFormatter +from .sequences import SequenceFormatter + + +class FlameGraphParseError(ValueError): + pass + + +class Samples(IntegerNode): + def __init__(self, num_samples: int, total_samples: int): + if total_samples <= 0: + raise ValueError("total_samples must be a positive integer") + elif num_samples < 0: + raise ValueError("num_samples must be non-negative") + super().__init__(num_samples) + self.total_samples: int = total_samples + + @property + def num_samples(self) -> int: + return self.object + + @property + def percent(self) -> float: + return self.num_samples / self.total_samples + + def calculate_total_size(self) -> int: + return int(self.percent * 100.0 + 0.5) + + +class StackTrace(ContainerNode): + """A stack trace and sample count""" + + def __init__( + self, + functions: Iterable[StringNode], + samples: IntegerNode, + allow_list_edits: bool = True, + allow_list_edits_when_same_length: bool = True + ): + """Initializes a stack trace. + + Args: + functions: the functions in the stack trace, in order. + samples: the number of times this stack trace was sampled in the profiling run. + """ + if samples.object < 0: + raise ValueError(f"Invalid number of samples: {samples.object}; the sample count must be non-negative") + self.functions: ListNode[StringNode] = ListNode( + functions, allow_list_edits, allow_list_edits_when_same_length + ) + self.samples: IntegerNode = samples + + def calculate_total_size(self) -> int: + return self.functions.calculate_total_size() + self.samples.calculate_total_size() + + def print(self, printer: Printer): + StackTraceFormatter.DEFAULT_INSTANCE.print(printer, self) + + def __eq__(self, other): + """Two stack traces are the same if their functions exactly match (regardless of their sample count)""" + return isinstance(other, StackTrace) and self.functions == other.functions + + def __hash__(self): + return hash(self.functions) + + def __iter__(self): + yield self.functions + yield self.samples + + def __len__(self) -> int: + return 2 + + def to_obj(self): + return self.functions.to_obj() + [self.samples.to_obj()] + + def edits(self, node: TreeNode) -> Edit: + if self == node and self.samples == node.samples: + return Match(self, node, cost=0) + elif isinstance(node, StackTrace): + return StackTraceEdit(from_node=self, to_node=node) + else: + return Replace(self, node) + + def __str__(self): + return f"{';'.join((str(f.object) for f in self.functions))} {self.samples.object!s}" + + def __repr__(self): + return f"{self.__class__.__name__}({self.functions!r}, {self.samples!r})" + + +class StackTraceEdit(AbstractCompoundEdit): + """An edit on a stack trace.""" + + def __init__(self, from_node: "StackTrace", to_node: "StackTrace"): + """Initializes a stack trace edit. + + Args: + from_node: The node being edited. + to_node: The node to which :obj:`from_node` will be transformed. + """ + self.functions_edit = from_node.functions.edits(to_node.functions) + self.samples_edit = from_node.samples.edits(to_node.samples) + super().__init__(from_node, to_node) + + def print(self, formatter: GraphtageFormatter, printer: Printer): + formatter.get_formatter(self.from_node)(printer, self.from_node) + + def bounds(self) -> Range: + return self.functions_edit.bounds() + self.samples_edit.bounds() + + def edits(self) -> Iterator[Edit]: + yield self.functions_edit + yield self.samples_edit + + +class FlameGraph(MultiSetNode[StackTrace]): + pass + + +class StackTraceFormatter(SequenceFormatter): + is_partial = True + + def __init__(self): + super().__init__('', '', ';') + + def item_newline(self, printer: Printer, is_first: bool = False, is_last: bool = False): + pass + + def print(self, printer: Printer, *args, **kwargs): + # Flamegraphs are not indented + printer.indent_str = "" + super().print(printer, *args, **kwargs) + + def print_StackTrace(self, printer: Printer, node: StackTrace): + self.print_SequenceNode(printer, node.functions) + printer.write(" ") + self.print(printer, node.samples) + + +class FlameGraphFormatter(SequenceFormatter): + sub_format_types = [StackTraceFormatter] + + def __init__(self): + super().__init__('', '', '') + + def print(self, printer: Printer, *args, **kwargs): + # Flamegraphs are not indented + printer.indent_str = "" + super().print(printer, *args, **kwargs) + + +class FlameGraphFile(Filetype): + """A textual representation of a flame graph.""" + def __init__(self): + """Initializes the FlameGraph file type. + + There is no official MIME type associated with a flame graph. Graphtage assigns it the MIME type + ``text/x-flame-graph``. + + """ + super().__init__( + 'flamegraph', + 'text/x-flame-graph' + ) + + def build_tree(self, path: str, options: Optional[BuildOptions] = None) -> FlameGraph: + traces: List[StackTrace] = [] + allow_list_edits = options is None or options.allow_list_edits + allow_list_edits_when_same_length = options is None or options.allow_list_edits_when_same_length + with open(path, "r") as f: + for n, line in enumerate(f): + line = line.strip() + if not line: + continue + # first parse the int off the end: + final_int = "" + while ord('0') <= ord(line[-1]) <= ord('9'): + final_int = f"{line[-1]}{final_int}" + line = line[:-1] + if not line: + break + if not final_int: + raise FlameGraphParseError(f"{path}:{n+1} expected the line to end with an integer number of " + "samples") + samples = int(final_int) + functions = line.strip().split(";") + traces.append(StackTrace( + functions=(StringNode(f, quoted=False) for f in functions), + samples=IntegerNode(samples), + allow_list_edits=allow_list_edits, + allow_list_edits_when_same_length=allow_list_edits_when_same_length + )) + return FlameGraph(traces) + + def build_tree_handling_errors(self, path: str, options: Optional[BuildOptions] = None) -> Union[str, TreeNode]: + try: + return self.build_tree(path=path, options=options) + except FlameGraphParseError as e: + return str(e) + + def get_default_formatter(self) -> FlameGraphFormatter: + return FlameGraphFormatter.DEFAULT_INSTANCE diff --git a/graphtage/graphtage.py b/graphtage/graphtage.py index 4323d35..b332bb8 100644 --- a/graphtage/graphtage.py +++ b/graphtage/graphtage.py @@ -2,13 +2,14 @@ import mimetypes from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Collection, Dict, Generic, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union +from collections import OrderedDict +from typing import Any, Collection, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union from .bounds import Range from .edits import AbstractEdit, EditCollection from .edits import Insert, Match, Remove, Replace, AbstractCompoundEdit from .levenshtein import EditDistance, levenshtein_distance -from .multiset import MultiSetEdit +from .multiset import MultiSetEdit, SetEdit from .printer import Back, Fore, NullANSIContext, Printer from .sequences import FixedLengthSequenceEdit, SequenceEdit, SequenceNode from .tree import ContainerNode, Edit, GraphtageFormatter, TreeNode @@ -237,27 +238,6 @@ def __lt__(self, other): return self.key < other return (self.key < other.key) or (self.key == other.key and self.value < other.value) - def __eq__(self, other): - """Tests whether this key/value pair equals another. - - Equivalent to:: - - isinstance(other, KeyValuePair) and self.key == other.key and self.value == other.value - - Args: - other: The object to test. - - Returns: - bool: :const:`True` if this key/value pair is equal to :obj:`other`. - - """ - if not isinstance(other, KeyValuePairNode): - return False - return self.key == other.key and self.value == other.value - - def __hash__(self): - return hash((self.key, self.value)) - def __len__(self): return 2 @@ -335,6 +315,58 @@ def edits(self, node: TreeNode) -> Edit: return Replace(self, node) +class SetNode(SequenceNode[List[T]], Generic[T]): + def __init__(self, nodes: Iterable[T] = ()): + if isinstance(nodes, set): + super().__init__(list(nodes)) + self._node_set: Set[T] = nodes + else: + ordered_nodes = list(OrderedDict.fromkeys(nodes)) + self._node_set = set(ordered_nodes) + super().__init__(ordered_nodes) + + @property + def container_type(self) -> Type[List[T]]: + return list + + def to_obj(self): + return list(self) + + def add(self, node: T): + if node not in self._node_set: + self._node_set.add(node) + self._children.append(node) + + def __xor__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set ^ other._node_set) + + def __and__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set & other._node_set) + + def __or__(self, other) -> "SetNode[T]": + if not isinstance(other, SetNode): + raise NotImplementedError() + return SetNode(self._node_set | other._node_set) + + def __contains__(self, node: T): + return node in self._node_set + + def edits(self, node: 'TreeNode') -> Edit: + if isinstance(node, SetNode): + if len(self._children) == len(node._children) == 0: + return Match(self, node, 0) + elif self._node_set == node._node_set: + return Match(self, node, 0) + else: + return SetEdit(self, node, self._node_set, node._node_set) + else: + return Replace(self, node) + + class MultiSetNode(SequenceNode[HashableCounter[T]], Generic[T]): """A node representing a set that can contain duplicate items.""" @@ -1031,6 +1063,14 @@ def get_filetype(path: Optional[str] = None, mime_type: Optional[str] = None) -> elif mime_type is None: mime_type = mimetypes.guess_type(path)[0] if mime_type is None: + # do non-MIME based filetype tests here + try: + ft = FILETYPES_BY_TYPENAME["flamegraph"] + _ = ft.build_tree(path) + # this is a valid flamegraph! + return ft + except: + pass raise ValueError(f"Could not determine the filetype for {path}") elif mime_type not in FILETYPES_BY_MIME: raise ValueError(f"Unsupported MIME type {mime_type} for {path}") diff --git a/graphtage/matching.py b/graphtage/matching.py index 971f93a..8d72eee 100644 --- a/graphtage/matching.py +++ b/graphtage/matching.py @@ -46,6 +46,7 @@ from .bounds import Bounded, make_distinct, Range, repeat_until_tightened from .bounds import sort as bounds_sort from .fibonacci import FibonacciHeap +from .printer import DEFAULT_PRINTER from .utils import smallest, largest @@ -613,12 +614,17 @@ def edges(self) -> List[List[Optional[Bounded]]]: """ if self._edges is None: self._edges = [ - [self.get_edge(from_node, to_node) for to_node in self.to_nodes] for from_node in self.from_nodes + [self.get_edge(from_node, to_node) for to_node in self.to_nodes] + for from_node in DEFAULT_PRINTER.tqdm( + self.from_nodes, leave=False, unit=" rows", desc="Building Matching Graph" + ) ] return self._edges def bounds(self) -> Range: if self._bounds is None: + # first calculate and cache the edges to do equality optimizations: + _ = self.edges if not self.from_nodes or not self.to_nodes: lb = ub = 0 elif self._match is None: diff --git a/graphtage/multiset.py b/graphtage/multiset.py index e855fd1..ca3472b 100644 --- a/graphtage/multiset.py +++ b/graphtage/multiset.py @@ -5,7 +5,8 @@ """ -from typing import Iterator, List +from abc import ABC +from typing import Collection, Generic, Iterable, Iterator, List, Set, TypeVar from .bounds import Range from .edits import Insert, Match, Remove @@ -15,19 +16,23 @@ from .utils import HashableCounter, largest -class MultiSetEdit(SequenceEdit): +T = TypeVar("T", bound=Collection[TreeNode]) + + +class AbstractSetEdit(SequenceEdit, Generic[T], ABC): """An edit matching one unordered collection of items to another. - It works by using a :class:`graphtage.matching.WeightedBipartiteMatcher` to find the minimum cost matching from - the elements of one collection to the elements of the other. + It works by using a :class:`graphtage.matching.WeightedBipartiteMatcher` to find the minimum cost matching from + the elements of one collection to the elements of the other. + + """ - """ def __init__( self, from_node: SequenceNode, to_node: SequenceNode, - from_set: HashableCounter[TreeNode], - to_set: HashableCounter[TreeNode] + from_set: T, + to_set: T ): """Initializes the edit. @@ -40,15 +45,15 @@ def __init__( is neither checked nor enforced. """ - self.to_insert = to_set - from_set + self.to_insert: T = to_set - from_set """The set of nodes in :obj:`to_set` that do not exist in :obj:`from_set`.""" - self.to_remove = from_set - to_set + self.to_remove: T = from_set - to_set """The set of nodes in :obj:`from_set` that do not exist in :obj:`to_set`.""" to_match = from_set & to_set - self._edits: List[Edit] = [Match(n, n, 0) for n in to_match.elements()] + self._edits: List[Edit] = [Match(n, n, 0) for n in self.__class__.get_elements(to_match)] self._matcher = WeightedBipartiteMatcher( - from_nodes=self.to_remove.elements(), - to_nodes=self.to_insert.elements(), + from_nodes=self.__class__.get_elements(self.to_remove), + to_nodes=self.__class__.get_elements(self.to_insert), get_edge=lambda f, t: f.edits(t) ) super().__init__( @@ -56,22 +61,13 @@ def __init__( to_node=to_node ) + @classmethod + def get_elements(cls, collection: T) -> Iterable[TreeNode]: + return collection + def is_complete(self) -> bool: return self._matcher.is_complete() - def edits(self) -> Iterator[Edit]: - yield from self._edits - remove_matched: HashableCounter[TreeNode] = HashableCounter() - insert_matched: HashableCounter[TreeNode] = HashableCounter() - for (rem, (ins, edit)) in self._matcher.matching.items(): - yield edit - remove_matched[rem] += 1 - insert_matched[ins] += 1 - for rm in (self.to_remove - remove_matched).elements(): - yield Remove(to_remove=rm, remove_from=self.from_node) - for ins in (self.to_insert - insert_matched).elements(): - yield Insert(to_insert=ins, insert_into=self.from_node) - def tighten_bounds(self) -> bool: """Delegates to :meth:`WeightedBipartiteMatcher.tighten_bounds`.""" return self._matcher.tighten_bounds() @@ -93,3 +89,37 @@ def bounds(self) -> Range: ): b = b + edit.bounds() return b + + +class SetEdit(AbstractSetEdit[Set[TreeNode]]): + def edits(self) -> Iterator[Edit]: + yield from self._edits + remove_matched: Set[TreeNode] = set() + insert_matched: Set[TreeNode] = set() + for (rem, (ins, edit)) in self._matcher.matching.items(): + yield edit + remove_matched.add(rem) + insert_matched.add(ins) + for rm in (self.to_remove - remove_matched): + yield Remove(to_remove=rm, remove_from=self.from_node) + for ins in (self.to_insert - insert_matched): + yield Insert(to_insert=ins, insert_into=self.from_node) + + +class MultiSetEdit(AbstractSetEdit[HashableCounter[TreeNode]]): + @classmethod + def get_elements(cls, collection: HashableCounter[TreeNode]) -> Iterable[TreeNode]: + return collection.elements() + + def edits(self) -> Iterator[Edit]: + yield from self._edits + remove_matched: HashableCounter[TreeNode] = HashableCounter() + insert_matched: HashableCounter[TreeNode] = HashableCounter() + for (rem, (ins, edit)) in self._matcher.matching.items(): + yield edit + remove_matched[rem] += 1 + insert_matched[ins] += 1 + for rm in (self.to_remove - remove_matched).elements(): + yield Remove(to_remove=rm, remove_from=self.from_node) + for ins in (self.to_insert - insert_matched).elements(): + yield Insert(to_insert=ins, insert_into=self.from_node) diff --git a/graphtage/sequences.py b/graphtage/sequences.py index 7fcdde4..e138270 100644 --- a/graphtage/sequences.py +++ b/graphtage/sequences.py @@ -143,6 +143,9 @@ def __len__(self) -> int: """ return len(self._children) + def __getitem__(self, index) -> TreeNode: + return self._children[index] + def __iter__(self) -> Iterator[TreeNode]: """Iterates over this sequence's child nodes. diff --git a/graphtage/tree.py b/graphtage/tree.py index 5d9b2c8..dd28e38 100644 --- a/graphtage/tree.py +++ b/graphtage/tree.py @@ -192,6 +192,15 @@ def edits(self) -> Iterator[Edit]: """Returns an iterator over this edit's sub-edits""" raise NotImplementedError() + def is_complete(self) -> bool: + return all(e.is_complete() for e in self) + + def tighten_bounds(self) -> bool: + for e in self: + if e.tighten_bounds(): + return True + return False + def on_diff(self, from_node: 'EditedTreeNode'): """A callback for when an edit is assigned to an :class:`EditedTreeNode` in :meth:`TreeNode.diff`. @@ -546,6 +555,32 @@ def print(self, printer: Printer): class ContainerNode(TreeNode, Iterable, Sized, ABC): """A tree node that has children.""" + def __hash__(self): + """Hashes the contents of this container node. + + Equivalent to:: + + hash(tuple(self)) + + """ + return hash(tuple(self)) + + def __eq__(self, other): + """Tests whether this container node equals another. + + Equivalent to:: + + isinstance(other, ContainerNode) and all(a == b for a, b in zip(self, other)) + + Args: + other: The object to test. + + Returns: + bool: :const:`True` if this container node is equal to :obj:`other`. + + """ + return isinstance(other, ContainerNode) and all(a == b for a, b in zip(self, other)) + def children(self) -> List[TreeNode]: """The children of this node. diff --git a/test/test_formatting.py b/test/test_formatting.py index bc65098..3bf9389 100644 --- a/test/test_formatting.py +++ b/test/test_formatting.py @@ -40,7 +40,8 @@ def wrapper(self: 'TestFormatting'): raise ValueError(f'@filetype_test {name} must end with "{FILETYPE_TEST_SUFFIX}"') filetype_name = name[len(FILETYPE_TEST_PREFIX):-len(FILETYPE_TEST_SUFFIX)] if filetype_name not in graphtage.FILETYPES_BY_TYPENAME: - raise ValueError(f'Filetype "{filetype_name}" for @filetype_test {name} not found in graphtage.FILETYPES_BY_TYPENAME') + raise ValueError(f'Filetype "{filetype_name}" for @filetype_test {name} not found in ' + 'graphtage.FILETYPES_BY_TYPENAME') filetype = graphtage.FILETYPES_BY_TYPENAME[filetype_name] formatter = filetype.get_default_formatter() @@ -84,12 +85,16 @@ def make_random_bool() -> bool: return random.choice([True, False]) @staticmethod - def make_random_str(exclude_bytes: FrozenSet[str] = frozenset(), allow_empty_strings: bool = True) -> str: + def make_random_str( + exclude_bytes: FrozenSet[str] = frozenset(), + allow_empty_strings: bool = True, + max_length: int = 128 + ) -> str: if allow_empty_strings: min_length = 0 else: min_length = 1 - return ''.join(random.choices(list(STR_BYTES - exclude_bytes), k=random.randint(min_length, 128))) + return ''.join(random.choices(list(STR_BYTES - exclude_bytes), k=random.randint(min_length, max_length))) @staticmethod def make_random_non_container(exclude_bytes: FrozenSet[str] = frozenset(), allow_empty_strings: bool = True): @@ -253,3 +258,32 @@ def test_yaml_formatting(self): def test_plist_formatting(self): orig_obj = TestFormatting.make_random_obj(force_string_keys=True, exclude_bytes=frozenset('<>/\n&?|@{}[]')) return orig_obj, plistlib.dumps(orig_obj) + + @staticmethod + def make_random_flamegraph() -> str: + num_traces = random.randint(1, 500) + traces = [] + for _ in range(num_traces): + num_functions = random.randint(1, 32) + functions = [ + TestFormatting.make_random_str( + max_length=32, + exclude_bytes=frozenset({'\n', ' ', '\t', ';', '\r'}), + allow_empty_strings=False + ) + for _ in range(num_functions) + ] + stack_trace = ";".join(functions) + assert "\n" not in stack_trace + assert sum(1 for c in stack_trace if c == ";") == num_functions - 1 + num_samples = random.randint(1, 10000) + traces.append(f"{stack_trace} {num_samples}\n") + return "".join(traces) + + @filetype_test(iterations=10) + def test_flamegraph_formatting(self): + orig_obj = TestFormatting.make_random_flamegraph() + num_spaces = sum(1 for c in orig_obj if c == " ") + num_newlines = sum(1 for c in orig_obj if c == "\n") + self.assertEqual(num_spaces, num_newlines) + return orig_obj, orig_obj