Skip to main content

Advanced Features

Middleware

Custom Middleware

from fastapi import Request, Response
import time

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response

CORS Middleware

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

WebSockets

Basic WebSocket

from fastapi import WebSocket

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")

WebSocket Manager

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)

async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")

Background Tasks

Basic Background Tasks

from fastapi import BackgroundTasks

def write_notification(email: str, message=""):
with open("log.txt", mode="w") as email_file:
content = f"notification for {email}: {message}"
email_file.write(content)

@app.post("/send-notification/{email}")
async def send_notification(email: str, background_tasks: BackgroundTasks):
background_tasks.add_task(write_notification, email, message="some notification")
return {"message": "Notification sent in the background"}

Advanced Background Tasks

import asyncio
from typing import Dict, Any

class TaskManager:
def __init__(self):
self.tasks: Dict[str, asyncio.Task] = {}

async def create_task(self, task_id: str, coro, *args, **kwargs):
task = asyncio.create_task(coro(*args, **kwargs))
self.tasks[task_id] = task
return task_id

async def get_task_status(self, task_id: str):
task = self.tasks.get(task_id)
if not task:
return None
return {
"id": task_id,
"done": task.done(),
"cancelled": task.cancelled()
}

async def cancel_task(self, task_id: str):
task = self.tasks.get(task_id)
if task:
task.cancel()
return True
return False

task_manager = TaskManager()

Dependency Injection

Advanced Dependencies

from fastapi import Depends, HTTPException
from typing import Optional

class CommonQueryParams:
def __init__(self, q: Optional[str] = None, skip: int = 0, limit: int = 100):
self.q = q
self.skip = skip
self.limit = limit

@app.get("/items/")
async def read_items(commons: CommonQueryParams = Depends(CommonQueryParams)):
return commons

# Or as a function
def common_parameters(q: Optional[str] = None, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}

@app.get("/items/")
async def read_items(commons: dict = Depends(common_parameters)):
return commons

Sub-dependencies

def query_extractor(q: Optional[str] = None):
return q

def query_or_cookie_extractor(
q: str = Depends(query_extractor),
last_query: Optional[str] = Cookie(None)
):
if not q:
return last_query
return q

@app.get("/items/")
async def read_query(query_or_default: str = Depends(query_or_cookie_extractor)):
return {"q_or_cookie": query_or_default}

Custom Response Classes

Custom JSON Response

from fastapi.responses import JSONResponse
import orjson

class ORJSONResponse(JSONResponse):
media_type = "application/json"

def render(self, content) -> bytes:
return orjson.dumps(content)

@app.get("/items/", response_class=ORJSONResponse)
async def read_items():
return [{"item_id": "Foo"}]

Custom HTML Response

from fastapi.responses import HTMLResponse

@app.get("/items/{item_id}", response_class=HTMLResponse)
async def read_item(item_id: str):
return f"""
<html>
<head>
<title>Item {item_id}</title>
</head>
<body>
<h1>Item: {item_id}</h1>
</body>
</html>
"""

GraphQL Integration

GraphQL with Strawberry

import strawberry
from strawberry.fastapi import GraphQLRouter

@strawberry.type
class User:
id: int
name: str
email: str

@strawberry.type
class Query:
@strawberry.field
def user(self, id: int) -> User:
return User(id=id, name="John", email="john@example.com")

schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema)

app.include_router(graphql_app, prefix="/graphql")

Event Handlers

Startup and Shutdown Events

@app.on_event("startup")
async def startup_event():
print("Application is starting up...")
# Initialize database connections, cache, etc.

@app.on_event("shutdown")
async def shutdown_event():
print("Application is shutting down...")
# Close database connections, clean up resources

Custom Exception Handlers

Global Exception Handler

from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse

class CustomException(Exception):
def __init__(self, name: str):
self.name = name

@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something wrong"}
)

@app.exception_handler(500)
async def internal_error_handler(request: Request, exc):
return JSONResponse(
status_code=500,
content={"message": "Internal server error"}
)

Request Validation

Custom Validators

from pydantic import BaseModel, validator, root_validator

class UserModel(BaseModel):
name: str
email: str
age: int

@validator('name')
def name_must_contain_space(cls, v):
if ' ' not in v:
raise ValueError('Name must contain a space')
return v.title()

@validator('email')
def email_must_be_valid(cls, v):
if '@' not in v:
raise ValueError('Invalid email')
return v

@root_validator
def validate_age_and_name(cls, values):
age = values.get('age')
name = values.get('name')
if age and age < 13 and name:
raise ValueError('Children must have parental consent')
return values

File Handling

File Upload with Validation

from fastapi import File, UploadFile, HTTPException
import shutil
from pathlib import Path

ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB

def validate_file(file: UploadFile):
if file.filename:
ext = file.filename.split('.')[-1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="File type not allowed")

if file.size > MAX_FILE_SIZE:
raise HTTPException(status_code=400, detail="File too large")

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
validate_file(file)

file_path = Path("uploads") / file.filename
file_path.parent.mkdir(exist_ok=True)

with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)

return {"filename": file.filename, "size": file.size}