Skip to content

Commit

Permalink
Introduces tests for cron_job_injest_events.py plus refactoring for e…
Browse files Browse the repository at this point in the history
…asier tests
  • Loading branch information
Mark Kasaboski committed Nov 5, 2024
1 parent d00305c commit 1c681de
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 41 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ inspect-tags:

.PHONY: test
test: venv-tools
@if test -d "./tests" ; \
then venv-tools/bin/pytest ./**/*.py -vv ; \
@if test -d "./packages/flare/tests" ; \
then venv-tools/bin/pytest ./packages/flare/tests/**/*.py -vv ; \
fi

.PHONY: format setup-web
Expand Down
102 changes: 63 additions & 39 deletions packages/flare/python/cron_job_ingest_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,10 @@
from constants import PasswordKeys
from flare import FlareAPI
from logger import Logger
from vendor.splunklib.client import Service


def main() -> None:
logger = Logger(class_name=__file__)
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

app: client.Application = splunk_service.apps[APP_NAME]
def main(logger: Logger, app: client.Application) -> None:
create_collection(app=app)

# To avoid cron jobs from doing the same work at the same time, exit new cron jobs if a cron job is already doing work
Expand All @@ -51,6 +39,12 @@ def main() -> None:
)
return

api_key, tenant_id = get_api_credentials(app=app)

fetch_and_print_feed_results(logger=logger, app=app, api_key=api_key, tenant_id=tenant_id)


def get_api_credentials(app: client.Application) -> dict[str, str]:
api_key: Optional[str] = None
tenant_id: Optional[int] = None
for item in app.service.storage_passwords.list():
Expand All @@ -68,30 +62,7 @@ def main() -> None:
if not tenant_id:
raise Exception("Tenant ID not found")

try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))
except Exception as e:
logger.error("Exception={}".format(e))
return api_key, tenant_id


def get_next(app: client.Application, tenant_id: int) -> Optional[str]:
Expand Down Expand Up @@ -202,5 +173,58 @@ def save_collection_value(app: client.Application, key: str, value: Any) -> None
)


def fetch_and_print_feed_results(
logger: Logger,
app: client.Application,
api_key: str,
tenant_id: int,
) -> None:
try:
flare_api = FlareAPI(app=app, api_key=api_key, tenant_id=tenant_id)

next = get_next(app=app, tenant_id=tenant_id)
start_date = get_start_date(app=app)
for response in flare_api.retrieve_feed(next=next, start_date=start_date):
save_last_fetched(app=app)

# Rate limiting.
time.sleep(1)

if response.status_code != 200:
logger.error(response.text)
return

event_feed = response.json()
save_start_date(app=app, tenant_id=tenant_id)
save_next(app=app, tenant_id=tenant_id, next=event_feed["next"])

if event_feed["items"]:
for item in event_feed["items"]:
print(json.dumps(item))
except Exception as e:
logger.error("Exception={}".format(e))


def get_splunk_service(logger: Logger) -> Service:
try:
splunk_service = client.connect(
host=HOST,
port=SPLUNK_PORT,
app=APP_NAME,
token=sys.stdin.readline().strip(),
)
except Exception as e:
logger.error(str(e))
raise Exception(str(e))

return splunk_service


if __name__ == "__main__":
main()
logger = Logger(class_name=__file__)
splunk_service = get_splunk_service(logger=logger)

main(
logger=logger,
app=splunk_service.apps[APP_NAME],
)
74 changes: 74 additions & 0 deletions packages/flare/tests/python/test_ingest_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

import json
import os
import pytest
import sys

from unittest.mock import MagicMock
from unittest.mock import patch


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
from constants import KV_COLLECTION_NAME
from cron_job_ingest_events import fetch_and_print_feed_results
from cron_job_ingest_events import get_collection_value
from cron_job_ingest_events import get_next
from cron_job_ingest_events import main
from cron_job_ingest_events import save_collection_value


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python/vendor"))
import vendor.splunklib.client as client


def test_get_collection_value_expect_none() -> None:
app = MagicMock()
assert get_collection_value(app=app, key="some_key") is None


def test_get_collection_value_expect_result() -> None:
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": "some_key",
"value": "some_value",
},
]

assert get_collection_value(app=app, key="some_key") == "some_value"


def test_save_collection_value_expect_insert() -> None:
key = "some_key"
value = "some_value"
app = MagicMock()
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.insert.assert_called_once_with(
json.dumps({"_key": key, "value": value})
)

def test_save_collection_value_expect_update() -> None:
key = "some_key"
value = "update_value"
app = MagicMock()
app.service.kvstore.__contains__.side_effect = lambda x: x == KV_COLLECTION_NAME
app.service.kvstore[KV_COLLECTION_NAME].data.query.return_value = [
{
"_key": key,
"value": "old_value",
},
]
save_collection_value(app=app, key=key, value=value)
app.service.kvstore[KV_COLLECTION_NAME].data.update.assert_called_once_with(
id=key,
data=json.dumps({"value": value}),
)

def test_fetch_and_print_feed_results_expect_exception() -> None:
logger = MagicMock()
app = MagicMock()
with patch.object(app.service, "confs", {"flare": {"endpoints": {"me_feed_endpoint": "https://api.flare.io/firework/api2/me/feed"}}}):
fetch_and_print_feed_results(logger=logger, app=app, api_key="some_key", tenant_id=11111)

logger.error.assert_called_once_with("Exception=Failed to fetch API Token")

0 comments on commit 1c681de

Please sign in to comment.