Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logging of mappings in BaseLogger #3295

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions ignite/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from collections.abc import Mapping, Collection
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, TypeVar, \
Iterable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -118,10 +120,10 @@ def _setup_output_metrics_state_attrs(
self, engine: Engine, log_text: Optional[bool] = False, key_tuple: Optional[bool] = True
) -> Dict[Any, Any]:
"""Helper method to setup metrics and state attributes to log"""
metrics_state_attrs = OrderedDict()
state_metrics = OrderedDict()
if self.metric_names is not None:
if isinstance(self.metric_names, str) and self.metric_names == "all":
metrics_state_attrs = OrderedDict(engine.state.metrics)
state_metrics = OrderedDict(engine.state.metrics)
else:
for name in self.metric_names:
if name not in engine.state.metrics:
Expand All @@ -130,43 +132,72 @@ def _setup_output_metrics_state_attrs(
f"in engine's state metrics: {list(engine.state.metrics.keys())}"
)
continue
metrics_state_attrs[name] = engine.state.metrics[name]
state_metrics[name] = engine.state.metrics[name]

if self.output_transform is not None:
output_dict = self.output_transform(engine.state.output)

if not isinstance(output_dict, dict):
output_dict = {"output": output_dict}

metrics_state_attrs.update(output_dict)
state_metrics.update(output_dict)

if self.state_attributes is not None:
metrics_state_attrs.update({name: getattr(engine.state, name, None) for name in self.state_attributes})
state_metrics.update({name: getattr(engine.state, name, None) for name in self.state_attributes})

metrics_state_attrs_dict: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()
metrics: Dict[Any, Union[str, float, numbers.Number]] = OrderedDict()

def key_tuple_tf(tag: str, name: str, *args: str) -> Tuple[str, ...]:
return (tag, name) + args
def concat_tuple(parent_tag: Tuple[str], name: str, *args: str) -> Tuple[str, ...]:
return parent_tag + (name, *args)

def key_str_tf(tag: str, name: str, *args: str) -> str:
return "/".join((tag, name) + args)
def concat_str(parent_tag: str, name: str, *args: str) -> str:
return "/".join((parent_tag, name) + args)

key_tf = key_tuple_tf if key_tuple else key_str_tf
concat_tf = concat_tuple if key_tuple else concat_str

for name, value in metrics_state_attrs.items():
self._compute_tags(
concat_tf,
log_text,
node=state_metrics.items(),
parent_tag=(self.tag, ) if key_tuple else self.tag,
dest_dict=metrics)
return metrics

@classmethod
def _compute_tags(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make this just a private helper method in this submodule, not necessarily a classmethod.

cls,
concat_tf: Callable[..., Union[str, Tuple[str, ...]]],
log_text: bool, node: Iterable[Tuple[str, Any]],
parent_tag: Union[str, Tuple[str, ...]],
dest_dict: Dict[Any, Union[str, float, numbers.Number]]
Comment on lines +169 to +172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not keen of passing dest_dict as the argument, but maybe we can't do anything better here...

):
for name, value in node:
if isinstance(value, numbers.Number):
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
dest_dict[concat_tf(parent_tag, name)] = value
elif isinstance(value, torch.Tensor) and value.ndimension() == 0:
metrics_state_attrs_dict[key_tf(self.tag, name)] = value.item()
dest_dict[concat_tf(parent_tag, name)] = value.item()
elif isinstance(value, torch.Tensor) and value.ndimension() == 1:
for i, v in enumerate(value):
metrics_state_attrs_dict[key_tf(self.tag, name, str(i))] = v.item()
dest_dict[concat_tf(parent_tag, name, str(i))] = v.item()
elif isinstance(value, str) and log_text:
dest_dict[concat_tf(parent_tag, name)] = value
elif isinstance(value, Mapping):
cls._compute_tags(
concat_tf,
log_text,
node=value.items(),
parent_tag=concat_tf(parent_tag, name),
dest_dict=dest_dict)
elif isinstance(value, Collection):
cls._compute_tags(
concat_tf,
log_text,
node=iter(enumerate(value)),
parent_tag=concat_tf(parent_tag, name),
dest_dict=dest_dict)
else:
if isinstance(value, str) and log_text:
metrics_state_attrs_dict[key_tf(self.tag, name)] = value
else:
warnings.warn(f"Logger output_handler can not log metrics value type {type(value)}")
return metrics_state_attrs_dict
warnings.warn(
f"Logger output_handler can not log metrics value type {type(value)}")


class BaseWeightsScalarHandler(BaseWeightsHandler):
Expand Down