Skip to content

Commit

Permalink
fix: Don't abort a training task entirely if a 4xx error is encounter…
Browse files Browse the repository at this point in the history
…ed when fetching an artifact from a previous run (#673)

* fix: remove unused 'reasons_created' in train taskcluster tests

* fix: Don't abort a training task entirely if a 4xx error is encountered when fetching an artifact from a previous run
  • Loading branch information
bhearsum authored Jun 13, 2024
1 parent 90a9d7a commit 25ceab5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 42 deletions.
53 changes: 31 additions & 22 deletions taskcluster/scripts/pipeline/train_taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,23 @@ def main(args):

run_artifacts = set([os.path.basename(a["name"]) for a in resp.json()["artifacts"]])

resumable = True
if run_artifacts.issuperset(CONTINUATION_ARTIFACTS):
logging.info(
f"Run {prev_run_id} appears to have the artifacts we need! Downloading them..."
)
else:
logging.info(f"Run {prev_run_id} is missing some necessary artifacts...")
prev_run_id -= 1
continue
resumable = False

for artifact in resp.json()["artifacts"]:
# Skip Taskcluster logs - we only care about artifacts that the training tools create.
if artifact["name"].startswith("public/log"):
continue
out_name = os.path.basename(artifact["name"])
logging.info(f"Fetching {artifact['name']}...")
if resumable:
for artifact in resp.json()["artifacts"]:
# Skip Taskcluster logs - we only care about artifacts that the training tools create.
if artifact["name"].startswith("public/log"):
continue
out_name = os.path.basename(artifact["name"])
logging.info(f"Fetching {artifact['name']}...")

try:
r = requests.get(
ARTIFACT_URL.format(
root_url=root_url,
Expand All @@ -100,19 +100,28 @@ def main(args):
),
stream=True,
)
r.raise_for_status()
except Exception:
logging.exception("Caught exception, exiting with distinct code...")
sys.exit(DOWNLOAD_ERROR_EXIT_CODE)

with open(os.path.join(model_dir, out_name), "wb+") as fd:
for chunk in r.iter_content(chunk_size=8192):
fd.write(chunk)

# We successfully downloaded all the artifacts from a previous run. Override
# the pretrained model mode and we're done!
pretrained_model_mode = "continue"
break
if 400 <= r.status_code <= 500:
logging.exception(
f"Got 4xx error for {artifact['name']}, run {run_id} is not resumable..."
)
resumable = False
break
elif r.status_code >= 500:
logging.exception("Caught exception, exiting with distinct code...")
sys.exit(DOWNLOAD_ERROR_EXIT_CODE)

with open(os.path.join(model_dir, out_name), "wb+") as fd:
for chunk in r.iter_content(chunk_size=8192):
fd.write(chunk)

if resumable:
# We successfully downloaded all the artifacts from a previous run. Override
# the pretrained model mode and we're done!
pretrained_model_mode = "continue"
break
else:
# We weren't able to get all of the necessary artifacts; try the next previous run
prev_run_id -= 1

if pretrained_model_mode:
if len(script_args) < PRETRAINED_MODEL_MODE_ARG_NUMBER:
Expand Down
54 changes: 34 additions & 20 deletions tests/test_train_taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,24 @@ def test_all_args_forwarded(args):


@pytest.mark.parametrize(
"current_run_id,resumable_run_id,reasons_created,run_artifacts,orig_pretrained_model_mode,expected_pretrained_model_mode",
"current_run_id,resumable_run_id,run_artifacts,artifact_response_code,orig_pretrained_model_mode,expected_pretrained_model_mode",
(
pytest.param(
0,
None,
["scheduled"],
# not used unless resumable_run_id is set
{},
200,
"",
"",
id="run_0_no_continuation",
),
pytest.param(
0,
None,
["scheduled"],
# not used unless resumable_run_id is set
{},
200,
"init",
"init",
id="run_0_no_continuation_with_pretrained_model",
Expand All @@ -160,64 +160,73 @@ def test_all_args_forwarded(args):
pytest.param(
1,
0,
["scheduled", "retry"],
{0: FULL_ARTIFACTS},
200,
"",
"continue",
id="run_1_continues_run_0",
),
pytest.param(
2,
1,
["scheduled", "retry", "retry"],
{1: FULL_ARTIFACTS},
200,
"",
"continue",
id="run_2_continues_run_1",
),
pytest.param(
2,
0,
["scheduled", "rerun", "retry"],
{1: PARTIAL_ARTIFACTS, 0: FULL_ARTIFACTS},
200,
"",
"continue",
id="run_2_continues_run_0",
),
pytest.param(
3,
1,
["scheduled", "rerun", "exception", "retry"],
{2: PARTIAL_ARTIFACTS, 1: FULL_ARTIFACTS, 0: PARTIAL_ARTIFACTS},
200,
"",
"continue",
id="run_3_continues_run_1",
),
pytest.param(
2,
None,
["scheduled", "rerun", "exception"],
{1: PARTIAL_ARTIFACTS, 0: PARTIAL_ARTIFACTS},
200,
"",
"",
id="run_2_cant_continue_earlier_runs",
),
pytest.param(
2,
None,
["scheduled", "retry", "rerun"],
{1: PARTIAL_ARTIFACTS, 0: PARTIAL_ARTIFACTS},
200,
"use",
"use",
id="run_2_cant_continue_earlier_runs_preserves_pretrained_model_mode",
),
pytest.param(
2,
0,
{1: PARTIAL_ARTIFACTS, 0: FULL_ARTIFACTS},
404,
"",
"",
id="artifacts_are_404",
),
),
)
def test_autocontinue(
current_run_id,
resumable_run_id,
reasons_created,
run_artifacts,
artifact_response_code,
orig_pretrained_model_mode,
expected_pretrained_model_mode,
):
Expand Down Expand Up @@ -252,9 +261,10 @@ def fake_get(url, *args, **kwargs):
):
# No action needed here; we will check that the right calls were
# made based on the current_run_id later.
resp.status_code = 200
resp._content = b""
resp.raw = io.StringIO("")
resp.status_code = artifact_response_code
if resp.status_code == 200:
resp._content = b""
resp.raw = io.StringIO("")
elif url.endswith("live.log") or url.endswith("live_backing.log"):
resp.status_code = 400
resp._content = (
Expand Down Expand Up @@ -304,15 +314,19 @@ def fake_get(url, *args, **kwargs):

# However, we only expect to fetch the artifacts for the run we resume from...
if prev_run_id == resumable_run_id:
i = 0
for artifact in run_artifacts[prev_run_id]:
# ...but even then, we don't expect to download the Taskcluster logs.
# ...but even then, we don't expect to download the Taskcluster logs
if not artifact["name"].startswith("public/logs"):
calls.append(
mock.call(
f"https://some.cluster/api/queue/v1/task/abcdef/runs/{prev_run_id}/artifacts/{artifact['name']}",
stream=True,
),
)
# or anything after the first artifact if the response code is not 200
if artifact_response_code == 200 or i == 0:
i += 1
calls.append(
mock.call(
f"https://some.cluster/api/queue/v1/task/abcdef/runs/{prev_run_id}/artifacts/{artifact['name']}",
stream=True,
),
)
prev_run_id = prev_run_id - 1

assert tt_mock["requests"].get.call_args_list == calls
Expand Down

0 comments on commit 25ceab5

Please sign in to comment.