Source code for reqivo.transport.connection_pool

"""src/reqivo/transport/connection_pool.py

Connection pooling module.

This module provides connection pool management for efficient reuse of
TCP/TLS connections across multiple HTTP requests.
"""

import asyncio
import threading
import time
from collections import deque
from typing import Deque, Dict, List, Tuple, Union

from reqivo.transport.connection import AsyncConnection, Connection
from reqivo.utils.timing import Timeout

__all__ = ["ConnectionPool", "AsyncConnectionPool"]


[docs] class ConnectionPool: """ Pool of reusable open connections. Maintains a cache of open connections keyed by (host, port, ssl, proxy) to enable connection reuse across multiple requests. This implementation is thread-safe and supports multiple connections per host. """ __slots__ = ("_pool", "_lock", "_semaphores", "max_size", "max_idle_time") def __init__(self, max_size: int = 10, max_idle_time: float = 30.0) -> None: """ Initialize connection pool. Args: max_size: Maximum number of connections to keep per host. max_idle_time: Max time (seconds) a connection can be idle. """ # Key -> List of (connection, timestamp) tuples (LIFO stack for reuse) self._pool: Dict[Tuple[str, int, bool], Deque[Tuple[Connection, float]]] = {} self._semaphores: Dict[Tuple[str, int, bool], threading.Semaphore] = {} self._lock = threading.Lock() self.max_size = max_size self.max_idle_time = max_idle_time
[docs] def get_connection( self, host: str, port: int, use_ssl: bool, timeout: Union[float, Timeout, None] = None, ) -> Connection: """ Get an existing connection or create a new one. Blocks if max_size is reached until a connection is available. """ # pylint: disable=too-many-branches key = (host, port, use_ssl) with self._lock: if key not in self._semaphores: self._semaphores[key] = threading.Semaphore(self.max_size) # Acquire semaphore (limits total active + idle connections) # Using connect timeout or default blocking? # Ideally we block. self._semaphores[key].acquire() try: with self._lock: # Cleanup expired connections first self._cleanup_expired(key) connections = self._pool.get(key) if connections: # Iterate to find a valid connection while connections: conn, last_used = connections.pop() # Pop from right (LIFO) # Check if connection is still fresh and usable if ( time.time() - last_used < self.max_idle_time and conn.is_usable() ): return conn # Close expired or dead connection conn.close() # Create new connection conn = Connection(host, port, use_ssl, timeout=timeout) conn.open() return conn except Exception: # If anything fails (creation), release key self._semaphores[key].release() raise
[docs] def put_connection(self, conn: Connection) -> None: """ Return a connection to the pool with timestamp. """ key = (conn.host, conn.port, conn.use_ssl) # If connection is bad or closed, we effectively discard it if not conn.sock or not conn.is_usable(): conn.close() # Release the slot if key in self._semaphores: self._semaphores[key].release() return with self._lock: if key not in self._pool: self._pool[key] = deque() queue = self._pool[key] # If full, drop oldest if len(queue) >= self.max_size: oldest_conn, _ = queue.popleft() oldest_conn.close() # Store connection with current timestamp queue.append((conn, time.time())) if key in self._semaphores: self._semaphores[key].release()
def _cleanup_expired(self, key: Tuple[str, int, bool]) -> None: """ Remove expired connections from the pool for a specific key. Must be called with _lock held. """ if key not in self._pool: return connections = self._pool[key] current_time = time.time() # Filter out expired connections valid_connections: Deque[Tuple[Connection, float]] = deque() for conn, last_used in connections: if current_time - last_used < self.max_idle_time and conn.is_usable(): valid_connections.append((conn, last_used)) else: conn.close() self._pool[key] = valid_connections
[docs] def discard_connection(self, conn: Connection) -> None: """Discard a connection and release its slot.""" conn.close() key = (conn.host, conn.port, conn.use_ssl) if key in self._semaphores: self._semaphores[key].release()
[docs] def release_connection(self, host: str, port: int, use_ssl: bool) -> None: """ Deprecated: Manual release by key. """ key = (host, port, use_ssl) with self._lock: if key in self._pool: connections = self._pool.pop(key) for conn, _ in connections: conn.close() # Also release semaphores? # The number of semaphores to release is len(connections). if key in self._semaphores: for _ in range(len(connections)): self._semaphores[key].release()
[docs] def close_all(self) -> None: """ Close all connections in the pool. """ with self._lock: for key, connections in self._pool.items(): count = len(connections) for conn, _ in connections: conn.close() if key in self._semaphores: for _ in range(count): self._semaphores[key].release() self._pool.clear()
[docs] class AsyncConnectionPool: """ Pool of reusable asynchronous connections. """ __slots__ = ("_pool", "_semaphores", "max_size", "max_idle_time") def __init__(self, max_size: int = 10, max_idle_time: float = 30.0): # Key -> List of (connection, timestamp) tuples self._pool: Dict[Tuple[str, int, bool], List[Tuple[AsyncConnection, float]]] = ( {} ) self._semaphores: Dict[Tuple[str, int, bool], asyncio.Semaphore] = {} self.max_size = max_size self.max_idle_time = max_idle_time
[docs] async def get_connection( self, host: str, port: int, use_ssl: bool, timeout: Union[float, Timeout, None] = None, ) -> AsyncConnection: """ Returns an existing connection or creates a new one. """ key = (host, port, use_ssl) if key not in self._semaphores: self._semaphores[key] = asyncio.Semaphore(self.max_size) await self._semaphores[key].acquire() try: # Cleanup expired connections first await self._cleanup_expired(key) if key in self._pool: connections = self._pool[key] while connections: conn, last_used = connections.pop() # Check if connection is still fresh and usable if ( time.time() - last_used < self.max_idle_time and conn.is_usable() ): return conn # Close expired or dead connection await conn.close() # Create new connection conn = AsyncConnection(host, port, use_ssl, timeout=timeout) await conn.open() return conn except Exception: self._semaphores[key].release() raise
[docs] async def put_connection(self, conn: AsyncConnection) -> None: """ Returns a connection to the pool for reuse with timestamp. """ key = (conn.host, conn.port, conn.use_ssl) if not conn.is_usable(): await conn.close() if key in self._semaphores: self._semaphores[key].release() return if key not in self._pool: self._pool[key] = [] connections = self._pool[key] if len(connections) >= self.max_size: oldest_conn, _ = connections.pop(0) await oldest_conn.close() # Store connection with current timestamp connections.append((conn, time.time())) if key in self._semaphores: self._semaphores[key].release()
async def _cleanup_expired(self, key: Tuple[str, int, bool]) -> None: """ Remove expired connections from the pool for a specific key. """ if key not in self._pool: return connections = self._pool[key] current_time = time.time() # Filter out expired connections valid_connections: List[Tuple[AsyncConnection, float]] = [] for conn, last_used in connections: if current_time - last_used < self.max_idle_time and conn.is_usable(): valid_connections.append((conn, last_used)) else: await conn.close() self._pool[key] = valid_connections
[docs] async def discard_connection(self, conn: AsyncConnection) -> None: """Discard async connection and release slot.""" await conn.close() key = (conn.host, conn.port, conn.use_ssl) if key in self._semaphores: self._semaphores[key].release()
[docs] async def release_connection(self, host: str, port: int, use_ssl: bool) -> None: """ Closes and removes all connections for the key. """ key = (host, port, use_ssl) if key in self._pool: connections = self._pool.pop(key) for conn, _ in connections: await conn.close() if key in self._semaphores: for _ in range(len(connections)): self._semaphores[key].release()
[docs] async def close_all(self) -> None: """ Closes all connections in the pool. """ for key, connections in self._pool.items(): count = len(connections) for conn, _ in connections: await conn.close() if key in self._semaphores: for _ in range(count): self._semaphores[key].release() self._pool.clear()