Skip to main content

Advanced Features

Explore advanced FastAPI features including background tasks, WebSockets, middleware, events, custom responses, and performance optimization techniques.

Background Tasks

Basic Background Tasks

from fastapi import BackgroundTasks

def write_log(message: str):
with open("log.txt", "a") as log:
log.write(f"{message}\n")

@app.post("/send-notification/")
async def send_notification(
email: str,
background_tasks: BackgroundTasks
):
background_tasks.add_task(write_log, f"Notification sent to {email}")
return {"message": "Notification sent in the background"}

Multiple Background Tasks

def send_email(email: str, message: str):
# Simulate sending email
time.sleep(2)
print(f"Email sent to {email}: {message}")

def update_database(user_id: int):
# Simulate database update
time.sleep(1)
print(f"Database updated for user {user_id}")

@app.post("/process/")
async def process_request(
email: str,
user_id: int,
background_tasks: BackgroundTasks
):
background_tasks.add_task(send_email, email, "Processing started")
background_tasks.add_task(update_database, user_id)
background_tasks.add_task(send_email, email, "Processing completed")

return {"message": "Request accepted"}

Background Task with Dependencies

from sqlalchemy.orm import Session

def cleanup_old_records(db: Session):
# Delete old records
db.query(Model).filter(Model.created_at < datetime.now() - timedelta(days=30)).delete()
db.commit()

@app.post("/trigger-cleanup/")
async def trigger_cleanup(
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
background_tasks.add_task(cleanup_old_records, db)
return {"message": "Cleanup scheduled"}

WebSockets

Basic WebSocket

from fastapi import WebSocket

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
print("Client disconnected")

WebSocket Connection Manager

from typing import List
from fastapi import WebSocket, WebSocketDisconnect

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)

async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")

WebSocket with Authentication

from fastapi import WebSocket, Query, WebSocketException, status

async def get_token(websocket: WebSocket):
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return token

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Depends(get_token)
):
user = verify_token(token)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await websocket.accept()
# ... handle WebSocket communication

Middleware

Custom Middleware

from fastapi import Request
import time

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response

Request Logging Middleware

import logging

logger = logging.getLogger(__name__)

@app.middleware("http")
async def log_requests(request: Request, call_next):
logger.info(f"Request: {request.method} {request.url}")

response = await call_next(request)

logger.info(f"Response: {response.status_code}")
return response

CORS Middleware

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "https://yourdomain.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Total-Count"],
max_age=3600,
)

Gzip Compression Middleware

from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000)

Trusted Host Middleware

from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
)

Application Events

Startup Events

@app.on_event("startup")
async def startup_event():
# Initialize resources
print("Starting up...")
# Connect to database
# Load ML models
# Initialize cache

@app.on_event("startup")
async def create_db_tables():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

Shutdown Events

@app.on_event("shutdown")
async def shutdown_event():
# Clean up resources
print("Shutting down...")
# Close database connections
# Save state
# Release resources

Lifespan Context Manager (Modern Approach)

from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Starting up...")
# Initialize resources
yield
# Shutdown
print("Shutting down...")
# Clean up resources

app = FastAPI(lifespan=lifespan)

Custom Response Classes

JSON Response with Custom Headers

from fastapi.responses import JSONResponse

@app.get("/custom-response/")
def custom_response():
content = {"message": "Hello World"}
headers = {
"X-Custom-Header": "Custom Value",
"Cache-Control": "max-age=3600"
}
return JSONResponse(content=content, headers=headers, status_code=200)

HTML Response

from fastapi.responses import HTMLResponse

@app.get("/html/", response_class=HTMLResponse)
def get_html():
html_content = """
<!DOCTYPE html>
<html>
<head><title>FastAPI</title></head>
<body>
<h1>Hello from FastAPI!</h1>
</body>
</html>
"""
return html_content

File Response

from fastapi.responses import FileResponse

@app.get("/download/{filename}")
def download_file(filename: str):
file_path = f"./files/{filename}"
return FileResponse(
path=file_path,
filename=filename,
media_type="application/octet-stream"
)

Streaming Response

from fastapi.responses import StreamingResponse
import csv
from io import StringIO

@app.get("/export/")
def export_data():
def generate():
output = StringIO()
writer = csv.writer(output)

writer.writerow(["ID", "Name", "Email"])
for i in range(1000):
writer.writerow([i, f"User {i}", f"user{i}@example.com"])
output.seek(0)
data = output.read()
output.seek(0)
output.truncate(0)
yield data

return StreamingResponse(
generate(),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=export.csv"}
)

Redirect Response

from fastapi.responses import RedirectResponse

@app.get("/old-url")
def old_url():
return RedirectResponse(url="/new-url", status_code=301)

Server-Sent Events (SSE)

from fastapi.responses import StreamingResponse
import asyncio

@app.get("/stream/")
async def stream_events():
async def event_generator():
for i in range(10):
# Send SSE formatted data
yield f"data: Message {i}\n\n"
await asyncio.sleep(1)

return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)

Custom Exception Handlers

Global Exception Handler

from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": exc.errors(),
"body": exc.body,
"message": "Validation error occurred"
}
)

class CustomException(Exception):
def __init__(self, name: str):
self.name = name

@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something wrong."}
)

GraphQL Integration

Using Strawberry GraphQL

pip install strawberry-graphql[fastapi]
import strawberry
from strawberry.fastapi import GraphQLRouter

@strawberry.type
class Query:
@strawberry.field
def hello(self, name: str = "World") -> str:
return f"Hello {name}"

@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, name: str, email: str) -> str:
# Create user logic
return f"User {name} created"

schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema)

app.include_router(graphql_app, prefix="/graphql")

Template Rendering

Using Jinja2

pip install jinja2
from fastapi.templating import Jinja2Templates
from fastapi import Request

templates = Jinja2Templates(directory="templates")

@app.get("/page/{item_id}", response_class=HTMLResponse)
async def read_item(request: Request, item_id: int):
return templates.TemplateResponse(
"item.html",
{"request": request, "item_id": item_id}
)

Static Files

from fastapi.staticfiles import StaticFiles

app.mount("/static", StaticFiles(directory="static"), name="static")

# Access files at: http://localhost:8000/static/style.css

Request Context and State

Application State

from fastapi import Request

@app.on_event("startup")
async def startup():
app.state.db = create_db_connection()
app.state.cache = create_cache()

@app.get("/items/")
async def read_items(request: Request):
db = request.app.state.db
cache = request.app.state.cache
# Use db and cache
return {"items": []}

Request State

@app.middleware("http")
async def add_request_id(request: Request, call_next):
request.state.request_id = generate_request_id()
response = await call_next(request)
response.headers["X-Request-ID"] = request.state.request_id
return response

@app.get("/items/")
async def read_items(request: Request):
request_id = request.state.request_id
return {"request_id": request_id}

Custom OpenAPI Schema

Customizing OpenAPI

from fastapi.openapi.utils import get_openapi

def custom_openapi():
if app.openapi_schema:
return app.openapi_schema

openapi_schema = get_openapi(
title="My Custom API",
version="2.0.0",
description="This is a custom OpenAPI schema",
routes=app.routes,
)

# Add custom fields
openapi_schema["info"]["x-logo"] = {
"url": "https://example.com/logo.png"
}

app.openapi_schema = openapi_schema
return app.openapi_schema

app.openapi = custom_openapi

Concurrency and Async

Mixing Sync and Async

import asyncio

# Async endpoint
@app.get("/async/")
async def async_endpoint():
await asyncio.sleep(1)
return {"message": "Async"}

# Sync endpoint
@app.get("/sync/")
def sync_endpoint():
time.sleep(1)
return {"message": "Sync"}

# Mixed: calling sync code from async
@app.get("/mixed/")
async def mixed_endpoint():
# Run sync function in thread pool
result = await asyncio.to_thread(blocking_function)
return {"result": result}

Concurrent Requests

import httpx

@app.get("/aggregate/")
async def aggregate_data():
async with httpx.AsyncClient() as client:
# Make concurrent requests
tasks = [
client.get("https://api1.example.com/data"),
client.get("https://api2.example.com/data"),
client.get("https://api3.example.com/data")
]
responses = await asyncio.gather(*tasks)

return {
"data": [r.json() for r in responses]
}

Performance Optimization

Caching with Redis

import redis
import json

redis_client = redis.Redis(host='localhost', port=6379, db=0)

@app.get("/items/{item_id}")
async def read_item(item_id: int, db: Session = Depends(get_db)):
# Check cache
cache_key = f"item:{item_id}"
cached = redis_client.get(cache_key)

if cached:
return json.loads(cached)

# Query database
item = db.query(Item).filter(Item.id == item_id).first()

# Store in cache (expire after 1 hour)
redis_client.setex(cache_key, 3600, json.dumps(item.dict()))

return item

Connection Pooling

from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool

engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=20,
max_overflow=0,
pool_pre_ping=True,
pool_recycle=3600
)

Response Caching Headers

from fastapi import Response

@app.get("/cached-data/")
def cached_data(response: Response):
response.headers["Cache-Control"] = "public, max-age=3600"
response.headers["ETag"] = "some-etag-value"
return {"data": "This is cached"}

Advanced Dependency Patterns

Dependency with Yield and Finally

async def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

async def get_resource():
resource = acquire_resource()
try:
yield resource
except Exception as e:
# Handle error
logger.error(f"Error: {e}")
raise
finally:
release_resource(resource)

Best Practices

  1. Use Background Tasks: For operations that don't need to block response
  2. WebSocket Manager: Implement proper connection management for WebSockets
  3. Middleware Order: Add middleware in correct order (CORS before others)
  4. Async Operations: Use async for I/O-bound operations
  5. Caching: Cache expensive operations and database queries
  6. Connection Pooling: Configure appropriate pool sizes
  7. Error Handling: Implement global exception handlers
  8. Monitoring: Add request tracking and performance monitoring
  9. Resource Management: Properly close connections and release resources
  10. Testing: Test advanced features thoroughly

These advanced features enable you to build sophisticated, high-performance FastAPI applications that can handle complex real-world requirements.