Skip to content

Commit

Permalink
Introduces tests for cron_job_ingest_events.py plus refactoring for e…
Browse files Browse the repository at this point in the history
…asier tests
  • Loading branch information
Mark Kasaboski committed Nov 6, 2024
1 parent 7c3f637 commit b050190
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 47 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ inspect-tags:

.PHONY: test
test: venv-tools
@if test -d "./tests" ; \
then venv-tools/bin/pytest ./**/*.py -vv ; \
@if test -d "./packages/flare/tests" ; then \
venv-tools/bin/pytest ./packages/flare/tests/**/*.py -vv ; \
fi

.PHONY: format setup-web
Expand Down
116 changes: 71 additions & 45 deletions packages/flare/python/cron_job_ingest_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,10 @@
from constants import PasswordKeys
from flare import FlareAPI
from logger import Logger
from vendor.splunklib.client import Service


def main() -> None:
logger = Logger(class_name=__file__)
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

app: client.Application = splunk_service.apps[APP_NAME]
def main(logger: Logger, app: client.Application) -> None:
create_collection(app=app)

# To avoid cron jobs from doing the same work at the same time, exit new cron jobs if a cron job is already doing work
Expand All @@ -49,6 +37,14 @@ def main() -> None:
)
return

api_key, tenant_id = get_api_credentials(app=app)

fetch_and_print_feed_results(
logger=logger, app=app, api_key=api_key, tenant_id=tenant_id
)


def get_api_credentials(app: client.Application) -> tuple[str, int]:
api_key: Optional[str] = None
tenant_id: Optional[int] = None
for item in app.service.storage_passwords.list():
Expand All @@ -66,36 +62,7 @@ def main() -> None:
if not tenant_id:
raise Exception("Tenant ID not found")

try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")
return api_key, tenant_id


def get_next(app: client.Application, tenant_id: int) -> Optional[str]:
Expand Down Expand Up @@ -206,5 +173,64 @@ def save_collection_value(app: client.Application, key: str, value: Any) -> None
)


def fetch_and_print_feed_results(
logger: Logger,
app: client.Application,
api_key: str,
tenant_id: int,
) -> None:
try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")


def get_splunk_service(logger: Logger) -> Service:
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

return splunk_service


if __name__ == "__main__":
main()
logger = Logger(class_name=__file__)
splunk_service = get_splunk_service(logger=logger)

main(
logger=logger,
app=splunk_service.apps[APP_NAME],
)
221 changes: 221 additions & 0 deletions packages/flare/tests/python/test_ingest_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import json
import os
import pytest
import sys

from datetime import datetime
from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import PropertyMock
from unittest.mock import patch


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
from constants import CRON_JOB_THRESHOLD_SINCE_LAST_FETCH
from constants import KV_COLLECTION_NAME
from cron_job_ingest_events import fetch_and_print_feed_results
from cron_job_ingest_events import get_api_credentials
from cron_job_ingest_events import get_collection_value
from cron_job_ingest_events import main
from cron_job_ingest_events import save_collection_value


def test_get_collection_value_expect_none() -> None:
app = MagicMock()
assert get_collection_value(app=app, key="some_key") is None


def test_get_collection_value_expect_result() -> None:
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": "some_key",
"value": "some_value",
},
]

assert get_collection_value(app=app, key="some_key") == "some_value"


def test_save_collection_value_expect_insert() -> None:
key = "some_key"
value = "some_value"
app = MagicMock()
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.insert.assert_called_once_with(
json.dumps({"_key": key, "value": value})
)


def test_save_collection_value_expect_update() -> None:
key = "some_key"
value = "update_value"
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": key,
"value": "old_value",
},
]
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.update.assert_called_once_with(
id=key,
data=json.dumps({"value": value}),
)


def test_get_api_credentials_expect_exception() -> None:
app = MagicMock()

with pytest.raises(Exception, match="API key not found"):
get_api_credentials(app=app)

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")
app.service.storage_passwords.list.return_value = [api_key_item]

with pytest.raises(Exception, match="Tenant ID not found"):
get_api_credentials(app=app)


def test_get_api_credentials_expect_api_key_and_tenant_id() -> None:
app = MagicMock()

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")

tenant_id_item = Mock()
type(tenant_id_item.content).username = PropertyMock(return_value="tenant_id")
type(tenant_id_item).clear_password = PropertyMock(return_value=11111)

app.service.storage_passwords.list.return_value = [api_key_item, tenant_id_item]

api_key, tenant_id = get_api_credentials(app=app)
assert api_key == "some_api_key"
assert tenant_id == 11111


def test_fetch_and_print_feed_results_expect_exception() -> None:
logger = MagicMock()
app = MagicMock()
with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Exception=Failed to fetch API Token")


@patch("cron_job_ingest_events.FlareAPI")
@patch("time.sleep", return_value=None)
def test_fetch_and_print_feed_results_expect_non_200_response(
sleep: Any,
flare_api_mock: MagicMock,
) -> None:
logger = MagicMock()
app = MagicMock()

response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=400)
type(response_mock).text = PropertyMock(return_value="Bad Request")

flare_api_mock_instance = flare_api_mock.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Bad Request")


@patch("cron_job_ingest_events.FlareAPI")
@patch("time.sleep", return_value=None)
def test_fetch_and_print_feed_results_expect_feed_response(
sleep: Any, flare_api_mock: MagicMock, capfd: Any
) -> None:
logger = MagicMock()
app = MagicMock()

response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=200)
response_mock.json.return_value = {
"next": "some_next_value",
"items": [
{
"actor": "this guy",
},
{
"actor": "some other guy",
},
],
}

flare_api_mock_instance = flare_api_mock.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)
captured = capfd.readouterr()
assert captured.out == '{"actor": "this guy"}\n{"actor": "some other guy"}\n'


@patch(
"cron_job_ingest_events.get_last_fetched",
return_value=datetime.now() - timedelta(minutes=5),
)
def test_main_expect_early_return(get_last_fetched_mock: MagicMock) -> None:
logger = MagicMock()
app = MagicMock()

main(logger=logger, app=app)
logger.debug.assert_called_once_with(
f"Fetched events less than {int(CRON_JOB_THRESHOLD_SINCE_LAST_FETCH.seconds / 60)} minutes ago, exiting"
)


@patch("cron_job_ingest_events.fetch_and_print_feed_results")
@patch(
"cron_job_ingest_events.get_api_credentials",
return_value=("some_api_key", "some_tenant_id"),
)
@patch(
"cron_job_ingest_events.get_last_fetched",
return_value=datetime.now() - timedelta(minutes=10),
)
def test_main_expect_normal_run(
get_last_fetched_mock: MagicMock,
get_api_credentials_mock: MagicMock,
fetch_and_print_feed_results_mock: MagicMock,
) -> None:
logger = MagicMock()
app = MagicMock()

main(logger=logger, app=app)
fetch_and_print_feed_results_mock.assert_called_once_with(
logger=logger,
app=app,
api_key="some_api_key",
tenant_id="some_tenant_id",
)

0 comments on commit b050190

Please sign in to comment.