Source code for reqivo.transport.connection

"""src/reqivo/transport/connection.py

TCP and TLS connection management module.

This module provides low-level connection handling with support for
TLS encryption, proxy connections, and proper error handling for
network operations.
"""

import asyncio
import contextlib
import select
import socket
import ssl
from typing import Any, Optional, Union

# pylint: disable=unused-import
# pylint: disable=redefined-builtin
from reqivo.exceptions import (
    ConnectTimeout,
    NetworkError,
    ReadTimeout,
    TimeoutError,
    TlsError,
)
from reqivo.utils.timing import Timeout

__all__ = ["Connection", "AsyncConnection"]


[docs] class Connection: """ Manages TCP and TLS connection creation and lifecycle. Attributes: host: The target hostname or IP address. port: The target port number. use_ssl: Whether to use TLS encryption. timeout: Connection timeout configuration. sock: The underlying socket object. """ __slots__ = ("host", "port", "use_ssl", "timeout", "sock") def __init__( self, host: str, port: int, use_ssl: bool = False, timeout: Union[float, Timeout, None] = None, ) -> None: """ Initialize connection parameters. """ self.host = host self.port = port self.use_ssl = use_ssl if timeout is None: self.timeout = None elif isinstance(timeout, Timeout): self.timeout = timeout else: self.timeout = Timeout.from_float(timeout) self.sock: Optional[socket.socket] = None
[docs] def open(self) -> socket.socket: """ Open TCP connection with optional TLS encryption. """ if self.timeout: connect_to = ( self.timeout.connect if self.timeout.connect is not None else self.timeout.total ) else: connect_to = None try: raw_sock = socket.create_connection( (self.host, self.port), timeout=connect_to ) if self.use_ssl: context = ssl.create_default_context() context.minimum_version = ssl.TLSVersion.TLSv1_2 try: self.sock = context.wrap_socket(raw_sock, server_hostname=self.host) except socket.timeout as e: raise ConnectTimeout(f"Timeout during TLS handshake: {e}") from e else: self.sock = raw_sock # After connection is established, switch timeout to 'read_timeout' if self.timeout: read_to = ( self.timeout.read if self.timeout.read is not None else self.timeout.total ) self.sock.settimeout(read_to) else: self.sock.settimeout(None) return self.sock except socket.timeout as e: raise ConnectTimeout( f"Timeout connecting to {self.host}:{self.port}" ) from e except ssl.SSLError as e: raise TlsError(f"TLS Verification Error: {e}") from e except socket.error as e: if isinstance(e, socket.timeout): raise ConnectTimeout( f"Timeout connecting to {self.host}:{self.port}" ) from e raise NetworkError( f"Connection error to {self.host}:{self.port} - {e}" ) from e
[docs] def close(self) -> None: """ Close the connection if it is open. """ if self.sock: try: self.sock.close() except (OSError, socket.error): pass self.sock = None
def __enter__(self) -> "Connection": self.open() return self def __exit__( self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any], ) -> None: self.close()
[docs] def is_usable(self) -> bool: """ Check if the connection appears usable (not closed by peer). """ if not self.sock: return False try: readable, _, _ = select.select([self.sock], [], [], 0) if readable: try: chunk = self.sock.recv(1, socket.MSG_PEEK) if not chunk: # Connection closed by peer return False # If we can peek data, connection has pending data # This should not happen in pooled connections # Consider this unusable to avoid reading stale data return False except (socket.error, OSError): return False return True except (socket.error, OSError): return False
[docs] class AsyncConnection: """ Manages asynchronous TCP and TLS connection creation and lifecycle. """ __slots__ = ("host", "port", "use_ssl", "timeout", "reader", "writer") def __init__( self, host: str, port: int, use_ssl: bool = False, timeout: Union[float, Timeout, None] = None, ) -> None: self.host = host self.port = port self.use_ssl = use_ssl if timeout is None: self.timeout = None elif isinstance(timeout, Timeout): self.timeout = timeout else: self.timeout = Timeout.from_float(timeout) self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None
[docs] async def open(self) -> None: """Async open.""" ssl_context = None if self.use_ssl: ssl_context = ssl.create_default_context() ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 if self.timeout: connect_to = ( self.timeout.connect if self.timeout.connect is not None else self.timeout.total ) else: connect_to = None try: coro = asyncio.open_connection(self.host, self.port, ssl=ssl_context) if connect_to: self.reader, self.writer = await asyncio.wait_for( coro, timeout=connect_to ) else: self.reader, self.writer = await coro except asyncio.TimeoutError as e: raise ConnectTimeout( f"Connection to {self.host}:{self.port} timed out" ) from e except ssl.SSLError as e: raise TlsError(f"TLS connection failed: {e}") from e except Exception as e: # pylint: disable=broad-exception-caught raise NetworkError( f"Failed to connect to {self.host}:{self.port}: {e}" ) from e
[docs] def is_usable(self) -> bool: """Check if connection is usable.""" if not self.writer or not self.reader: return False return not self.writer.is_closing() and not self.reader.at_eof()
[docs] async def close(self) -> None: """Async close.""" if self.writer: self.writer.close() with contextlib.suppress(Exception): await self.writer.wait_closed() self.reader = None self.writer = None