Skip to main content

Middleware & Applications

Advanced middleware patterns and custom application architectures for WSGI and ASGI.

WSGI Middleware Patterns

Request/Response Middleware

class RequestResponseMiddleware:
def __init__(self, app):
self.app = app

def __call__(self, environ, start_response):
# Pre-process request
request_data = self.process_request(environ)

# Capture response
response_data = {'status': None, 'headers': None, 'body': []}

def capture_start_response(status, headers, exc_info=None):
response_data['status'] = status
response_data['headers'] = headers
return start_response(status, headers, exc_info)

# Call application
app_iter = self.app(environ, capture_start_response)

# Capture response body
try:
for data in app_iter:
response_data['body'].append(data)
yield data
finally:
if hasattr(app_iter, 'close'):
app_iter.close()

# Post-process response
self.process_response(request_data, response_data)

def process_request(self, environ):
return {
'method': environ['REQUEST_METHOD'],
'path': environ['PATH_INFO'],
'timestamp': time.time()
}

def process_response(self, request_data, response_data):
duration = time.time() - request_data['timestamp']
print(f"Request: {request_data['method']} {request_data['path']} - {duration:.3f}s")

Session Middleware

import uuid
import time
import json
from urllib.parse import parse_qs

class SessionMiddleware:
def __init__(self, app, secret_key, session_store=None):
self.app = app
self.secret_key = secret_key
self.session_store = session_store or {}
self.session_timeout = 3600 # 1 hour

def __call__(self, environ, start_response):
# Get session ID from cookie
session_id = self.get_session_id(environ)
if not session_id:
session_id = str(uuid.uuid4())

# Load session data
session_data = self.load_session(session_id)
environ['session'] = session_data
environ['session_id'] = session_id

# Capture response to set cookie
def session_start_response(status, headers, exc_info=None):
# Save session
self.save_session(session_id, session_data)

# Add session cookie
cookie = f"session_id={session_id}; HttpOnly; Path=/; Max-Age={self.session_timeout}"
headers.append(('Set-Cookie', cookie))

return start_response(status, headers, exc_info)

return self.app(environ, session_start_response)

def get_session_id(self, environ):
cookie_header = environ.get('HTTP_COOKIE', '')
for cookie in cookie_header.split(';'):
if '=' in cookie:
key, value = cookie.strip().split('=', 1)
if key == 'session_id':
return value
return None

def load_session(self, session_id):
session = self.session_store.get(session_id, {})
if session.get('expires', 0) < time.time():
return {}
return session.get('data', {})

def save_session(self, session_id, data):
self.session_store[session_id] = {
'data': data,
'expires': time.time() + self.session_timeout
}

Rate Limiting Middleware

import time
from collections import defaultdict

class RateLimitMiddleware:
def __init__(self, app, requests_per_minute=60, burst_size=10):
self.app = app
self.requests_per_minute = requests_per_minute
self.burst_size = burst_size
self.clients = defaultdict(list)
self.cleanup_interval = 60
self.last_cleanup = time.time()

def __call__(self, environ, start_response):
client_ip = self.get_client_ip(environ)

# Cleanup old entries
if time.time() - self.last_cleanup > self.cleanup_interval:
self.cleanup_old_entries()

# Check rate limit
if not self.is_allowed(client_ip):
start_response('429 Too Many Requests', [
('Content-Type', 'text/plain'),
('Retry-After', '60')
])
return [b'Rate limit exceeded']

# Record request
self.record_request(client_ip)

return self.app(environ, start_response)

def get_client_ip(self, environ):
# Check for proxy headers
forwarded_for = environ.get('HTTP_X_FORWARDED_FOR')
if forwarded_for:
return forwarded_for.split(',')[0].strip()

real_ip = environ.get('HTTP_X_REAL_IP')
if real_ip:
return real_ip

return environ.get('REMOTE_ADDR', '127.0.0.1')

def is_allowed(self, client_ip):
now = time.time()
client_requests = self.clients[client_ip]

# Remove old requests
client_requests[:] = [req_time for req_time in client_requests if now - req_time < 60]

# Check limits
return len(client_requests) < self.requests_per_minute

def record_request(self, client_ip):
self.clients[client_ip].append(time.time())

def cleanup_old_entries(self):
now = time.time()
for client_ip in list(self.clients.keys()):
self.clients[client_ip][:] = [req_time for req_time in self.clients[client_ip] if now - req_time < 60]
if not self.clients[client_ip]:
del self.clients[client_ip]
self.last_cleanup = now

ASGI Middleware Patterns

Request/Response Middleware

class ASGIRequestResponseMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope['type'] == 'http':
await self.handle_http(scope, receive, send)
else:
await self.app(scope, receive, send)

async def handle_http(self, scope, receive, send):
# Pre-process request
request_data = await self.process_request(scope, receive)

# Capture response
response_data = {'status': None, 'headers': None, 'body': b''}

async def capture_send(message):
if message['type'] == 'http.response.start':
response_data['status'] = message['status']
response_data['headers'] = message.get('headers', [])
elif message['type'] == 'http.response.body':
response_data['body'] += message.get('body', b'')

await send(message)

# Call application
await self.app(scope, receive, capture_send)

# Post-process response
await self.process_response(request_data, response_data)

async def process_request(self, scope, receive):
return {
'method': scope['method'],
'path': scope['path'],
'timestamp': time.time()
}

async def process_response(self, request_data, response_data):
duration = time.time() - request_data['timestamp']
print(f"Request: {request_data['method']} {request_data['path']} - {duration:.3f}s")

WebSocket Middleware

class WebSocketMiddleware:
def __init__(self, app):
self.app = app
self.connections = set()

async def __call__(self, scope, receive, send):
if scope['type'] == 'websocket':
await self.handle_websocket(scope, receive, send)
else:
await self.app(scope, receive, send)

async def handle_websocket(self, scope, receive, send):
connection_id = id(send)

async def websocket_send(message):
if message['type'] == 'websocket.accept':
self.connections.add(connection_id)
print(f"WebSocket connected: {connection_id}")
elif message['type'] == 'websocket.close':
self.connections.discard(connection_id)
print(f"WebSocket disconnected: {connection_id}")

await send(message)

try:
await self.app(scope, receive, websocket_send)
finally:
self.connections.discard(connection_id)

async def broadcast(self, message):
"""Broadcast message to all connected WebSockets"""
for connection_id in self.connections.copy():
try:
# Note: This is a simplified example
# In practice, you'd need to store the actual send callable
pass
except Exception as e:
self.connections.discard(connection_id)

Authentication Middleware

import jwt
import json

class JWTAuthMiddleware:
def __init__(self, app, secret_key, protected_paths=None):
self.app = app
self.secret_key = secret_key
self.protected_paths = protected_paths or []

async def __call__(self, scope, receive, send):
if scope['type'] == 'http' and self.is_protected_path(scope['path']):
if not await self.authenticate(scope, receive, send):
return

await self.app(scope, receive, send)

def is_protected_path(self, path):
return any(path.startswith(protected) for protected in self.protected_paths)

async def authenticate(self, scope, receive, send):
# Get Authorization header
headers = dict(scope.get('headers', []))
auth_header = headers.get(b'authorization', b'').decode()

if not auth_header.startswith('Bearer '):
await self.unauthorized(send)
return False

token = auth_header[7:]
try:
payload = jwt.decode(token, self.secret_key, algorithms=['HS256'])
scope['user'] = payload
return True
except jwt.InvalidTokenError:
await self.unauthorized(send)
return False

async def unauthorized(self, send):
await send({
'type': 'http.response.start',
'status': 401,
'headers': [[b'content-type', b'application/json']],
})
await send({
'type': 'http.response.body',
'body': json.dumps({'error': 'Unauthorized'}).encode('utf-8'),
})

Custom Application Architectures

MVC Pattern

class Controller:
def __init__(self, model, view):
self.model = model
self.view = view

async def handle_request(self, scope, receive, send):
method = scope['method']
path = scope['path']

# Route to appropriate action
if method == 'GET' and path == '/users':
await self.list_users(scope, receive, send)
elif method == 'GET' and path.startswith('/users/'):
user_id = path.split('/')[-1]
await self.get_user(scope, receive, send, user_id)
elif method == 'POST' and path == '/users':
await self.create_user(scope, receive, send)
else:
await self.not_found(scope, receive, send)

async def list_users(self, scope, receive, send):
users = await self.model.get_all_users()
response = self.view.render_users_list(users)
await self.send_response(send, response)

async def get_user(self, scope, receive, send, user_id):
user = await self.model.get_user(user_id)
if user:
response = self.view.render_user(user)
await self.send_response(send, response)
else:
await self.not_found(scope, receive, send)

async def create_user(self, scope, receive, send):
body = await self.parse_body(receive)
user = await self.model.create_user(body)
response = self.view.render_user(user)
await self.send_response(send, response, status=201)

async def send_response(self, send, response, status=200):
await send({
'type': 'http.response.start',
'status': status,
'headers': [[b'content-type', b'application/json']],
})
await send({
'type': 'http.response.body',
'body': json.dumps(response).encode('utf-8'),
})

async def not_found(self, scope, receive, send):
await send({
'type': 'http.response.start',
'status': 404,
'headers': [[b'content-type', b'application/json']],
})
await send({
'type': 'http.response.body',
'body': json.dumps({'error': 'Not found'}).encode('utf-8'),
})

class Model:
def __init__(self, db_pool):
self.db_pool = db_pool

async def get_all_users(self):
async with self.db_pool.acquire() as conn:
return await conn.fetch("SELECT * FROM users")

async def get_user(self, user_id):
async with self.db_pool.acquire() as conn:
return await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)

async def create_user(self, user_data):
async with self.db_pool.acquire() as conn:
return await conn.fetchrow(
"INSERT INTO users (name, email) VALUES ($1, $2) RETURNING *",
user_data['name'], user_data['email']
)

class View:
def render_users_list(self, users):
return {'users': [dict(user) for user in users]}

def render_user(self, user):
return {'user': dict(user)}

# Application setup
async def create_app():
db_pool = await create_db_pool()
model = Model(db_pool)
view = View()
controller = Controller(model, view)

async def app(scope, receive, send):
await controller.handle_request(scope, receive, send)

return app

Plugin Architecture

class PluginManager:
def __init__(self):
self.plugins = {}
self.hooks = defaultdict(list)

def register_plugin(self, name, plugin):
self.plugins[name] = plugin
plugin.register_hooks(self)

def register_hook(self, hook_name, callback):
self.hooks[hook_name].append(callback)

async def call_hook(self, hook_name, *args, **kwargs):
results = []
for callback in self.hooks[hook_name]:
result = await callback(*args, **kwargs)
results.append(result)
return results

class Plugin:
def register_hooks(self, plugin_manager):
raise NotImplementedError

class LoggingPlugin(Plugin):
def register_hooks(self, plugin_manager):
plugin_manager.register_hook('request_started', self.log_request)
plugin_manager.register_hook('request_finished', self.log_response)

async def log_request(self, scope, receive, send):
print(f"Request started: {scope['method']} {scope['path']}")

async def log_response(self, scope, response_data):
print(f"Request finished: {scope['method']} {scope['path']} - {response_data['status']}")

class CachePlugin(Plugin):
def __init__(self):
self.cache = {}

def register_hooks(self, plugin_manager):
plugin_manager.register_hook('before_response', self.cache_response)

async def cache_response(self, scope, response_data):
if scope['method'] == 'GET':
cache_key = f"{scope['method']}:{scope['path']}"
self.cache[cache_key] = response_data

class PluginApplication:
def __init__(self):
self.plugin_manager = PluginManager()
self.setup_plugins()

def setup_plugins(self):
self.plugin_manager.register_plugin('logging', LoggingPlugin())
self.plugin_manager.register_plugin('cache', CachePlugin())

async def __call__(self, scope, receive, send):
# Call request started hooks
await self.plugin_manager.call_hook('request_started', scope, receive, send)

# Handle request
response_data = await self.handle_request(scope, receive, send)

# Call response hooks
await self.plugin_manager.call_hook('before_response', scope, response_data)
await self.plugin_manager.call_hook('request_finished', scope, response_data)

async def handle_request(self, scope, receive, send):
# Basic request handling
await send({
'type': 'http.response.start',
'status': 200,
'headers': [[b'content-type', b'text/plain']],
})
await send({
'type': 'http.response.body',
'body': b'Hello World',
})
return {'status': 200}

Event-Driven Architecture

import asyncio
from typing import Callable, Any

class EventBus:
def __init__(self):
self.subscribers = defaultdict(list)

def subscribe(self, event_type: str, handler: Callable):
self.subscribers[event_type].append(handler)

def unsubscribe(self, event_type: str, handler: Callable):
self.subscribers[event_type].remove(handler)

async def publish(self, event_type: str, data: Any):
tasks = []
for handler in self.subscribers[event_type]:
tasks.append(asyncio.create_task(handler(data)))

if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

class Event:
def __init__(self, type: str, data: Any, timestamp: float = None):
self.type = type
self.data = data
self.timestamp = timestamp or time.time()

class EventDrivenApp:
def __init__(self):
self.event_bus = EventBus()
self.setup_event_handlers()

def setup_event_handlers(self):
self.event_bus.subscribe('user.created', self.send_welcome_email)
self.event_bus.subscribe('user.created', self.update_statistics)
self.event_bus.subscribe('order.placed', self.process_payment)
self.event_bus.subscribe('order.placed', self.update_inventory)

async def send_welcome_email(self, user_data):
print(f"Sending welcome email to {user_data['email']}")
# Email sending logic here

async def update_statistics(self, user_data):
print(f"Updating user statistics for {user_data['id']}")
# Statistics update logic here

async def process_payment(self, order_data):
print(f"Processing payment for order {order_data['id']}")
# Payment processing logic here

async def update_inventory(self, order_data):
print(f"Updating inventory for order {order_data['id']}")
# Inventory update logic here

async def create_user(self, user_data):
# Create user in database
user = await self.save_user(user_data)

# Publish event
await self.event_bus.publish('user.created', user)

return user

async def place_order(self, order_data):
# Create order in database
order = await self.save_order(order_data)

# Publish event
await self.event_bus.publish('order.placed', order)

return order

This comprehensive guide covers advanced middleware patterns and application architectures for both WSGI and ASGI, providing flexible building blocks for complex applications.