Performance & Optimization
Advanced performance optimization techniques for GraphQL APIs including caching, batching, and query optimization.
DataLoader Pattern
Basic DataLoader Implementation
from promise import Promise
from promise.dataloader import DataLoader
import asyncio
class UserDataLoader(DataLoader):
def __init__(self, db_pool):
super().__init__()
self.db_pool = db_pool
def batch_load_fn(self, user_ids):
return Promise.resolve(self.load_users(user_ids))
async def load_users(self, user_ids):
async with self.db_pool.acquire() as conn:
users = await conn.fetch(
"SELECT * FROM users WHERE id = ANY($1)",
user_ids
)
user_map = {user['id']: user for user in users}
return [user_map.get(user_id) for user_id in user_ids]
class PostDataLoader(DataLoader):
def __init__(self, db_pool):
super().__init__()
self.db_pool = db_pool
def batch_load_fn(self, user_ids):
return Promise.resolve(self.load_posts_by_user(user_ids))
async def load_posts_by_user(self, user_ids):
async with self.db_pool.acquire() as conn:
posts = await conn.fetch(
"SELECT * FROM posts WHERE author_id = ANY($1)",
user_ids
)
posts_by_user = {}
for post in posts:
user_id = post['author_id']
if user_id not in posts_by_user:
posts_by_user[user_id] = []
posts_by_user[user_id].append(post)
return [posts_by_user.get(user_id, []) for user_id in user_ids]
# Context with DataLoaders
class GraphQLContext:
def __init__(self, db_pool):
self.user_loader = UserDataLoader(db_pool)
self.post_loader = PostDataLoader(db_pool)
self.comment_loader = CommentDataLoader(db_pool)
# Using DataLoader in resolvers
class PostType(ObjectType):
author = Field(UserType)
def resolve_author(self, info):
return info.context.user_loader.load(self.author_id)
class UserType(ObjectType):
posts = List(PostType)
def resolve_posts(self, info):
return info.context.post_loader.load(self.id)
Advanced DataLoader Patterns
import asyncio
from collections import defaultdict
class MultiFieldDataLoader(DataLoader):
"""DataLoader that can batch multiple fields"""
def __init__(self, db_pool, fields):
super().__init__()
self.db_pool = db_pool
self.fields = fields
def batch_load_fn(self, keys):
return Promise.resolve(self.load_multiple_fields(keys))
async def load_multiple_fields(self, keys):
# Group keys by field
field_groups = defaultdict(list)
for key in keys:
field, entity_id = key
field_groups[field].append(entity_id)
# Load all fields in parallel
results = {}
tasks = []
for field, entity_ids in field_groups.items():
task = self.load_field_data(field, entity_ids)
tasks.append((field, task))
# Wait for all tasks
for field, task in tasks:
field_data = await task
for entity_id, data in field_data.items():
results[(field, entity_id)] = data
# Return results in original order
return [results.get(key) for key in keys]
async def load_field_data(self, field, entity_ids):
if field == 'user_posts':
return await self.load_user_posts(entity_ids)
elif field == 'post_comments':
return await self.load_post_comments(entity_ids)
elif field == 'user_followers':
return await self.load_user_followers(entity_ids)
return {}
async def load_user_posts(self, user_ids):
async with self.db_pool.acquire() as conn:
posts = await conn.fetch(
"SELECT * FROM posts WHERE author_id = ANY($1)",
user_ids
)
posts_by_user = defaultdict(list)
for post in posts:
posts_by_user[post['author_id']].append(post)
return dict(posts_by_user)
# Prime DataLoader cache
async def prime_user_cache(user_loader, users):
"""Pre-populate DataLoader cache"""
for user in users:
user_loader.clear(user['id']).prime(user['id'], user)
# Composite DataLoader
class CompositeDataLoader:
def __init__(self, db_pool):
self.db_pool = db_pool
self.cache = {}
async def load_user_with_stats(self, user_id):
cache_key = f"user_stats:{user_id}"
if cache_key in self.cache:
return self.cache[cache_key]
async with self.db_pool.acquire() as conn:
result = await conn.fetchrow("""
SELECT u.*,
COUNT(DISTINCT p.id) as post_count,
COUNT(DISTINCT f.follower_id) as follower_count,
COUNT(DISTINCT fw.following_id) as following_count
FROM users u
LEFT JOIN posts p ON u.id = p.author_id
LEFT JOIN follows f ON u.id = f.following_id
LEFT JOIN follows fw ON u.id = fw.follower_id
WHERE u.id = $1
GROUP BY u.id
""", user_id)
self.cache[cache_key] = result
return result
Caching Strategies
Query-Level Caching
import hashlib
import json
from functools import wraps
class QueryCache:
def __init__(self, redis_client, default_ttl=300):
self.redis = redis_client
self.default_ttl = default_ttl
def generate_cache_key(self, query, variables, user_id=None):
"""Generate deterministic cache key"""
cache_data = {
'query': query,
'variables': variables,
'user_id': user_id
}
cache_string = json.dumps(cache_data, sort_keys=True)
return f"gql:{hashlib.md5(cache_string.encode()).hexdigest()}"
async def get(self, key):
"""Get cached query result"""
data = await self.redis.get(key)
return json.loads(data) if data else None
async def set(self, key, value, ttl=None):
"""Cache query result"""
ttl = ttl or self.default_ttl
await self.redis.set(key, json.dumps(value), ex=ttl)
async def invalidate_pattern(self, pattern):
"""Invalidate cache keys matching pattern"""
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
# Cache decorator for resolvers
def cache_resolver(ttl=300, key_template=None):
def decorator(resolver):
@wraps(resolver)
async def wrapper(root, info, **kwargs):
# Generate cache key
if key_template:
cache_key = key_template.format(
root_id=getattr(root, 'id', 'root'),
**kwargs
)
else:
cache_key = f"{resolver.__name__}:{getattr(root, 'id', 'root')}:{kwargs}"
cache_key = f"resolver:{hashlib.md5(cache_key.encode()).hexdigest()}"
# Try cache first
cached_result = await info.context.cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute resolver
result = await resolver(root, info, **kwargs)
# Cache result
await info.context.cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
# Usage
class PostType(ObjectType):
comments = List(CommentType)
@cache_resolver(ttl=600, key_template="post_comments:{root_id}")
async def resolve_comments(self, info):
return await get_comments_for_post(self.id)
Field-Level Caching
class FieldCache:
def __init__(self, redis_client):
self.redis = redis_client
self.field_configs = {}
def register_field(self, field_name, ttl=300, invalidate_on=None):
"""Register field for caching"""
self.field_configs[field_name] = {
'ttl': ttl,
'invalidate_on': invalidate_on or []
}
async def get_field(self, field_name, entity_id, user_id=None):
"""Get cached field value"""
cache_key = self.get_field_key(field_name, entity_id, user_id)
data = await self.redis.get(cache_key)
return json.loads(data) if data else None
async def set_field(self, field_name, entity_id, value, user_id=None):
"""Cache field value"""
config = self.field_configs.get(field_name, {'ttl': 300})
cache_key = self.get_field_key(field_name, entity_id, user_id)
await self.redis.set(cache_key, json.dumps(value), ex=config['ttl'])
def get_field_key(self, field_name, entity_id, user_id=None):
if user_id:
return f"field:{field_name}:{entity_id}:{user_id}"
return f"field:{field_name}:{entity_id}"
async def invalidate_field(self, field_name, entity_id=None):
"""Invalidate field cache"""
if entity_id:
pattern = f"field:{field_name}:{entity_id}*"
else:
pattern = f"field:{field_name}*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
# Setup field caching
field_cache = FieldCache(redis_client)
field_cache.register_field('user_posts', ttl=600, invalidate_on=['post_created', 'post_deleted'])
field_cache.register_field('post_comments', ttl=300, invalidate_on=['comment_added', 'comment_deleted'])
Persisted Queries
class PersistedQueryManager:
def __init__(self, redis_client):
self.redis = redis_client
self.queries = {}
async def store_query(self, query_hash, query_string):
"""Store persisted query"""
await self.redis.set(f"pq:{query_hash}", query_string)
self.queries[query_hash] = query_string
async def get_query(self, query_hash):
"""Get persisted query"""
if query_hash in self.queries:
return self.queries[query_hash]
query = await self.redis.get(f"pq:{query_hash}")
if query:
self.queries[query_hash] = query
return query
return None
async def handle_persisted_query(self, request_data):
"""Handle persisted query request"""
if 'queryHash' in request_data:
query_hash = request_data['queryHash']
query = await self.get_query(query_hash)
if not query:
return {'errors': [{'message': 'PersistedQueryNotFound'}]}
request_data['query'] = query
return request_data
# GraphQL endpoint with persisted queries
async def graphql_endpoint(request):
request_data = await request.json()
# Handle persisted queries
request_data = await persisted_query_manager.handle_persisted_query(request_data)
if 'errors' in request_data:
return JSONResponse(request_data, status_code=400)
# Execute GraphQL query
result = await schema.execute_async(
request_data['query'],
variable_values=request_data.get('variables'),
context_value=create_context(request)
)
return JSONResponse(result.to_dict())
Query Optimization
Query Complexity Analysis
from graphql import validate, DocumentNode
from graphql.validation import ValidationRule
class QueryComplexityAnalyzer(ValidationRule):
def __init__(self, max_complexity=1000, max_depth=15):
self.max_complexity = max_complexity
self.max_depth = max_depth
self.complexity = 0
self.depth = 0
self.field_configs = {}
def register_field_complexity(self, type_name, field_name, complexity):
"""Register complexity score for specific field"""
self.field_configs[f"{type_name}.{field_name}"] = complexity
def enter_field(self, node, *args):
# Calculate field complexity
type_name = node.parent_type.name if hasattr(node, 'parent_type') else 'Query'
field_name = node.name.value
field_key = f"{type_name}.{field_name}"
field_complexity = self.field_configs.get(field_key, 1)
# Factor in arguments (e.g., list sizes)
if node.arguments:
for arg in node.arguments:
if arg.name.value in ['first', 'last', 'limit']:
if hasattr(arg.value, 'value'):
field_complexity *= int(arg.value.value)
self.complexity += field_complexity
self.depth += 1
if self.complexity > self.max_complexity:
raise Exception(f"Query complexity {self.complexity} exceeds limit {self.max_complexity}")
if self.depth > self.max_depth:
raise Exception(f"Query depth {self.depth} exceeds limit {self.max_depth}")
def leave_field(self, node, *args):
self.depth -= 1
# Setup complexity analysis
complexity_analyzer = QueryComplexityAnalyzer(max_complexity=1000, max_depth=15)
complexity_analyzer.register_field_complexity('User', 'posts', 5)
complexity_analyzer.register_field_complexity('Post', 'comments', 3)
complexity_analyzer.register_field_complexity('Query', 'search', 10)
# Validate query complexity
def validate_query_complexity(schema, document):
validation_rules = [complexity_analyzer]
errors = validate(schema, document, rules=validation_rules)
return errors
Query Timeout and Rate Limiting
import asyncio
import time
from collections import defaultdict
class QueryRateLimiter:
def __init__(self, max_requests_per_minute=60, max_complexity_per_minute=10000):
self.max_requests_per_minute = max_requests_per_minute
self.max_complexity_per_minute = max_complexity_per_minute
self.client_requests = defaultdict(list)
self.client_complexity = defaultdict(list)
def is_allowed(self, client_id, query_complexity=1):
"""Check if request is allowed"""
now = time.time()
minute_ago = now - 60
# Clean old entries
self.client_requests[client_id] = [
req_time for req_time in self.client_requests[client_id]
if req_time > minute_ago
]
self.client_complexity[client_id] = [
(complexity, req_time) for complexity, req_time in self.client_complexity[client_id]
if req_time > minute_ago
]
# Check limits
if len(self.client_requests[client_id]) >= self.max_requests_per_minute:
return False, "Request rate limit exceeded"
total_complexity = sum(complexity for complexity, _ in self.client_complexity[client_id])
if total_complexity + query_complexity > self.max_complexity_per_minute:
return False, "Complexity rate limit exceeded"
# Record request
self.client_requests[client_id].append(now)
self.client_complexity[client_id].append((query_complexity, now))
return True, None
class QueryTimeoutManager:
def __init__(self, default_timeout=30):
self.default_timeout = default_timeout
self.field_timeouts = {}
def register_field_timeout(self, field_name, timeout):
"""Register timeout for specific field"""
self.field_timeouts[field_name] = timeout
async def execute_with_timeout(self, coro, field_name=None):
"""Execute coroutine with timeout"""
timeout = self.field_timeouts.get(field_name, self.default_timeout)
try:
return await asyncio.wait_for(coro, timeout=timeout)
except asyncio.TimeoutError:
raise Exception(f"Query timeout after {timeout} seconds")
# Timeout decorator
def with_timeout(timeout_seconds=30):
def decorator(resolver):
@wraps(resolver)
async def wrapper(root, info, **kwargs):
try:
return await asyncio.wait_for(
resolver(root, info, **kwargs),
timeout=timeout_seconds
)
except asyncio.TimeoutError:
raise Exception(f"Resolver {resolver.__name__} timed out after {timeout_seconds}s")
return wrapper
return decorator
Database Optimization
Query Batching and Optimization
class DatabaseQueryOptimizer:
def __init__(self, db_pool):
self.db_pool = db_pool
self.prepared_statements = {}
async def prepare_statements(self):
"""Prepare frequently used SQL statements"""
statements = {
'get_users_by_ids': """
SELECT id, name, email, avatar, created_at
FROM users
WHERE id = ANY($1)
""",
'get_posts_by_user_ids': """
SELECT id, title, content, author_id, created_at, published_at
FROM posts
WHERE author_id = ANY($1) AND status = 'published'
ORDER BY published_at DESC
""",
'get_comments_by_post_ids': """
SELECT id, content, author_id, post_id, created_at
FROM comments
WHERE post_id = ANY($1)
ORDER BY created_at ASC
"""
}
async with self.db_pool.acquire() as conn:
for name, sql in statements.items():
self.prepared_statements[name] = await conn.prepare(sql)
async def execute_prepared(self, statement_name, *args):
"""Execute prepared statement"""
if statement_name not in self.prepared_statements:
raise ValueError(f"Statement {statement_name} not prepared")
async with self.db_pool.acquire() as conn:
return await self.prepared_statements[statement_name].fetch(*args)
async def batch_load_with_joins(self, main_table, join_configs, ids):
"""Load data with optimized joins"""
join_clauses = []
select_fields = [f"{main_table}.*"]
for join_config in join_configs:
table = join_config['table']
on_clause = join_config['on']
fields = join_config.get('fields', ['*'])
join_clauses.append(f"LEFT JOIN {table} ON {on_clause}")
select_fields.extend([f"{table}.{field}" for field in fields])
query = f"""
SELECT {', '.join(select_fields)}
FROM {main_table}
{' '.join(join_clauses)}
WHERE {main_table}.id = ANY($1)
"""
async with self.db_pool.acquire() as conn:
return await conn.fetch(query, ids)
# Optimized DataLoader with joins
class OptimizedUserLoader(DataLoader):
def __init__(self, db_optimizer):
super().__init__()
self.db_optimizer = db_optimizer
def batch_load_fn(self, user_ids):
return Promise.resolve(self.load_users_with_stats(user_ids))
async def load_users_with_stats(self, user_ids):
join_configs = [
{
'table': 'user_stats',
'on': 'users.id = user_stats.user_id',
'fields': ['post_count', 'follower_count', 'following_count']
}
]
results = await self.db_optimizer.batch_load_with_joins(
'users', join_configs, user_ids
)
user_map = {}
for row in results:
user_id = row['id']
user_map[user_id] = {
'id': row['id'],
'name': row['name'],
'email': row['email'],
'stats': {
'post_count': row['post_count'] or 0,
'follower_count': row['follower_count'] or 0,
'following_count': row['following_count'] or 0,
}
}
return [user_map.get(user_id) for user_id in user_ids]
Monitoring and Performance Metrics
GraphQL Performance Monitoring
import time
from collections import defaultdict
class GraphQLPerformanceMonitor:
def __init__(self):
self.query_metrics = defaultdict(list)
self.resolver_metrics = defaultdict(list)
self.error_counts = defaultdict(int)
def record_query(self, query_hash, duration, complexity, error_count=0):
"""Record query execution metrics"""
self.query_metrics[query_hash].append({
'duration': duration,
'complexity': complexity,
'error_count': error_count,
'timestamp': time.time()
})
# Keep only last 1000 entries
if len(self.query_metrics[query_hash]) > 1000:
self.query_metrics[query_hash] = self.query_metrics[query_hash][-1000:]
def record_resolver(self, resolver_name, duration, cache_hit=False):
"""Record resolver execution metrics"""
self.resolver_metrics[resolver_name].append({
'duration': duration,
'cache_hit': cache_hit,
'timestamp': time.time()
})
def get_query_stats(self, query_hash, window_seconds=3600):
"""Get query performance statistics"""
cutoff_time = time.time() - window_seconds
recent_metrics = [
m for m in self.query_metrics[query_hash]
if m['timestamp'] > cutoff_time
]
if not recent_metrics:
return None
durations = [m['duration'] for m in recent_metrics]
return {
'count': len(recent_metrics),
'avg_duration': sum(durations) / len(durations),
'min_duration': min(durations),
'max_duration': max(durations),
'p95_duration': sorted(durations)[int(len(durations) * 0.95)],
'error_rate': sum(m['error_count'] for m in recent_metrics) / len(recent_metrics)
}
def get_slowest_queries(self, limit=10):
"""Get slowest queries"""
all_queries = []
for query_hash, metrics in self.query_metrics.items():
if metrics:
avg_duration = sum(m['duration'] for m in metrics) / len(metrics)
all_queries.append({
'query_hash': query_hash,
'avg_duration': avg_duration,
'execution_count': len(metrics)
})
return sorted(all_queries, key=lambda x: x['avg_duration'], reverse=True)[:limit]
# Performance monitoring middleware
class PerformanceMonitoringMiddleware:
def __init__(self, app, monitor):
self.app = app
self.monitor = monitor
async def __call__(self, scope, receive, send):
if scope['type'] != 'http':
await self.app(scope, receive, send)
return
start_time = time.time()
# Parse GraphQL request
request_data = await self.parse_request(receive)
query_hash = self.generate_query_hash(request_data.get('query', ''))
error_count = 0
try:
await self.app(scope, receive, send)
except Exception as e:
error_count = 1
raise
finally:
duration = time.time() - start_time
# Record metrics
complexity = self.calculate_complexity(request_data.get('query', ''))
self.monitor.record_query(query_hash, duration, complexity, error_count)
def generate_query_hash(self, query):
"""Generate hash for query"""
normalized_query = self.normalize_query(query)
return hashlib.md5(normalized_query.encode()).hexdigest()
def normalize_query(self, query):
"""Normalize query for consistent hashing"""
# Remove whitespace and normalize formatting
import re
return re.sub(r'\s+', ' ', query.strip())
def calculate_complexity(self, query):
"""Calculate query complexity score"""
# Simple complexity calculation
return query.count('{') + query.count('(') * 2
This comprehensive guide covers advanced performance optimization techniques for GraphQL APIs, including DataLoader patterns, caching strategies, query optimization, database optimization, and performance monitoring.