Introduction
In this article, a detour from the ongoing series on building a financial data analyzer, I focus on the critical aspect of rigorously testing the server. Ensuring appropriate and accurate data handling is paramount. While Python doesn't enforce strict type-safety, I'll demonstrate how to use tools like mypy
, bandit
, and later prospector
to maintain basic code quality and standards.
Prerequisite
You should read the article on building the preliminary AI service to follow along. This article builds upon the concepts and code established there.
Live version
Source code
An AI-powered financial behavior analyzer and advisor written in Python (aiohttp) and TypeScript (ExpressJS & SvelteKit with Svelte 5)
Implementation
Step 1: Improving the Initial AI Service
The AI service described in the AI service has several areas for improvement:
- Testability: The current structure makes automated testing (both integration and unit) difficult.
- Model Accuracy: The zero-shot classification model, originally designed for sentiment analysis, isn't optimal for categorizing financial transactions. A more suitable model is needed.
- Code Quality: The code requires refactoring, cleanup, and the addition of new features.
- Type Consistency: Type annotations need to be consistently applied and enforced throughout the codebase.
To address these, we will adopt this structure:
.
├── README.md
├── mypy.ini
├── requirements.dev.txt
├── requirements.txt
├── run.py
├── scripts
│ └── test_app.sh
├── src
│ ├── __init_.py
│ ├── app
│ │ ├── __init__.py
│ │ └── app_instance.py
│ ├── models
│ │ ├── __init__.py
│ │ └── base.py
│ └── utils
│ ├── __init__.py
│ ├── analyzer.py
│ ├── base.py
│ ├── extract_text.py
│ ├── resume_parser.py
│ ├── settings.py
│ ├── summarize.py
│ └── websocket.py
└── tests
├── __init__.py
We introduced the src/
directory to house the entire application. The aiohttp
server setup was refactored into src/app/app_instance.py
, with run.py
simply responsible for running the created app instance:
import os
from aiohttp import web
from src.app.app_instance import init_app
from src.utils.settings import base_settings
if __name__ == '__main__':
app = init_app()
try:
web.run_app(
app,
host='0.0.0.0',
port=int(os.environ.get('PORT', 5173)),
)
except KeyboardInterrupt:
base_settings.logger.info('Received keyboard interrupt...')
except Exception as e:
base_settings.logger.error(f'Server error: {e}')
finally:
base_settings.logger.info('Server shutdown complete.')
The run.py
file initializes and starts the aiohttp application.
The key changes in app_instance.py
are highlighted below:
+ import asyncio
+ from weakref import WeakSet
...
- from utils.analyzer import analyze_transactions
- from utils.extract_text import extract_text_from_pdf
- from utils.resume_parser import extract_text_with_pymupdf, parse_resume_text
- from utils.settings import base_settings
- from utils.summarize import summarize_transactions
- from utils.websocket import WebSocketManager
+ from src.utils.analyzer import analyze_transactions
+ from src.utils.extract_text import extract_text_from_pdf
+ from src.utils.resume_parser import extract_text_with_pymupdf, parse_resume_text
+ from src.utils.settings import base_settings
+ from src.utils.summarize import summarize_transactions
+ from src.utils.websocket import WebSocketManager
# Replace global ws_connections with typed version
- ws_connections: set[WebSocketResponse] = set()
- ws_lock = Lock()
+ WEBSOCKETS = web.AppKey("websockets", WeakSet[WebSocketResponse])
- async def start_background_tasks(app):
+ async def start_background_tasks(app: web.Application) -> None:
"""Initialize application background tasks."""
- app['ws_connections'] = ws_connections
- app['ws_lock'] = ws_lock
+ app[WEBSOCKETS] = WeakSet()
- async def cleanup_background_tasks(app):
- """Cleanup application resources."""
- await cleanup_ws(app)
- async def cleanup_ws(app):
+ async def cleanup_ws(app: web.Application) -> None:
"""Cleanup WebSocket connections on shutdown."""
- async with ws_lock:
- connections = set(ws_connections) # Create a copy to iterate safely
- for ws in connections:
- await ws.close(code=WSMsgType.CLOSE, message='Server shutdown')
- ws_connections.clear()
+ for websocket in set(app[WEBSOCKETS]): # type: ignore
+ await websocket.close(code=WSCloseCode.GOING_AWAY, message=b'Server shutdown')
async def websocket_handler(request: Request) -> WebSocketResponse:
"""WebSocket handler for real-time communication."""
ws = web.WebSocketResponse()
await ws.prepare(request)
- async with ws_lock:
- ws_connections.add(ws)
+ request.app[WEBSOCKETS].add(ws)
ws_manager = WebSocketManager(ws)
await ws_manager.prepare()
+ async def ping_server(ws: WebSocketResponse) -> None:
+ try:
+ while True:
+ await ws.ping()
+ await asyncio.sleep(25)
+ except ConnectionResetError:
+ base_settings.logger.info("Client disconnected")
+ finally:
+ await ws.close()
+
+ asyncio.create_task(ping_server(ws))
base_settings.logger.info('WebSocket connection established')
try:
async for msg in ws:
+ if msg.type == WSMsgType.PING:
+ base_settings.logger.info('Intercepted PING from client')
+ await ws.pong(msg.data)
+ elif msg.type == WSMsgType.PONG:
+ base_settings.logger.info('Intercepted PONG from client')
if msg.type == WSMsgType.TEXT:
...
- elif msg.type == WSMsgType.ERROR:
+ elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
- base_settings.logger.error(f'WebSocket error: {ws.exception()}')
+ base_settings.logger.info(
+ 'WebSocket is closing or encountered an error',
+ )
+ break
except Exception as e:
base_settings.logger.error(f'WebSocket handler error: {str(e)}')
finally:
- async with ws_lock:
- ws_connections.remove(ws)
- if not ws.closed:
- await ws.close()
+ request.app[WEBSOCKETS].discard(ws)
+ if not ws.closed:
+ await ws.close()
base_settings.logger.info('WebSocket connection closed')
return ws
def init_app() -> web.Application:
...
# Add startup/cleanup handlers
app.on_startup.append(start_background_tasks)
- app.on_cleanup.append(cleanup_background_tasks)
+ app.on_shutdown.append(cleanup_ws)
return app
We improved type consistency throughout the codebase, using # type: ignore
where necessary. We also replaced the global WebSocket connection list with weakref.WeakSet
for more robust connection management during a shutdown. To maintain persistent connections during long-running processes like zero-shot classification, we implemented a ping/pong mechanism.
Next, we consolidated common utility functions into a new src/utils/base.py
file. This included functions like validate_and_convert_transactions
, get_device
, detect_anomalies
, analyze_spending
, predict_trends
, calculate_trend
, and calculate_percentage_change
, previously located in utils/summarize.py
and utils/analyze.py
. We also introduced new functions to estimate financial health (calculate_financial_health
) and detect recurring transactions (analyze_recurring_transactions
). The anomaly detection was enhanced to identify single-instance anomalies, and the transaction grouping algorithm now uses difflib
for fuzzy matching of descriptions. For example, difflib
might consider these descriptions to be similar (approximately 69% match): "Target T-12345 Anytown USA" and "Target 12345 Anytown USA":
def group_transactions_by_description(transactions: list[Transaction], cutoff: float = 0.69) -> dict[str, list[float]]:
"""
Group transactions by description using fuzzy matching with difflib.
Returns a dictionary mapping a representative description (the group key)
to a list of transaction amounts. Two descriptions are grouped together if
their similarity is above a certain threshold.
"""
groups: dict[str, list[float]] = {}
for tx in transactions:
desc = tx.description.lower().strip()
# Try to find an existing key similar to desc.
# difflib.get_close_matches returns a list of close matches.
close_matches = difflib.get_close_matches(desc, groups.keys(), n=1, cutoff=cutoff)
if close_matches:
matched_key = close_matches[0]
else:
matched_key = None
if matched_key:
groups[matched_key].append(tx.amount)
else:
groups[desc] = [tx.amount]
return groups
def find_group_key(description: str, group_keys: list[str], cutoff: float = 0.69) -> str:
"""
Find the best matching key from group_keys for the given description using difflib.
Returns the matched key if similarity is above cutoff; otherwise, returns the description.
"""
desc = description.lower().strip()
matches = difflib.get_close_matches(desc, group_keys, n=1, cutoff=cutoff)
if matches:
return matches[0]
return desc
We also encapsulated sending progress reports in a reusable function, update_progress
.
In src/utils/analyzer.py
, the major improvements are:
- Improved Model Accuracy: We switched from the
yiyanghkust/finbert-tone
model tofacebook/bart-large-mnli
for zero-shot classification. This significantly improves accuracy, although at the cost of speed. For multilingual support,joeddav/xlm-roberta-large-xnli
is another option. - Hybrid Classification Approach: We now use a hybrid approach, first attempting to classify transactions using pattern matching. Any remaining unclassified transactions are then processed by the ML model. To improve performance, we process transactions in batches, releasing the event loop after each batch to allow other operations to proceed and to clear memory.
- Offloading Calculations: To reduce the load on the classification process, we moved the calculation of
anomalies
,spending_analysis
,spending_trends
,recurring_transactions
, andfinancial_health
tosrc/utils/summarize.py
, which is significantly faster.
Step 2: Enforcing Type Safety, Security, and Style
Our type annotations are currently only decorative. To enforce type safety, ensure code security, and maintain a consistent code style, we'll use the following tools:
mypy
: A static type checker.bandit
: A security linter.black
: An uncompromising code formatter.isort
: A tool for sorting imports.
Prospector
Prospector
provides comprehensive static analysis and ensures your code conforms to PEP8 and other style guidelines. It's highly recommended for in-depth code quality checks.
Install these tools and add them to requirements.dev.txt
:
(virtualenv) pip install mypy bandit black isort
Create a mypy.ini
file at the root of the project with the following configuration:
# some config from:
# https://www.ralphminderhoud.com/blog/django-mypy-check-runs/
[mypy]
# The mypy configurations: https://mypy.readthedocs.io/en/latest/config_file.html
python_version = 3.13
check_untyped_defs = True
disallow_untyped_defs= True
disallow_incomplete_defs = True
disallow_any_generics = True
disallow_untyped_calls = True
# needs this because celery doesn't have typings
disallow_untyped_decorators = False
ignore_errors = False
ignore_missing_imports = True
implicit_reexport = False
strict_optional = True
strict_equality = True
no_implicit_optional = True
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unreachable = True
warn_no_return = True
# added these 2 option in mypy 0.800 to enable it to run in our code base
explicit_package_bases = True
namespace_packages = True
[mypy-*.migrations.*]
ignore_errors = True
This configuration enforces various type-checking rules. Each option is generally self-explanatory.
Next, create a bash script (scripts/static_check.sh
) to automate the static analysis process:
#!/usr/bin/env bash
set -e
# run black - make sure everyone uses same python style
black --skip-string-normalization --line-length 120 --check src/
black --skip-string-normalization --line-length 120 --check run.py
black --skip-string-normalization --line-length 120 --check tests/
# run isort for import structure checkup with black profile
isort --atomic --profile black -c src/
isort --atomic --profile black -c run.py
isort --atomic --profile black -c tests/
# run mypy
mypy src/
# run bandit - A security linter from OpenStack Security
bandit -r src/
# python static analysis
# prospector --profile=.prospector.yml --path=src --ignore-patterns=static
# prospector --profile=.prospector.yml --path=tests --ignore-patterns=static
This script checks the code against the defined standards. To ensure your code passes these checks, run the following commands before committing:
black --skip-string-normalization --line-length 120 src tests *.py
isort --atomic --profile black src tests *.py
To enforce these rules in a team environment, we'll use a CI/CD pipeline. This pipeline runs these checks, and any failure prevents the pull or merge request from being merged. We will use GitHub Actions for our CI/CD. Create a .github/workflows/aiohttp.yml
file:
name: UTILITY-SERVER CI
on:
push:
branches: [utility]
pull_request:
branches: [utility]
jobs:
build:
runs-on: ubuntu-latest
strategy:
max-parallel: 4
matrix:
python-version: [3.13] #[3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v4
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y \
poppler-utils \
tesseract-ocr \
libtesseract-dev \
libglib2.0-0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.dev.txt
- name: Debug Environment
run: |
python -c "import sys; print(sys.version)"
python -c "import platform; print(platform.platform())"
env | sort
- name: Run static analysis
run: ./scripts/static_check.sh
- name: Run Tests
env:
LC_ALL: en_NG.UTF-8
LANG: en_NG.UTF-8
LABELS: groceries,school,housing,transportation,gadgets,entertainment,utilities,credit cards,other,dining out,healthcare,insurance,savings,investments,childcare,travel,personal care,debts,charity,taxes,subscriptions,streaming services,home maintenance,shopping,pets,fitness,hobbies,gifts
run: |
coverage run --parallel-mode -m unittest discover tests && coverage combine && coverage report -m && coverage html
GitHub Actions uses .yaml
or .yml
files to define workflows, similar to docker-compose.yml
. In this case, we're using the latest Ubuntu distribution as the environment. We use version 4 of the actions/checkout action to check out our repository. We also install system dependencies required by some of the Python packages, such as poppler-utils
for pdf2image
and tesseract-ocr
and libtesseract-dev
for pytesseract
. Since our project doesn't have database interaction, we don't need a services
section. The remaining steps are self-explanatory. We then execute our bash script to check the codebase against our defined standards. We also supply environment variables and run the tests (which we'll write later). This CI/CD pipeline runs on every pull request or push to the utility branch.
Step 3: Writing the tests
The last part of our CI/CD was running tests and getting coverage reports. In the Python ecosystem, pytest is an extremely popular testing framework. Though very tempting and might still be used later on, we will stick with Python's built-in testing library, unittest, and use coverage for measuring the code test coverage of our program. Let's start with the test setup:
import unittest
import uuid
from datetime import datetime
from aiohttp.test_utils import AioHTTPTestCase
from src.app.app_instance import init_app
from src.utils.websocket import WebSocketManager
class Base:
"""Base class for tests."""
def create_transaction_dict(
self,
date: datetime | str,
description: str,
amount: float,
balance: float,
type_str='expense',
include_v: bool = True,
) -> dict:
txn = {
'_id': str(uuid.uuid4()),
'date': date.isoformat() if isinstance(date, datetime) else date,
'createdAt': date.isoformat() if isinstance(date, datetime) else date,
'updatedAt': date.isoformat() if isinstance(date, datetime) else date,
'description': description,
'amount': amount,
'balance': balance,
'type': type_str,
'userId': '1',
}
if include_v:
txn['__v'] = 0
return txn
# A simple fake WebSocketResponse to simulate aiohttp behavior.
class FakeWebSocket:
def __init__(self, raise_on_send=False):
self.messages = [] # will store the JSON messages sent
self.closed = False
self.raise_on_send = raise_on_send
self.close_code = None
self.close_message = None
async def send_json(self, data):
if self.raise_on_send:
raise Exception('send_json error')
self.messages.append(data)
async def close(self, code=None, message=None):
self.closed = True
self.close_code = code
self.close_message = message
class BaseAsyncTestClass(Base, unittest.IsolatedAsyncioTestCase):
"""Base class for async tests."""
async def asyncSetUp(self):
# Create a FakeWebSocket for each test.
self.fake_ws = FakeWebSocket()
self.websocket_manager = WebSocketManager(self.fake_ws)
class BaseTestClass(Base, unittest.TestCase):
"""Base class for sync tests."""
class BaseAioHTTPTestCase(Base, AioHTTPTestCase):
"""Base class for aiohttp tests."""
async def get_application(self):
return init_app()
async def asyncSetUp(self):
await super().asyncSetUp()
# Create a FakeWebSocket for each test.
self.fake_ws = FakeWebSocket()
self.websocket_manager = WebSocketManager(self.fake_ws)
if __name__ == '__main__':
unittest.main()
We simply have classes which provide blueprints for our tests. The Base
class makes the create_transaction_dict
method available to all its children, simplifying the creation of transaction data for tests. The FakeWebSocket
class simulates aiohttp WebSocket behavior, which is essential for unit testing the project's WebSocket utilities. All asynchronous unit tests inherit from BaseAsyncTestClass
, while synchronous tests inherit from BaseTestClass
. BaseAioHTTPTestCase
is used for integration-style tests that involve the aiohttp application. The get_application
is required in this class to return our app's instance.
A unit test focuses on testing a single piece of code (like a function such as analyze_recurring_transactions
) whereas integration tests examine how multiple units of code interact with each other within a system (this is like testing the behavior of sending a request to /ws
)
Let's take an example integration-style test, especially for our websocket, and another unit test for some of the subprocesses to balance things out:
import asyncio
import json
from unittest.mock import AsyncMock, patch
from aiohttp import WSMsgType
from src.app.app_instance import WEBSOCKETS
from tests import BaseAioHTTPTestCase
class TestWebSocketHandler(BaseAioHTTPTestCase):
"""Exhaustively test the WebSocket handler."""
async def setUpAsync(self):
await super().setUpAsync()
# Capture the original create_task function.
self.orig_create_task = asyncio.create_task
async def __dummy_analyze(self, transactions, ws_manager):
"""Dummy analyze implementation that returns a known result."""
return {
'categories': {
'expenses': {
'groceries': 10.0,
'rent': 90.0,
},
'expense_percentages': {
'groceries': 5,
'rent': 45,
},
'income': 200.0,
}
}
async def __dummy_summarize(self, transactions, ws_manager):
"""Dummy summarize implementation that returns a known result."""
return {
'income': {
'total': 200.00,
'trend': 'neutral',
'change': 0.0,
},
'expenses': {
'total': 100.00,
'trend': 'neutral',
'change': 0.0,
},
'savings': {
'total': 100.00,
'trend': 'neutral',
'change': 0.0,
},
'total_transactions': 2,
'expense_count': 1,
'income_count': 1,
'avg_expense': 100.00,
'avg_income': 200.00,
'start_date': '2022-01-01',
'end_date': '2022-01-31',
'largest_expense': 200.00,
'largest_income': 200.00,
'savings_rate': 50.0,
'monthly_summary': {
'2022-01': {
'income': 200.00,
'expenses': 100.00,
'savings': 100.00,
},
},
'anomalies': [],
'spending_analysis': {
'total_spent': 100.00,
'total_income': 200.00,
'savings_rate': 50.0,
'daily_summary': {
'2022-01-01': {
'total_spent': 100.00,
'total_income': 200.00,
'savings_rate': 50.0,
},
},
'cumulative_balance': {
'2022-01-01': 100.00,
},
},
'spending_trends': {
'total_spent': 100.00,
'total_income': 200.00,
'savings_rate': 50.0,
},
'recurring_transactions': [],
'financial_health': {
'debt_to_income_ratio': 0,
'savings_rate': 0,
'balance_growth_rate': 0,
'financial_health_score': 0,
},
}
def __dummy_create_task(self, coro):
if hasattr(coro, 'cr_code') and 'ping_server' in coro.cr_code.co_qualname:
# Explicitly close the ping_server coroutine so it doesn't leak.
coro.close()
# Return a dummy, already‐completed future.
fut = asyncio.Future()
fut.set_result(None)
return fut
return self.orig_create_task(coro)
async def __receive_messages(self, ws, count, timeout=5):
"""Helper to collect 'count' text messages from the WebSocket."""
messages = []
while len(messages) < count:
msg = await ws.receive(timeout=timeout)
if msg.type == WSMsgType.TEXT:
messages.append(json.loads(msg.data))
elif msg.type == WSMsgType.CLOSE:
break
return messages
async def test_analyze_action(self):
"""Test that sending an 'analyze' action yields progress and result messages."""
self.transactions = [
self.create_transaction_dict('2022-01-01', 'Transaction 1', -100.0, 100.0),
self.create_transaction_dict('2022-01-02', 'Transaction 2', 200.0, 300.0),
]
# Patch the analyzer so that it returns a predictable result,
# and patch create_task with our dummy version.
with patch(
'src.app.app_instance.analyze_transactions',
new=AsyncMock(side_effect=self.__dummy_analyze),
), patch("asyncio.create_task", self.__dummy_create_task):
ws = await self.client.ws_connect('/ws')
msg_data = {'action': 'analyze', 'transactions': self.transactions}
await ws.send_str(json.dumps(msg_data))
# This helps avoid timout errors when the server is slow to respond.
messages = await self.__receive_messages(ws, 2)
self.assertEqual(len(messages), 2)
# First response: progress message.
progress, result = messages
self.assertEqual(progress.get('action'), 'progress')
self.assertEqual(progress.get('message'), 'Analysis complete')
self.assertEqual(progress.get('progress'), 1.0)
self.assertEqual(progress.get('taskType'), 'Analysis')
# Second response: result message.
self.assertEqual(result.get('action'), 'analysis_complete')
self.assertEqual(result.get('taskType'), 'Analysis')
expected_data = await self.__dummy_analyze(self.transactions, self.websocket_manager)
self.assertEqual(result.get('result'), expected_data)
await ws.close()
async def test_summary_action(self):
"""Test that sending a 'summary' action yields progress and result messages."""
self.transactions = [
self.create_transaction_dict('2022-01-01', 'Transaction 1', -100.0, 100.0),
self.create_transaction_dict('2022-01-02', 'Transaction 2', 200.0, 300.0),
]
# Patch the summarizer so that it returns a predictable result,
# and patch create_task with our dummy version.
with patch(
'src.app.app_instance.summarize_transactions',
new=AsyncMock(side_effect=self.__dummy_summarize),
), patch("asyncio.create_task", self.__dummy_create_task):
ws = await self.client.ws_connect('/ws')
msg_data = {'action': 'summary', 'transactions': self.transactions}
await ws.send_str(json.dumps(msg_data))
# This helps avoid timout errors when the server is slow to respond.
messages = await self.__receive_messages(ws, 2)
self.assertEqual(len(messages), 2)
# First response: progress message.
progress, result = messages
self.assertEqual(progress.get('action'), 'progress')
self.assertEqual(progress.get('message'), 'Summary complete')
self.assertEqual(progress.get('progress'), 1.0)
self.assertEqual(progress.get('taskType'), 'Summarize')
# Second response: result message.
self.assertEqual(result.get('action'), 'summary_complete')
self.assertEqual(result.get('taskType'), 'Summarize')
expected_data = await self.__dummy_summarize(self.transactions, self.websocket_manager)
self.assertEqual(result.get('result'), expected_data)
await ws.close()
async def test_unknown_action(self):
"""Test that an unknown action returns an error message."""
ws = await self.client.ws_connect('/ws')
msg_data = {'action': 'nonexistent'}
await ws.send_str(json.dumps(msg_data))
# This helps avoid timout errors when the server is slow to respond.
messages = await self.__receive_messages(ws, 1)
self.assertEqual(len(messages), 1)
error = messages[0]
self.assertEqual(error.get('action'), 'error')
self.assertEqual(error.get('taskType'), 'Error')
self.assertEqual(error.get('result'), {'message': 'Unknown action'})
await ws.close()
async def test_message_processing_exception(self):
"""Test that sending invalid JSON produces an error message."""
ws = await self.client.ws_connect('/ws')
await ws.send_str('invalid json')
# This helps avoid timout errors when the server is slow to respond.
messages = await self.__receive_messages(ws, 1)
self.assertEqual(len(messages), 1)
error = messages[0]
self.assertEqual(error.get('action'), 'error')
self.assertEqual(error.get('taskType'), 'Error')
self.assertEqual(error.get('result'), {'error': 'Expecting value: line 1 column 1 (char 0)'})
await ws.close()
async def test_close_on_error(self):
"""Test that when a client closes the connection, the WebSocket is removed from the app."""
ws = await self.client.ws_connect('/ws')
await ws.send_str('invalid json')
# This helps avoid timout errors when the server is slow to respond.
messages = await self.__receive_messages(ws, 1)
self.assertEqual(len(messages), 1)
await ws.close()
self.assertNotIn(ws, self.app[WEBSOCKETS])
Overlooking the dummy data generators, the __receive_messages
helper is crucial for accumulating WebSocket messages. Without it, attempting await ws.receive_json(...)
multiple times could lead to timeout errors, resulting in cryptic tracebacks:
----------------------------------------------------------------------
Traceback (most recent call last):
File ".../utility/virtualenv/lib/python3.13/site-packages/aiohttp/client_ws.py", line 332, in receive
msg = await self._reader.read()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "aiohttp/_websocket/reader_c.py", line 109, in read
File "aiohttp/_websocket/reader_c.py", line 106, in aiohttp._websocket.reader_c.WebSocketDataQueue.read
asyncio.exceptions.CancelledError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py", line 725, in run_until_complete
return future.result()
~~~~~~~~~~~~~^^
File ".../utility/tests/app/websocket_handler/test_integration.py", line 114, in __receive_messages
msg = await ws.receive_json(timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../utility/virtualenv/lib/python3.13/site-packages/aiohttp/client_ws.py", line 331, in receive
async with async_timeout.timeout(receive_timeout):
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/timeouts.py", line 116, in __aexit__
raise TimeoutError from exc_val
TimeoutError
The helper also aids in filtering messages of interest. We also created a dummy version of ping_server
to properly close it and prevent memory leaks. With the dummy functions in place, we created test cases that interact with our WebSocket endpoint. Using async patches and mocks, we fed predictable responses to the tests. Note that we used our async dummy methods as the side_effect
of the AsyncMock
. Using return_value
instead of side_effect
in the mocks prolonged the processes and caused timeout errors.
The other test cases handle various scenarios to provide better test coverage.
When supplying file paths in patch
, use the path where the program is used, not where it was defined. For instance, src.app.app_instance.analyze_transactions
was defined in src/utils/analyzer.py
but since it was used in src/app/app_instance.py
, we used src.app.app_instance.analyze_transactions
.
However, the integration testing approach poses some limitations. We can't modify the internals of the aiohttp
WebSocket instance. This is where unit testing comes to the rescue, as we can modify internals and mock them as needed to thoroughly test the desired feature. Hence the other test file for our WebSocket, `tests/app/websocket_handler/test_ping.py`.
To wrap up, let's see how we tested the src/utils/analyze.py
:
import uuid
from datetime import datetime, timedelta
from unittest.mock import patch
import torch
from src.utils.analyzer import analyze_transactions, classify_transactions
from src.utils.base import (
analyze_recurring_transactions,
analyze_spending,
calculate_financial_health,
detect_anomalies,
predict_trends,
validate_and_convert_transactions,
)
from tests import BaseAsyncTestClass
class TestAnalyzer(BaseAsyncTestClass):
@patch(
'src.utils.analyzer.pipeline', return_value=lambda *args, **kwargs: [{'labels': ['groceries'], 'scores': [1.0]}]
)
async def test_analyze_transactions_valid(self, mock_pipeline):
tx_data = [
{
'_id': str(uuid.uuid4()),
'date': '2024-01-01T00:00:00',
'createdAt': '2024-01-01T00:00:00',
'updatedAt': '2024-01-01T00:00:00',
'description': 'Test expense',
'amount': -100,
'balance': 900,
'type': 'expense',
'userId': '1',
},
{
'_id': str(uuid.uuid4()),
'date': '2024-01-02T00:00:00',
'createdAt': '2024-01-02T00:00:00',
'updatedAt': '2024-01-02T00:00:00',
'description': 'Salary',
'amount': 2000,
'balance': 2900,
'type': 'income',
'userId': '1',
},
]
result = await analyze_transactions(tx_data)
self.assertIn('categories', result)
...
async def test_classify_transactions_pattern_matching(self):
"""
Test that transactions with descriptions matching common patterns
are categorized correctly without invoking the ML pipeline.
"""
# Create dummy transactions that should match predefined patterns
tx1 = self.create_transaction_dict('2024-01-01T00:00:00', 'Walmart grocery purchase', -50.0, 950.0, 'expense')
tx2 = self.create_transaction_dict('2024-01-02T00:00:00', 'Uber ride', -20.0, 930.0, 'expense')
tx3 = self.create_transaction_dict('2024-01-03T00:00:00', 'Netflix subscription', -15.0, 915.0, 'expense')
tx4 = self.create_transaction_dict('2024-01-04T00:00:00', 'Salary', 3000.0, 3915.0, 'income')
# We assume pattern matching is applied first
transactions = await validate_and_convert_transactions([tx1, tx2, tx3, tx4])
# Call classify_transactions without a WebSocket manager
result = await classify_transactions(transactions)
categories = result.get('expenses', {})
income_total = result.get('income', 0)
# Check that the descriptions are mapped to expected categories:
# "Walmart grocery" should fall under 'groceries'
self.assertIn('groceries', categories)
self.assertGreater(categories['groceries'], 0)
# "Uber ride" should fall under 'transportation'
self.assertIn('transportation', categories)
self.assertGreater(categories['transportation'], 0)
# "Netflix subscription" should be captured under 'subscriptions'
self.assertIn('subscriptions', categories)
self.assertGreater(categories['subscriptions'], 0)
# Income should include the salary
self.assertEqual(income_total, 3000)
@patch.dict(
'os.environ',
{"LABELS": "groceries,housing,transportation,entertainment,utilities,education,credit_cards,insurance,other"},
)
@patch('src.utils.analyzer.pipeline')
async def test_classify_transactions_ml_fallback(self, mock_pipeline):
# Simulate a transaction with an unmatched description
tx1 = self.create_transaction_dict(
'2024-01-05T00:00:00', 'Unusual expense with no pattern', -75.0, 840.0, 'expense'
)
# Setup the fake pipeline result
fake_result = [{'labels': ['other'], 'scores': [0.95]}]
mock_pipeline.return_value = lambda *args, **kwargs: fake_result
transactions = await validate_and_convert_transactions([tx1])
result = await classify_transactions(transactions)
categories = result.get('expenses', {})
# Expect that the ML fallback has assigned this expense to 'other'
self.assertIn('other', categories)
self.assertAlmostEqual(categories['other'], 75 * 0.95, places=2)
async def test_analyze_recurring_transactions_monthly(self):
"""
Test that transactions with the same description and a roughly monthly interval
are detected as recurring.
"""
base_date = datetime(2024, 1, 1)
# Create 3 monthly transactions (interval ~30 days)
tx1 = self.create_transaction_dict((base_date).isoformat(), 'Gym membership', -50.0, 950.0, 'expense')
tx2 = self.create_transaction_dict(
(base_date + timedelta(days=30)).isoformat(), 'Gym membership', -50.0, 900.0, 'expense'
)
tx3 = self.create_transaction_dict(
(base_date + timedelta(days=60)).isoformat(), 'Gym membership', -50.0, 850.0, 'expense'
)
transactions = await validate_and_convert_transactions([tx1, tx2, tx3])
recurring = analyze_recurring_transactions(transactions)
self.assertTrue(len(recurring) > 0)
monthly_recurring = next((r for r in recurring if r['frequency'] == 'Monthly'), None)
self.assertIsNotNone(monthly_recurring)
self.assertEqual(monthly_recurring['description'], 'gym membership')
async def test_analyze_recurring_transactions_weekly(self):
"""
Test that transactions with the same description and a roughly weekly interval
are detected as recurring.
"""
base_date = datetime(2024, 1, 1)
# Create 3 weekly transactions (interval ~7 days)
tx1 = self.create_transaction_dict((base_date).isoformat(), 'Weekly yoga class', -20.0, 980.0, 'expense')
tx2 = self.create_transaction_dict(
(base_date + timedelta(days=7)).isoformat(), 'Weekly yoga class', -20.0, 960.0, 'expense'
)
tx3 = self.create_transaction_dict(
(base_date + timedelta(days=14)).isoformat(), 'Weekly yoga class', -20.0, 940.0, 'expense'
)
transactions = await validate_and_convert_transactions([tx1, tx2, tx3])
recurring = analyze_recurring_transactions(transactions)
self.assertTrue(len(recurring) > 0)
weekly_recurring = next((r for r in recurring if r['frequency'] == 'Weekly'), None)
self.assertIsNotNone(weekly_recurring)
self.assertEqual(weekly_recurring['description'], 'weekly yoga class')
...
def test_edge_empty_transactions(self):
"""
Ensure that functions gracefully handle an empty list of transactions.
"""
# predict_trends should return a message indicating insufficient data
trends = predict_trends([])
self.assertIn('trend', trends)
self.assertEqual(trends['trend'], 'Not enough data')
# calculate_financial_health on empty list should not crash (might return infinity or 0)
health = calculate_financial_health([])
self.assertIn('debt_to_income_ratio', health)
self.assertIn('savings_rate', health)
self.assertIn('balance_growth_rate', health)
self.assertIn('financial_health_score', health)
async def test_not_transaction_analyzer(self):
"""
Ensure that functions gracefully handle invalid transaction data.
"""
analysis = await analyze_transactions(None, self.websocket_manager)
self.assertIn('error', analysis)
self.assertEqual(analysis['error'], 'No transactions provided')
self.assertTrue(self.fake_ws.messages)
async def test_analyze_transactions_with_websocket(self):
"""Test the analyze_transactions function with a WebSocketManager."""
# Create an invalid transaction
tx_1 = [
{
'_id': str(uuid.uuid4()),
'date': '2024-01-01T00:00:00',
'createdAt': '2024-01-01T00:00:00',
'updatedAt': '2024-01-01T00:00:00',
'description': 'Test expense',
'amount': -100,
'type': 'expense',
'userId': '1',
}
]
result = await analyze_transactions(tx_1, self.websocket_manager)
self.assertIn('error', result)
self.assertEqual(result['error'], 'No valid transactions provided')
# Check that progress messages were sent
self.assertTrue(self.fake_ws.messages)
# Create valid transactions
tx_2 = self.create_transaction_dict('2024-01-01T00:00:00', 'Salary', 2000, 2900, 'income')
result = await analyze_transactions([tx_2], self.websocket_manager)
self.assertIn('categories', result)
# Check that progress messages were sent
self.assertTrue(self.fake_ws.messages)
@patch('src.utils.analyzer.validate_and_convert_transactions')
async def test_analyze_transactions_validation_exception(self, mock_validate):
"""Test analyze_transactions handling when validation fails"""
valid_tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Salary', 2000, 2900, 'income')
mock_validate.side_effect = ValueError('Mock validation error')
result = await analyze_transactions(valid_tx, self.websocket_manager)
self.assertIn('error', result)
msg = self.fake_ws.messages[-1]
self.assertEqual(msg['action'], 'progress')
self.assertEqual(msg['message'], 'Analysis failed')
self.assertEqual(msg['taskType'], 'Analysis')
@patch('src.utils.analyzer.pipeline')
async def test_classify_transactions_exception(self, mock_pipeline):
"""Test classify_transactions handling when pipeline fails"""
tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
transactions = await validate_and_convert_transactions([tx])
mock_pipeline.side_effect = RuntimeError('Mock pipeline error')
result = await classify_transactions(transactions, self.websocket_manager)
# Check error response
self.assertIn('error', result)
self.assertTrue('Classification failed' in result['error'])
# Check websocket message
msg = self.fake_ws.messages[-1]
self.assertEqual(msg['action'], 'progress')
self.assertEqual(msg['message'], 'Analysis failed')
self.assertEqual(msg['taskType'], 'Analysis')
@patch('src.utils.analyzer.get_device')
@patch('src.utils.analyzer.pipeline')
async def test_classify_transactions_device_cpu(self, mock_pipeline, mock_device):
"""Test that classify_transactions uses CPU device for the pipeline."""
tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
transactions = await validate_and_convert_transactions([tx])
mock_device.return_value = (torch.device('cpu'), 'CPU')
await classify_transactions(transactions, self.websocket_manager)
mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=-1)
@patch('src.utils.analyzer.get_device')
@patch('src.utils.analyzer.pipeline')
async def test_classify_transactions_device_gpu(self, mock_pipeline, mock_device):
"""Test that classify_transactions uses GPU device for the pipeline."""
tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
transactions = await validate_and_convert_transactions([tx])
mock_device.return_value = (torch.device('cuda'), 'GPU')
await classify_transactions(transactions, self.websocket_manager)
mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=0)
@patch('src.utils.analyzer.get_device')
@patch('src.utils.analyzer.pipeline')
async def test_classify_transactions_device_mps(self, mock_pipeline, mock_device):
"""Test that classify_transactions uses MPS (Apple Metal) device for the pipeline."""
tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
transactions = await validate_and_convert_transactions([tx])
mock_device.return_value = (torch.device('mps'), 'MPS (Apple Metal)')
await classify_transactions(transactions, self.websocket_manager)
mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=0)
This thorough testing allows us to have confidence in the reliability of our code. The repository's tests
folder contains other test files that rigorously test our implementations. Currently, we have 100% test coverage on the AI service, and static analysis is enforced.
We will stop here. In the next article, we will return to implementing the dashboard.
Outro
Enjoyed this article? I'm a Software Engineer, Technical Writer and Technical Support Engineer actively seeking new opportunities, particularly in areas related to web security, finance, healthcare, and education. If you think my expertise aligns with your team's needs, let's chat! You can find me on LinkedIn and X. I am also an email away.
If you found this article valuable, consider sharing it with your network to help spread the knowledge!