| from __future__ import annotations |
|
|
| import socket |
| from abc import abstractmethod |
| from collections.abc import Callable, Collection, Mapping |
| from contextlib import AsyncExitStack |
| from io import IOBase |
| from ipaddress import IPv4Address, IPv6Address |
| from socket import AddressFamily |
| from types import TracebackType |
| from typing import Any, TypeVar, Union |
|
|
| from .._core._typedattr import ( |
| TypedAttributeProvider, |
| TypedAttributeSet, |
| typed_attribute, |
| ) |
| from ._streams import ByteStream, Listener, UnreliableObjectStream |
| from ._tasks import TaskGroup |
|
|
| IPAddressType = Union[str, IPv4Address, IPv6Address] |
| IPSockAddrType = tuple[str, int] |
| SockAddrType = Union[IPSockAddrType, str] |
| UDPPacketType = tuple[bytes, IPSockAddrType] |
| UNIXDatagramPacketType = tuple[bytes, str] |
| T_Retval = TypeVar("T_Retval") |
|
|
|
|
| class _NullAsyncContextManager: |
| async def __aenter__(self) -> None: |
| pass |
|
|
| async def __aexit__( |
| self, |
| exc_type: type[BaseException] | None, |
| exc_val: BaseException | None, |
| exc_tb: TracebackType | None, |
| ) -> bool | None: |
| return None |
|
|
|
|
| class SocketAttribute(TypedAttributeSet): |
| |
| family: AddressFamily = typed_attribute() |
| |
| local_address: SockAddrType = typed_attribute() |
| |
| local_port: int = typed_attribute() |
| |
| raw_socket: socket.socket = typed_attribute() |
| |
| remote_address: SockAddrType = typed_attribute() |
| |
| remote_port: int = typed_attribute() |
|
|
|
|
| class _SocketProvider(TypedAttributeProvider): |
| @property |
| def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: |
| from .._core._sockets import convert_ipv6_sockaddr as convert |
|
|
| attributes: dict[Any, Callable[[], Any]] = { |
| SocketAttribute.family: lambda: self._raw_socket.family, |
| SocketAttribute.local_address: lambda: convert( |
| self._raw_socket.getsockname() |
| ), |
| SocketAttribute.raw_socket: lambda: self._raw_socket, |
| } |
| try: |
| peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) |
| except OSError: |
| peername = None |
|
|
| |
| if peername is not None: |
| attributes[SocketAttribute.remote_address] = lambda: peername |
|
|
| |
| if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): |
| attributes[SocketAttribute.local_port] = ( |
| lambda: self._raw_socket.getsockname()[1] |
| ) |
| if peername is not None: |
| remote_port = peername[1] |
| attributes[SocketAttribute.remote_port] = lambda: remote_port |
|
|
| return attributes |
|
|
| @property |
| @abstractmethod |
| def _raw_socket(self) -> socket.socket: |
| pass |
|
|
|
|
| class SocketStream(ByteStream, _SocketProvider): |
| """ |
| Transports bytes over a socket. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|
|
|
| class UNIXSocketStream(SocketStream): |
| @abstractmethod |
| async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: |
| """ |
| Send file descriptors along with a message to the peer. |
| |
| :param message: a non-empty bytestring |
| :param fds: a collection of files (either numeric file descriptors or open file |
| or socket objects) |
| """ |
|
|
| @abstractmethod |
| async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: |
| """ |
| Receive file descriptors along with a message from the peer. |
| |
| :param msglen: length of the message to expect from the peer |
| :param maxfds: maximum number of file descriptors to expect from the peer |
| :return: a tuple of (message, file descriptors) |
| """ |
|
|
|
|
| class SocketListener(Listener[SocketStream], _SocketProvider): |
| """ |
| Listens to incoming socket connections. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|
| @abstractmethod |
| async def accept(self) -> SocketStream: |
| """Accept an incoming connection.""" |
|
|
| async def serve( |
| self, |
| handler: Callable[[SocketStream], Any], |
| task_group: TaskGroup | None = None, |
| ) -> None: |
| from .. import create_task_group |
|
|
| async with AsyncExitStack() as stack: |
| if task_group is None: |
| task_group = await stack.enter_async_context(create_task_group()) |
|
|
| while True: |
| stream = await self.accept() |
| task_group.start_soon(handler, stream) |
|
|
|
|
| class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): |
| """ |
| Represents an unconnected UDP socket. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|
| async def sendto(self, data: bytes, host: str, port: int) -> None: |
| """ |
| Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))). |
| |
| """ |
| return await self.send((data, (host, port))) |
|
|
|
|
| class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): |
| """ |
| Represents an connected UDP socket. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|
|
|
| class UNIXDatagramSocket( |
| UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider |
| ): |
| """ |
| Represents an unconnected Unix datagram socket. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|
| async def sendto(self, data: bytes, path: str) -> None: |
| """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path)).""" |
| return await self.send((data, path)) |
|
|
|
|
| class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider): |
| """ |
| Represents a connected Unix datagram socket. |
| |
| Supports all relevant extra attributes from :class:`~SocketAttribute`. |
| """ |
|
|