Skip to content

Commit

Permalink
+ multirun tests
Browse files Browse the repository at this point in the history
  • Loading branch information
imtambovtcev committed Aug 20, 2024
1 parent 127a903 commit fe6050e
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 14 deletions.
7 changes: 5 additions & 2 deletions hari_plotter/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,18 @@ def merge(cls, graphs: list['Graph']) -> 'Graph':
Returns:
'Graph': A new Graph instance representing the merged graph.
"""
merged_graph = nx.Graph()
merged_graph = cls()

for i, graph in enumerate(graphs):
# Create a mapping to rename nodes by adding a unique prefix based on graph index
mapping = {node: f"{node}_g{i}" for node in graph.nodes()}
mapping = {node: tuple([(n, i) for n in node])
for node in graph.nodes()}
renamed_graph = nx.relabel_nodes(graph if isinstance(graph, Graph) else graph.get_graph(
), mapping) # different behavior for graphs and LazyGraphs
merged_graph = nx.compose(merged_graph, renamed_graph)

merged_graph.set_gatherer(type(graphs[0].gatherer))

return merged_graph

def has_self_loops(self) -> bool:
Expand Down
10 changes: 9 additions & 1 deletion hari_plotter/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,21 @@ def clustering_graph(self, merge_remaining: bool = False, reinitialize: bool = F

clustering_key = self.request_to_tuple(clustering_settings)

def to_ints_and_tuples(node):
if isinstance(node, np.integer):
return int(node)
elif isinstance(node, (tuple, list, np.ndarray)):
return tuple(to_ints_and_tuples(n) for n in node)
else:
raise TypeError(f"Unsupported type: {type(node)}")

# initialize the graph
if reinitialize or (clustering_key not in self.clusterings or 'graph' not in self.clusterings[clustering_key]):

clustering = self.clustering(**clustering_settings)

clustering_nodes = [
list([tuple(node) for node in nodes]) for nodes in clustering.labels_nodes_dict().values()]
list([to_ints_and_tuples(node) for node in nodes]) for nodes in clustering.labels_nodes_dict().values()]
cluster_labels = clustering.cluster_labels

clustering_graph = self.mean_graph.copy()
Expand Down
72 changes: 63 additions & 9 deletions hari_plotter/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .dynamics import Dynamics
from .graph import Graph
from .group import Group
from .multirun import Multirun
from .simulation import Simulation


Expand Down Expand Up @@ -213,7 +214,7 @@ def info(cls) -> str:

@property
@abstractmethod
def available_parameters(self) -> list:
def available_parameters(self) -> list[str]:
raise NotImplementedError(
"This method must be implemented in subclasses")

Expand Down Expand Up @@ -283,6 +284,21 @@ def cluster_graph(self, cluster_settings: dict[str, Any], clusters_dynamics: lis
G.add_node(node_id, frame=frame_index,
cluster=cluster_name, Label=cluster_name)

def cluster_array_to_set_of_tuples(cluster):
"""
Returns a set of tuples from a cluster array.
If the shape of the cluster is (n, m), the code returns set[tuple[int]].
If the shape is (n, m, 2), the code returns set[tuple[tuple[int, int]]].
"""
if cluster.ndim == 2:
# Case when shape is (n, m): return set of tuples
return set(map(tuple, cluster))
elif cluster.ndim == 3 and cluster.shape[2] == 2:
# Case when shape is (n, m, 2): return set of tuples of tuples
return set(tuple(map(tuple, sub_cluster)) for sub_cluster in cluster)
else:
raise ValueError("Unsupported array shape")

if frame_index > 0: # There are previous frame clusters to connect to
current_frame_clusters = frame
prev_frame_clusters = clusters_dynamics[frame_index - 1]
Expand All @@ -295,12 +311,13 @@ def cluster_graph(self, cluster_settings: dict[str, Any], clusters_dynamics: lis
for current_cluster_name, current_cluster in current_frame_clusters.items():
current_node_id = self.generate_node_id(
frame_index, current_cluster_name)
current_set = set(
map(tuple, current_cluster))
current_set = cluster_array_to_set_of_tuples(
current_cluster)
best_match = None
max_overlap = 0
for prev_cluster_name, prev_cluster in prev_frame_clusters.items():
prev_set = set(map(tuple, prev_cluster))
prev_set = cluster_array_to_set_of_tuples(
prev_cluster)
prev_node_id = self.generate_node_id(
frame_index - 1, prev_cluster_name)

Expand All @@ -316,7 +333,8 @@ def cluster_graph(self, cluster_settings: dict[str, Any], clusters_dynamics: lis
else:
# More clusters in the previous frame, match each previous cluster to one from the current frame
for prev_cluster_name, prev_cluster in prev_frame_clusters.items():
prev_set = set(map(tuple, prev_cluster))
prev_set = cluster_array_to_set_of_tuples(
prev_cluster)
prev_node_id = self.generate_node_id(
frame_index - 1, prev_cluster_name)

Expand All @@ -325,7 +343,8 @@ def cluster_graph(self, cluster_settings: dict[str, Any], clusters_dynamics: lis
for current_cluster_name, current_cluster in current_frame_clusters.items():
current_node_id = self.generate_node_id(
frame_index, current_cluster_name)
current_set = set(map(tuple, current_cluster))
current_set = cluster_array_to_set_of_tuples(
current_cluster)
overlap = len(
prev_set.intersection(current_set))

Expand Down Expand Up @@ -552,7 +571,7 @@ def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset:
self._group_size = interval_size

@property
def available_parameters(self) -> list:
def available_parameters(self) -> list[str]:
"""
Retrieves the list of available parameters/methods from the data gatherer.
Expand Down Expand Up @@ -586,7 +605,7 @@ def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset:
self.data.group(num_intervals, interval_size, offset)

@property
def available_parameters(self) -> list:
def available_parameters(self) -> list[str]:
"""
Retrieves the list of available parameters/methods from the data gatherer.
Expand Down Expand Up @@ -620,7 +639,7 @@ def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset:
self.data.dynamics.group(num_intervals, interval_size, offset)

@property
def available_parameters(self) -> list:
def available_parameters(self) -> list[str]:
"""
Retrieves the list of available parameters/methods from the data gatherer.
Expand All @@ -632,3 +651,38 @@ def available_parameters(self) -> list:
@property
def time_range(self) -> list[float]:
return [0., float(len(self.data)-1) * self.data.dt]


class MultirunInterface(Interface):

REQUIRED_TYPE = Multirun

def __init__(self, data):
super().__init__(data=data, group_length=len(
data.simulations[0].dynamics.groups))
self.merge = self.data.merge()
self.data: Multirun

def __len__(self):
return len(self.merge)

@property
def available_parameters(self) -> list[str]:
"""
Retrieves the list of available parameters/methods from the data gatherer.
Returns:
list: A list of available parameters or methods.
"""
return self.megre.available_parameters

def _initialize_group(self, i: int) -> Group:
group = self.merge.dynamics.groups[i]
return Group([self.merge.dynamics[j] for j in group], time=np.array(group) * self.merge.dt)

def _regroup_dynamics(self, num_intervals: int, interval_size: int = 1, offset: int = 0):
self.merge.dynamics.group(num_intervals, interval_size, offset)

@property
def time_range(self) -> list[float]:
return [0., float(len(self.merge)-1) * self.merge.dt]
17 changes: 17 additions & 0 deletions hari_plotter/multirun.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,27 @@ def from_dirs(cls, dirs: list[str]):
return cls([Simulation.from_dir(d) for d in dirs])

def join(self, other: 'Multirun'):
assert len(self) == 0 or len(self.simulations) == len(
other.simulations), "Multiruns must have the same length"
self.simulations.extend(other.simulations)

def append(self, simulation: Simulation):
assert isinstance(
simulation, Simulation), "Can only append Simulation objects"
assert len(self) == 0 or len(self.simulations[0]) == len(
simulation), "Simulations must have the same length"
self.simulations.append(simulation)

def merge(self) -> Simulation:
return Simulation.merge(self.simulations)

@property
def available_parameters(self) -> list[str]:
if len(self.simulations) == 0:
return []
return self.simulations[0].dynamics[0].gatherer.node_parameters

def __len__(self):
if len(self.simulations) == 0:
return 0
return len(self.simulations[0])
28 changes: 26 additions & 2 deletions hari_plotter/node_gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def gather_everything(self) -> dict[str, Any]:
"""
return self.gather_unprocessed(list(self.node_parameter_logger.keys()))

@staticmethod
def merged_node_id(cluster):
"""
Returns a new node ID for a merged cluster.
Parameters:
cluster (list[tuple[int]]): A list of node IDs to be merged.
Returns:
tuple[int]: A new node ID representing the merged cluster.
"""
def to_ints_and_tuples(node):
if isinstance(node, (int, np.integer)):
return int(node)
elif isinstance(node, (list, tuple, np.ndarray)):
return tuple(node)
else:
raise TypeError(f"Unsupported type: {type(node)}")

cluster_id = []
for node in cluster:
cluster_id.extend(tuple([to_ints_and_tuples(n) for n in node]))
return tuple(sorted(cluster_id))


class DefaultNodeEdgeGatherer(NodeEdgeGatherer):
"""
Expand Down Expand Up @@ -262,7 +286,7 @@ def merge_clusters(self, clusters: list[list[tuple[int]]], labels: Union[list[st
labels.append(None)

for cluster, label in zip(clusters, labels):
new_node_id = tuple(sorted(sum(cluster, ())))
new_node_id = NodeEdgeGatherer.merged_node_id(cluster)
merged_attributes = self.merge(cluster)

self.G.add_node(new_node_id, **merged_attributes)
Expand Down Expand Up @@ -565,7 +589,7 @@ def merge_clusters(self, clusters: list[list[tuple[int]]], labels: Union[list[st
labels.append(None)

for cluster, label in zip(clusters, labels):
new_node_id = tuple(sorted(sum(cluster, ())))
new_node_id = NodeEdgeGatherer.merged_node_id(cluster)
merged_attributes = self.merge(cluster)

self.G.add_node(new_node_id, **merged_attributes)
Expand Down
2 changes: 2 additions & 0 deletions hari_plotter/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def __init__(self, interfaces: Interface | list[Interface] | None = None):
interface : Interface
Interface instance to be used for plotting.
"""
assert interfaces is None or isinstance(
interfaces, (Interface, list)), "Interface must be an Interface instance or a list of Interface instances"
self._interfaces: list[Interface] | None = [interfaces] if isinstance(
interfaces, Interface) else interfaces
self.default_color_scheme: ColorScheme = ColorScheme()
Expand Down
52 changes: 52 additions & 0 deletions tests/test_multirun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from hari_plotter.multirun import Multirun
from hari_plotter.simulation import Simulation
from hari_plotter.model import Model


class TestSimulation:

@classmethod
def setup_class(cls):
cls.degroot = Simulation.from_toml('tests/degroot.toml')
cls.activity = Simulation.from_toml('tests/activity.toml')

def test_from_toml(self):
assert isinstance(self.degroot, Simulation)
assert isinstance(self.activity, Simulation)
assert isinstance(self.degroot.model, Model)
assert isinstance(self.activity.model, Model)

def test_to_toml(self, tmp_path):
filename = tmp_path / "test.toml"
self.degroot.to_toml(filename)
assert filename.exists(), 'File not created'


@pytest.fixture
def simulation_fixture():
return TestSimulation()


@pytest.fixture
def mock_multirun(simulation_fixture):
return Multirun([simulation_fixture.degroot, simulation_fixture.activity])


def test_multirun_initialization(simulation_fixture):
mr = Multirun([simulation_fixture.degroot, simulation_fixture.activity])
assert len(mr.simulations) == 2
assert mr.simulations[0] == simulation_fixture.degroot
assert mr.simulations[1] == simulation_fixture.activity


def test_multirun_join(mock_multirun, simulation_fixture):
another_multirun = Multirun([simulation_fixture.degroot])
mock_multirun.join(another_multirun)
assert len(mock_multirun.simulations) == 3


def test_multirun_append(mock_multirun, simulation_fixture):
new_simulation = Simulation.from_toml('tests/activity.toml')
mock_multirun.append(new_simulation)
assert len(mock_multirun.simulations) == 3

0 comments on commit fe6050e

Please sign in to comment.