Advanced Features
Explore advanced FastAPI features including background tasks, WebSockets, middleware, events, custom responses, and performance optimization techniques.
Background Tasks
Basic Background Tasks
from fastapi import BackgroundTasks
def write_log(message: str):
with open("log.txt", "a") as log:
log.write(f"{message}\n")
@app.post("/send-notification/")
async def send_notification(
email: str,
background_tasks: BackgroundTasks
):
background_tasks.add_task(write_log, f"Notification sent to {email}")
return {"message": "Notification sent in the background"}
Multiple Background Tasks
def send_email(email: str, message: str):
# Simulate sending email
time.sleep(2)
print(f"Email sent to {email}: {message}")
def update_database(user_id: int):
# Simulate database update
time.sleep(1)
print(f"Database updated for user {user_id}")
@app.post("/process/")
async def process_request(
email: str,
user_id: int,
background_tasks: BackgroundTasks
):
background_tasks.add_task(send_email, email, "Processing started")
background_tasks.add_task(update_database, user_id)
background_tasks.add_task(send_email, email, "Processing completed")
return {"message": "Request accepted"}
Background Task with Dependencies
from sqlalchemy.orm import Session
def cleanup_old_records(db: Session):
# Delete old records
db.query(Model).filter(Model.created_at < datetime.now() - timedelta(days=30)).delete()
db.commit()
@app.post("/trigger-cleanup/")
async def trigger_cleanup(
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
background_tasks.add_task(cleanup_old_records, db)
return {"message": "Cleanup scheduled"}
WebSockets
Basic WebSocket
from fastapi import WebSocket
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
print("Client disconnected")
WebSocket Connection Manager
from typing import List
from fastapi import WebSocket, WebSocketDisconnect
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
WebSocket with Authentication
from fastapi import WebSocket, Query, WebSocketException, status
async def get_token(websocket: WebSocket):
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return token
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Depends(get_token)
):
user = verify_token(token)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
# ... handle WebSocket communication
Middleware
Custom Middleware
from fastapi import Request
import time
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
Request Logging Middleware
import logging
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response: {response.status_code}")
return response
CORS Middleware
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "https://yourdomain.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Total-Count"],
max_age=3600,
)
Gzip Compression Middleware
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
Trusted Host Middleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
)
Application Events
Startup Events
@app.on_event("startup")
async def startup_event():
# Initialize resources
print("Starting up...")
# Connect to database
# Load ML models
# Initialize cache
@app.on_event("startup")
async def create_db_tables():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
Shutdown Events
@app.on_event("shutdown")
async def shutdown_event():
# Clean up resources
print("Shutting down...")
# Close database connections
# Save state
# Release resources
Lifespan Context Manager (Modern Approach)
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Starting up...")
# Initialize resources
yield
# Shutdown
print("Shutting down...")
# Clean up resources
app = FastAPI(lifespan=lifespan)
Custom Response Classes
JSON Response with Custom Headers
from fastapi.responses import JSONResponse
@app.get("/custom-response/")
def custom_response():
content = {"message": "Hello World"}
headers = {
"X-Custom-Header": "Custom Value",
"Cache-Control": "max-age=3600"
}
return JSONResponse(content=content, headers=headers, status_code=200)
HTML Response
from fastapi.responses import HTMLResponse
@app.get("/html/", response_class=HTMLResponse)
def get_html():
html_content = """
<!DOCTYPE html>
<html>
<head><title>FastAPI</title></head>
<body>
<h1>Hello from FastAPI!</h1>
</body>
</html>
"""
return html_content
File Response
from fastapi.responses import FileResponse
@app.get("/download/{filename}")
def download_file(filename: str):
file_path = f"./files/{filename}"
return FileResponse(
path=file_path,
filename=filename,
media_type="application/octet-stream"
)
Streaming Response
from fastapi.responses import StreamingResponse
import csv
from io import StringIO
@app.get("/export/")
def export_data():
def generate():
output = StringIO()
writer = csv.writer(output)
writer.writerow(["ID", "Name", "Email"])
for i in range(1000):
writer.writerow([i, f"User {i}", f"user{i}@example.com"])
output.seek(0)
data = output.read()
output.seek(0)
output.truncate(0)
yield data
return StreamingResponse(
generate(),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=export.csv"}
)
Redirect Response
from fastapi.responses import RedirectResponse
@app.get("/old-url")
def old_url():
return RedirectResponse(url="/new-url", status_code=301)
Server-Sent Events (SSE)
from fastapi.responses import StreamingResponse
import asyncio
@app.get("/stream/")
async def stream_events():
async def event_generator():
for i in range(10):
# Send SSE formatted data
yield f"data: Message {i}\n\n"
await asyncio.sleep(1)
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
Custom Exception Handlers
Global Exception Handler
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": exc.errors(),
"body": exc.body,
"message": "Validation error occurred"
}
)
class CustomException(Exception):
def __init__(self, name: str):
self.name = name
@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something wrong."}
)
GraphQL Integration
Using Strawberry GraphQL
pip install strawberry-graphql[fastapi]
import strawberry
from strawberry.fastapi import GraphQLRouter
@strawberry.type
class Query:
@strawberry.field
def hello(self, name: str = "World") -> str:
return f"Hello {name}"
@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, name: str, email: str) -> str:
# Create user logic
return f"User {name} created"
schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema)
app.include_router(graphql_app, prefix="/graphql")
Template Rendering
Using Jinja2
pip install jinja2
from fastapi.templating import Jinja2Templates
from fastapi import Request
templates = Jinja2Templates(directory="templates")
@app.get("/page/{item_id}", response_class=HTMLResponse)
async def read_item(request: Request, item_id: int):
return templates.TemplateResponse(
"item.html",
{"request": request, "item_id": item_id}
)
Static Files
from fastapi.staticfiles import StaticFiles
app.mount("/static", StaticFiles(directory="static"), name="static")
# Access files at: http://localhost:8000/static/style.css
Request Context and State
Application State
from fastapi import Request
@app.on_event("startup")
async def startup():
app.state.db = create_db_connection()
app.state.cache = create_cache()
@app.get("/items/")
async def read_items(request: Request):
db = request.app.state.db
cache = request.app.state.cache
# Use db and cache
return {"items": []}
Request State
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request.state.request_id = generate_request_id()
response = await call_next(request)
response.headers["X-Request-ID"] = request.state.request_id
return response
@app.get("/items/")
async def read_items(request: Request):
request_id = request.state.request_id
return {"request_id": request_id}
Custom OpenAPI Schema
Customizing OpenAPI
from fastapi.openapi.utils import get_openapi
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="My Custom API",
version="2.0.0",
description="This is a custom OpenAPI schema",
routes=app.routes,
)
# Add custom fields
openapi_schema["info"]["x-logo"] = {
"url": "https://example.com/logo.png"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
Concurrency and Async
Mixing Sync and Async
import asyncio
# Async endpoint
@app.get("/async/")
async def async_endpoint():
await asyncio.sleep(1)
return {"message": "Async"}
# Sync endpoint
@app.get("/sync/")
def sync_endpoint():
time.sleep(1)
return {"message": "Sync"}
# Mixed: calling sync code from async
@app.get("/mixed/")
async def mixed_endpoint():
# Run sync function in thread pool
result = await asyncio.to_thread(blocking_function)
return {"result": result}
Concurrent Requests
import httpx
@app.get("/aggregate/")
async def aggregate_data():
async with httpx.AsyncClient() as client:
# Make concurrent requests
tasks = [
client.get("https://api1.example.com/data"),
client.get("https://api2.example.com/data"),
client.get("https://api3.example.com/data")
]
responses = await asyncio.gather(*tasks)
return {
"data": [r.json() for r in responses]
}
Performance Optimization
Caching with Redis
import redis
import json
redis_client = redis.Redis(host='localhost', port=6379, db=0)
@app.get("/items/{item_id}")
async def read_item(item_id: int, db: Session = Depends(get_db)):
# Check cache
cache_key = f"item:{item_id}"
cached = redis_client.get(cache_key)
if cached:
return json.loads(cached)
# Query database
item = db.query(Item).filter(Item.id == item_id).first()
# Store in cache (expire after 1 hour)
redis_client.setex(cache_key, 3600, json.dumps(item.dict()))
return item
Connection Pooling
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=20,
max_overflow=0,
pool_pre_ping=True,
pool_recycle=3600
)
Response Caching Headers
from fastapi import Response
@app.get("/cached-data/")
def cached_data(response: Response):
response.headers["Cache-Control"] = "public, max-age=3600"
response.headers["ETag"] = "some-etag-value"
return {"data": "This is cached"}
Advanced Dependency Patterns
Dependency with Yield and Finally
async def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_resource():
resource = acquire_resource()
try:
yield resource
except Exception as e:
# Handle error
logger.error(f"Error: {e}")
raise
finally:
release_resource(resource)
Best Practices
- Use Background Tasks: For operations that don't need to block response
- WebSocket Manager: Implement proper connection management for WebSockets
- Middleware Order: Add middleware in correct order (CORS before others)
- Async Operations: Use async for I/O-bound operations
- Caching: Cache expensive operations and database queries
- Connection Pooling: Configure appropriate pool sizes
- Error Handling: Implement global exception handlers
- Monitoring: Add request tracking and performance monitoring
- Resource Management: Properly close connections and release resources
- Testing: Test advanced features thoroughly
These advanced features enable you to build sophisticated, high-performance FastAPI applications that can handle complex real-world requirements.