Skip to content

Commit

Permalink
Try to reduce network usage in cuML tests. (#6174)
Browse files Browse the repository at this point in the history
This PR tries to use `"session"` scope pytest fixtures and cached data downloads to reduce network usage in cuML's nightly tests.

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #6174
  • Loading branch information
bdice authored Dec 12, 2024
1 parent 029b708 commit 84a858b
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions python/cuml/cuml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
from ssl import create_default_context
from urllib.request import build_opener, HTTPSHandler, install_opener
import certifi
import functools
import hypothesis
from cuml.internals.safe_imports import gpu_only_import
import pytest
import os
import subprocess
import time
import pandas as pd
import cudf.pandas

Expand Down Expand Up @@ -212,7 +214,7 @@ def pytest_pyfunc_call(pyfuncitem):
pytest.skip("Test requires cudf.pandas accelerator")


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def nlp_20news():
try:
twenty_train = fetch_20newsgroups(
Expand All @@ -228,7 +230,7 @@ def nlp_20news():
return X, Y


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def housing_dataset():
try:
data = fetch_california_housing()
Expand All @@ -245,16 +247,30 @@ def housing_dataset():
return X, y, feature_names


@pytest.fixture(scope="module")
@functools.cache
def get_boston_data():
n_retries = 3
url = "https://raw.githubusercontent.com/scikit-learn/scikit-learn/baf828ca126bcb2c0ad813226963621cafe38adb/sklearn/datasets/data/boston_house_prices.csv" # noqa: E501
for _ in range(n_retries):
try:
return pd.read_csv(url, header=None)
except Exception:
time.sleep(1)
raise RuntimeError(
f"Failed to download file from {url} after {n_retries} retries."
)


@pytest.fixture(scope="session")
def deprecated_boston_dataset():
# dataset was removed in Scikit-learn 1.2, we should change it for a
# better dataset for tests, see
# https://github.com/rapidsai/cuml/issues/5158

df = pd.read_csv(
"https://raw.githubusercontent.com/scikit-learn/scikit-learn/baf828ca126bcb2c0ad813226963621cafe38adb/sklearn/datasets/data/boston_house_prices.csv",
header=None,
) # noqa: E501
try:
df = get_boston_data()
except: # noqa E722
pytest.xfail(reason="Error fetching Boston housing dataset")
n_samples = int(df[0][0])
data = df[list(np.arange(13))].values[2:n_samples].astype(np.float64)
targets = df[13].values[2:n_samples].astype(np.float64)
Expand All @@ -266,7 +282,7 @@ def deprecated_boston_dataset():


@pytest.fixture(
scope="module",
scope="session",
params=["digits", "deprecated_boston_dataset", "diabetes", "cancer"],
)
def test_datasets(request, deprecated_boston_dataset):
Expand Down Expand Up @@ -313,7 +329,7 @@ def failure_logger(request):
print(error_msg)


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def exact_shap_regression_dataset():
return create_synthetic_dataset(
generator=skl_make_reg,
Expand All @@ -326,7 +342,7 @@ def exact_shap_regression_dataset():
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def exact_shap_classification_dataset():
return create_synthetic_dataset(
generator=skl_make_clas,
Expand Down

0 comments on commit 84a858b

Please sign in to comment.