Advanced Features
Middleware
Custom Middleware
from fastapi import Request, Response
import time
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
CORS Middleware
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
WebSockets
Basic WebSocket
from fastapi import WebSocket
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
WebSocket Manager
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
Background Tasks
Basic Background Tasks
from fastapi import BackgroundTasks
def write_notification(email: str, message=""):
with open("log.txt", mode="w") as email_file:
content = f"notification for {email}: {message}"
email_file.write(content)
@app.post("/send-notification/{email}")
async def send_notification(email: str, background_tasks: BackgroundTasks):
background_tasks.add_task(write_notification, email, message="some notification")
return {"message": "Notification sent in the background"}
Advanced Background Tasks
import asyncio
from typing import Dict, Any
class TaskManager:
def __init__(self):
self.tasks: Dict[str, asyncio.Task] = {}
async def create_task(self, task_id: str, coro, *args, **kwargs):
task = asyncio.create_task(coro(*args, **kwargs))
self.tasks[task_id] = task
return task_id
async def get_task_status(self, task_id: str):
task = self.tasks.get(task_id)
if not task:
return None
return {
"id": task_id,
"done": task.done(),
"cancelled": task.cancelled()
}
async def cancel_task(self, task_id: str):
task = self.tasks.get(task_id)
if task:
task.cancel()
return True
return False
task_manager = TaskManager()
Dependency Injection
Advanced Dependencies
from fastapi import Depends, HTTPException
from typing import Optional
class CommonQueryParams:
def __init__(self, q: Optional[str] = None, skip: int = 0, limit: int = 100):
self.q = q
self.skip = skip
self.limit = limit
@app.get("/items/")
async def read_items(commons: CommonQueryParams = Depends(CommonQueryParams)):
return commons
# Or as a function
def common_parameters(q: Optional[str] = None, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/items/")
async def read_items(commons: dict = Depends(common_parameters)):
return commons
Sub-dependencies
def query_extractor(q: Optional[str] = None):
return q
def query_or_cookie_extractor(
q: str = Depends(query_extractor),
last_query: Optional[str] = Cookie(None)
):
if not q:
return last_query
return q
@app.get("/items/")
async def read_query(query_or_default: str = Depends(query_or_cookie_extractor)):
return {"q_or_cookie": query_or_default}
Custom Response Classes
Custom JSON Response
from fastapi.responses import JSONResponse
import orjson
class ORJSONResponse(JSONResponse):
media_type = "application/json"
def render(self, content) -> bytes:
return orjson.dumps(content)
@app.get("/items/", response_class=ORJSONResponse)
async def read_items():
return [{"item_id": "Foo"}]
Custom HTML Response
from fastapi.responses import HTMLResponse
@app.get("/items/{item_id}", response_class=HTMLResponse)
async def read_item(item_id: str):
return f"""
<html>
<head>
<title>Item {item_id}</title>
</head>
<body>
<h1>Item: {item_id}</h1>
</body>
</html>
"""
GraphQL Integration
GraphQL with Strawberry
import strawberry
from strawberry.fastapi import GraphQLRouter
@strawberry.type
class User:
id: int
name: str
email: str
@strawberry.type
class Query:
@strawberry.field
def user(self, id: int) -> User:
return User(id=id, name="John", email="john@example.com")
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema)
app.include_router(graphql_app, prefix="/graphql")
Event Handlers
Startup and Shutdown Events
@app.on_event("startup")
async def startup_event():
print("Application is starting up...")
# Initialize database connections, cache, etc.
@app.on_event("shutdown")
async def shutdown_event():
print("Application is shutting down...")
# Close database connections, clean up resources
Custom Exception Handlers
Global Exception Handler
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
class CustomException(Exception):
def __init__(self, name: str):
self.name = name
@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something wrong"}
)
@app.exception_handler(500)
async def internal_error_handler(request: Request, exc):
return JSONResponse(
status_code=500,
content={"message": "Internal server error"}
)
Request Validation
Custom Validators
from pydantic import BaseModel, validator, root_validator
class UserModel(BaseModel):
name: str
email: str
age: int
@validator('name')
def name_must_contain_space(cls, v):
if ' ' not in v:
raise ValueError('Name must contain a space')
return v.title()
@validator('email')
def email_must_be_valid(cls, v):
if '@' not in v:
raise ValueError('Invalid email')
return v
@root_validator
def validate_age_and_name(cls, values):
age = values.get('age')
name = values.get('name')
if age and age < 13 and name:
raise ValueError('Children must have parental consent')
return values
File Handling
File Upload with Validation
from fastapi import File, UploadFile, HTTPException
import shutil
from pathlib import Path
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
def validate_file(file: UploadFile):
if file.filename:
ext = file.filename.split('.')[-1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(status_code=400, detail="File type not allowed")
if file.size > MAX_FILE_SIZE:
raise HTTPException(status_code=400, detail="File too large")
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
validate_file(file)
file_path = Path("uploads") / file.filename
file_path.parent.mkdir(exist_ok=True)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {"filename": file.filename, "size": file.size}