Skip to content

Commit

Permalink
Introduces tests for cron_job_ingest_events.py plus refactoring for e…
Browse files Browse the repository at this point in the history
…asier tests
  • Loading branch information
Mark Kasaboski committed Nov 6, 2024
1 parent 7c3f637 commit af8e73c
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 47 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ inspect-tags:

.PHONY: test
test: venv-tools
@if test -d "./tests" ; \
then venv-tools/bin/pytest ./**/*.py -vv ; \
@if test -d "./packages/flare/tests" ; then \
venv-tools/bin/pytest ./packages/flare/tests/**/*.py -vv ; \
fi

.PHONY: format setup-web
Expand Down
116 changes: 71 additions & 45 deletions packages/flare/python/cron_job_ingest_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,10 @@
from constants import PasswordKeys
from flare import FlareAPI
from logger import Logger
from vendor.splunklib.client import Service


def main() -> None:
logger = Logger(class_name=__file__)
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

app: client.Application = splunk_service.apps[APP_NAME]
def main(logger: Logger, app: client.Application) -> None:
create_collection(app=app)

# To avoid cron jobs from doing the same work at the same time, exit new cron jobs if a cron job is already doing work
Expand All @@ -49,6 +37,14 @@ def main() -> None:
)
return

api_key, tenant_id = get_api_credentials(app=app)

fetch_and_print_feed_results(
logger=logger, app=app, api_key=api_key, tenant_id=tenant_id
)


def get_api_credentials(app: client.Application) -> tuple[str, int]:
api_key: Optional[str] = None
tenant_id: Optional[int] = None
for item in app.service.storage_passwords.list():
Expand All @@ -66,36 +62,7 @@ def main() -> None:
if not tenant_id:
raise Exception("Tenant ID not found")

try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")
return api_key, tenant_id


def get_next(app: client.Application, tenant_id: int) -> Optional[str]:
Expand Down Expand Up @@ -206,5 +173,64 @@ def save_collection_value(app: client.Application, key: str, value: Any) -> None
)


def fetch_and_print_feed_results(
logger: Logger,
app: client.Application,
api_key: str,
tenant_id: int,
) -> None:
try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
logger.debug(f"Fetching {tenant_id=}, {next=}, {start_date=}")
events_retrieved_count = 0
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))

events_retrieved_count += len(event_feed["items"])
except Exception as e:
logger.error(f"Exception={e}")

logger.debug(f"Retrieved {events_retrieved_count} events")


def get_splunk_service(logger: Logger) -> Service:
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

return splunk_service


if __name__ == "__main__":
main()
logger = Logger(class_name=__file__)
splunk_service = get_splunk_service(logger=logger)

main(
logger=logger,
app=splunk_service.apps[APP_NAME],
)
193 changes: 193 additions & 0 deletions packages/flare/tests/python/test_ingest_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import json
import os
import pytest
import sys

from typing import Any
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import PropertyMock
from unittest.mock import patch


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
from constants import KV_COLLECTION_NAME
from cron_job_ingest_events import fetch_and_print_feed_results
from cron_job_ingest_events import get_api_credentials
from cron_job_ingest_events import get_collection_value
from cron_job_ingest_events import save_collection_value


def test_get_collection_value_expect_none() -> None:
app = MagicMock()
assert get_collection_value(app=app, key="some_key") is None


def test_get_collection_value_expect_result() -> None:
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": "some_key",
"value": "some_value",
},
]

assert get_collection_value(app=app, key="some_key") == "some_value"


def test_save_collection_value_expect_insert() -> None:
key = "some_key"
value = "some_value"
app = MagicMock()
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.insert.assert_called_once_with(
json.dumps({"_key": key, "value": value})
)


def test_save_collection_value_expect_update() -> None:
key = "some_key"
value = "update_value"
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": key,
"value": "old_value",
},
]
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.update.assert_called_once_with(
id=key,
data=json.dumps({"value": value}),
)


def test_get_api_credentials_expect_exception() -> None:
app = MagicMock()

with pytest.raises(Exception, match="API key not found"):
get_api_credentials(app=app)

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")
app.service.storage_passwords.list.return_value = [api_key_item]

with pytest.raises(Exception, match="Tenant ID not found"):
get_api_credentials(app=app)


def test_get_api_credentials_expect_api_key_and_tenant_id() -> None:
app = MagicMock()

api_key_item = Mock()
type(api_key_item.content).username = PropertyMock(return_value="api_key")
type(api_key_item).clear_password = PropertyMock(return_value="some_api_key")

tenant_id_item = Mock()
type(tenant_id_item.content).username = PropertyMock(return_value="tenant_id")
type(tenant_id_item).clear_password = PropertyMock(return_value=11111)

app.service.storage_passwords.list.return_value = [api_key_item, tenant_id_item]

api_key, tenant_id = get_api_credentials(app=app)
assert api_key == "some_api_key"
assert tenant_id == 11111


def test_fetch_and_print_feed_results_expect_exception() -> None:
logger = MagicMock()
app = MagicMock()
with patch.object(
app.service,
"confs",
{
"flare": {
"endpoints": {
"me_feed_endpoint": "/firework/v2/me/feed"
}
}
},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Exception=Failed to fetch API Token")


def test_fetch_and_print_feed_results_expect_non_200_response() -> None:
logger = MagicMock()
app = MagicMock()

with patch("cron_job_ingest_events.FlareAPI") as MockFlareAPI:
response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=400)
type(response_mock).text = PropertyMock(return_value="Bad Request")

flare_api_mock_instance = MockFlareAPI.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch("time.sleep", return_value=None):
with patch.object(
app.service,
"confs",
{
"flare": {
"endpoints": {
"me_feed_endpoint": "/firework/v2/me/feed"
}
}
},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)

logger.error.assert_called_once_with("Bad Request")


def test_fetch_and_print_feed_results_expect_feed_response(capfd: Any) -> None:
logger = MagicMock()
app = MagicMock()

with patch("cron_job_ingest_events.FlareAPI") as MockFlareAPI:
response_mock = Mock()
type(response_mock).status_code = PropertyMock(return_value=200)
response_mock.json.return_value = {
"next": "some_next_value",
"items": [
{
"actor": "this guy",
},
{
"actor": "some other guy",
},
],
}

flare_api_mock_instance = MockFlareAPI.return_value
flare_api_mock_instance.retrieve_feed.return_value = [response_mock]

with patch("time.sleep", return_value=None):
with patch.object(
app.service,
"confs",
{
"flare": {
"endpoints": {
"me_feed_endpoint": "/firework/v2/me/feed"
}
}
},
):
fetch_and_print_feed_results(
logger=logger, app=app, api_key="some_key", tenant_id=11111
)
captured = capfd.readouterr()
assert (
captured.out
== '{"actor": "this guy"}\n{"actor": "some other guy"}\n'
)

0 comments on commit af8e73c

Please sign in to comment.