Skip to content

Commit

Permalink
Introduces tests for cron_job_ingest_events.py plus refactoring for e…
Browse files Browse the repository at this point in the history
…asier tests
  • Loading branch information
Mark Kasaboski committed Nov 6, 2024
1 parent 7c3f637 commit e042697
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 47 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ inspect-tags:

.PHONY: test
test: venv-tools
@if test -d "./tests" ; \
then venv-tools/bin/pytest ./**/*.py -vv ; \
@if test -d "./packages/flare/tests" ; then \
venv-tools/bin/pytest ./packages/flare/tests/**/*.py -vv ; \
fi

.PHONY: format setup-web
Expand Down
116 changes: 71 additions & 45 deletions packages/flare/python/cron_job_ingest_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,10 @@
from constants import PasswordKeys
from flare import FlareAPI
from logger import Logger
from vendor.splunklib.client import Service


def main() -> None:
logger = Logger(class_name=__file__)
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

app: client.Application = splunk_service.apps[APP_NAME]
def main(logger: Logger, app: client.Application) -> None:
create_collection(app=app)

# To avoid cron jobs from doing the same work at the same time, exit new cron jobs if a cron job is already doing work
Expand All @@ -49,6 +37,14 @@ def main() -> None:
)
return

api_key, tenant_id = get_api_credentials(app=app)

fetch_and_print_feed_results(
logger=logger, app=app, api_key=api_key, tenant_id=tenant_id
)


def get_api_credentials(app: client.Application) -> tuple[str, int]:
api_key: Optional[str] = None
tenant_id: Optional[int] = None
for item in app.service.storage_passwords.list():
Expand All @@ -66,36 +62,7 @@ def main() -> None:
if not tenant_id:
raise Exception("Tenant ID not found")

try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")
return api_key, tenant_id


def get_next(app: client.Application, tenant_id: int) -> Optional[str]:
Expand Down Expand Up @@ -206,5 +173,64 @@ def save_collection_value(app: client.Application, key: str, value: Any) -> None
)


def fetch_and_print_feed_results(
logger: Logger,
app: client.Application,
api_key: str,
tenant_id: int,
) -> None:
try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")


def get_splunk_service(logger: Logger) -> Service:
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

return splunk_service


if __name__ == "__main__":
main()
logger = Logger(class_name=__file__)
splunk_service = get_splunk_service(logger=logger)

main(
logger=logger,
app=splunk_service.apps[APP_NAME],
)
196 changes: 196 additions & 0 deletions packages/flare/tests/python/test_ingest_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from datetime import datetime, timedelta
import json
import os
import pytest
import sys

from typing import Any
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import PropertyMock
from unittest.mock import patch


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
from constants import CRON_JOB_THRESHOLD_SINCE_LAST_FETCH
from constants import KV_COLLECTION_NAME
from cron_job_ingest_events import fetch_and_print_feed_results
from cron_job_ingest_events import get_api_credentials
from cron_job_ingest_events import get_collection_value
from cron_job_ingest_events import main
from cron_job_ingest_events import save_collection_value


def test_get_collection_value_expect_none() -> None:
app = MagicMock()
assert get_collection_value(app=app, key="some_key") is None


def test_get_collection_value_expect_result() -> None:
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": "some_key",
"value": "some_value",
},
]

assert get_collection_value(app=app, key="some_key") == "some_value"


def test_save_collection_value_expect_insert() -> None:
key = "some_key"
value = "some_value"
app = MagicMock()
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.insert.assert_called_once_with(
json.dumps({"_key": key, "value": value})
)


def test_save_collection_value_expect_update() -> None:
key = "some_key"
value = "update_value"
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": key,
"value": "old_value",
},
]
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.update.assert_called_once_with(
id=key,
data=json.dumps({"value": value}),
)


def test_get_api_credentials_expect_exception() -> None:
app = MagicMock()

with pytest.raises(Exception, match="API key not found"):
get_api_credentials(app=app)

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")
app.service.storage_passwords.list.return_value = [api_key_item]

with pytest.raises(Exception, match="Tenant ID not found"):
get_api_credentials(app=app)


def test_get_api_credentials_expect_api_key_and_tenant_id() -> None:
app = MagicMock()

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")

tenant_id_item = Mock()
type(tenant_id_item.content).username = PropertyMock(return_value="tenant_id")
type(tenant_id_item).clear_password = PropertyMock(return_value=11111)

app.service.storage_passwords.list.return_value = [api_key_item, tenant_id_item]

api_key, tenant_id = get_api_credentials(app=app)
assert api_key == "some_api_key"
assert tenant_id == 11111


def test_fetch_and_print_feed_results_expect_exception() -> None:
logger = MagicMock()
app = MagicMock()
with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Exception=Failed to fetch API Token")


def test_fetch_and_print_feed_results_expect_non_200_response() -> None:
logger = MagicMock()
app = MagicMock()

with patch("cron_job_ingest_events.FlareAPI") as MockFlareAPI:
response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=400)
type(response_mock).text = PropertyMock(return_value="Bad Request")

flare_api_mock_instance = MockFlareAPI.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch("time.sleep", return_value=None):
with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Bad Request")


def test_fetch_and_print_feed_results_expect_feed_response(capfd: Any) -> None:
logger = MagicMock()
app = MagicMock()

with patch("cron_job_ingest_events.FlareAPI") as MockFlareAPI:
response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=200)
response_mock.json.return_value = {
"next": "some_next_value",
"items": [
{
"actor": "this guy",
},
{
"actor": "some other guy",
},
],
}

flare_api_mock_instance = MockFlareAPI.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch("time.sleep", return_value=None):
with patch.object(
app.service,
"confs",
{"flare": {"endpoints": {"me_feed_endpoint": "/firework/v2/me/feed"}}},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)
captured = capfd.readouterr()
assert (
captured.out
== '{"actor": "this guy"}\n{"actor": "some other guy"}\n'
)

def test_main_expect_early_return() -> None:
logger = MagicMock()
app = MagicMock()

with patch('cron_job_ingest_events.get_last_fetched', return_value=datetime.now() - timedelta(minutes=5)):
main(logger=logger, app=app)
logger.debug.assert_called_once_with(f"Fetched events less than {int(CRON_JOB_THRESHOLD_SINCE_LAST_FETCH.seconds / 60)} minutes ago, exiting")

def test_main_expect_normal_run() -> None:
logger = MagicMock()
app = MagicMock()

with patch('cron_job_ingest_events.get_last_fetched', return_value=datetime.now() - timedelta(minutes=11)):
with patch('cron_job_ingest_events.get_api_credentials', return_value=("some_api_key", "some_tenant_id")):
with patch('cron_job_ingest_events.fetch_and_print_feed_results') as fetch_and_print_feed_results_mock:
main(logger=logger, app=app)
fetch_and_print_feed_results_mock.assert_called_once_with(logger=logger, app=app, api_key="some_api_key", tenant_id="some_tenant_id")

0 comments on commit e042697

Please sign in to comment.