Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
Un-hardcoding these numbers makes it possible to arbitrarily extend the metric set definitions without needing to change the test code
  • Loading branch information
cthoyt committed Jan 4, 2022
1 parent 4f0815a commit d16e7dd
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/integration/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,37 @@ def test_classification(self):
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 11)
assert performance_metrics.shape == (1, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group"])
assert performance_metrics.shape == (5, 11)
assert performance_metrics.shape == (5, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group", "target_group"])
assert performance_metrics.shape == (20, 11)
assert performance_metrics.shape == (20, len(metric_set))

def test_regression(self):
metric_set = RatingMetricSet()
metric_set.normalize_metrics()
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 7)
assert performance_metrics.shape == (1, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group"])
assert performance_metrics.shape == (5, 7)
assert performance_metrics.shape == (5, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group", "target_group"])
assert performance_metrics.shape == (20, 7)
assert performance_metrics.shape == (20, len(metric_set))

def test_addition(self):
metric_set = RatingMetricSet() + ClassificationMetricSet()
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 18)
assert performance_metrics.shape == (1, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group"])
assert performance_metrics.shape == (5, 18)
assert performance_metrics.shape == (5, len(metric_set))

performance_metrics = score_card.generate_report(self.scores, grouping=["source_group", "target_group"])
assert performance_metrics.shape == (20, 18)
assert performance_metrics.shape == (20, len(metric_set))

0 comments on commit d16e7dd

Please sign in to comment.