| | from __future__ import annotations |
| |
|
| | import hashlib |
| | import os |
| | import re |
| | import time |
| | import typing |
| | from base64 import b64encode |
| | from urllib.request import parse_http_list |
| |
|
| | from ._exceptions import ProtocolError |
| | from ._models import Cookies, Request, Response |
| | from ._utils import to_bytes, to_str, unquote |
| |
|
| | if typing.TYPE_CHECKING: |
| | from hashlib import _Hash |
| |
|
| |
|
| | __all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"] |
| |
|
| |
|
| | class Auth: |
| | """ |
| | Base class for all authentication schemes. |
| | |
| | To implement a custom authentication scheme, subclass `Auth` and override |
| | the `.auth_flow()` method. |
| | |
| | If the authentication scheme does I/O such as disk access or network calls, or uses |
| | synchronization primitives such as locks, you should override `.sync_auth_flow()` |
| | and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized |
| | implementations that will be used by `Client` and `AsyncClient` respectively. |
| | """ |
| |
|
| | requires_request_body = False |
| | requires_response_body = False |
| |
|
| | def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| | """ |
| | Execute the authentication flow. |
| | |
| | To dispatch a request, `yield` it: |
| | |
| | ``` |
| | yield request |
| | ``` |
| | |
| | The client will `.send()` the response back into the flow generator. You can |
| | access it like so: |
| | |
| | ``` |
| | response = yield request |
| | ``` |
| | |
| | A `return` (or reaching the end of the generator) will result in the |
| | client returning the last response obtained from the server. |
| | |
| | You can dispatch as many requests as is necessary. |
| | """ |
| | yield request |
| |
|
| | def sync_auth_flow( |
| | self, request: Request |
| | ) -> typing.Generator[Request, Response, None]: |
| | """ |
| | Execute the authentication flow synchronously. |
| | |
| | By default, this defers to `.auth_flow()`. You should override this method |
| | when the authentication scheme does I/O and/or uses concurrency primitives. |
| | """ |
| | if self.requires_request_body: |
| | request.read() |
| |
|
| | flow = self.auth_flow(request) |
| | request = next(flow) |
| |
|
| | while True: |
| | response = yield request |
| | if self.requires_response_body: |
| | response.read() |
| |
|
| | try: |
| | request = flow.send(response) |
| | except StopIteration: |
| | break |
| |
|
| | async def async_auth_flow( |
| | self, request: Request |
| | ) -> typing.AsyncGenerator[Request, Response]: |
| | """ |
| | Execute the authentication flow asynchronously. |
| | |
| | By default, this defers to `.auth_flow()`. You should override this method |
| | when the authentication scheme does I/O and/or uses concurrency primitives. |
| | """ |
| | if self.requires_request_body: |
| | await request.aread() |
| |
|
| | flow = self.auth_flow(request) |
| | request = next(flow) |
| |
|
| | while True: |
| | response = yield request |
| | if self.requires_response_body: |
| | await response.aread() |
| |
|
| | try: |
| | request = flow.send(response) |
| | except StopIteration: |
| | break |
| |
|
| |
|
| | class FunctionAuth(Auth): |
| | """ |
| | Allows the 'auth' argument to be passed as a simple callable function, |
| | that takes the request, and returns a new, modified request. |
| | """ |
| |
|
| | def __init__(self, func: typing.Callable[[Request], Request]) -> None: |
| | self._func = func |
| |
|
| | def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| | yield self._func(request) |
| |
|
| |
|
| | class BasicAuth(Auth): |
| | """ |
| | Allows the 'auth' argument to be passed as a (username, password) pair, |
| | and uses HTTP Basic authentication. |
| | """ |
| |
|
| | def __init__(self, username: str | bytes, password: str | bytes) -> None: |
| | self._auth_header = self._build_auth_header(username, password) |
| |
|
| | def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| | request.headers["Authorization"] = self._auth_header |
| | yield request |
| |
|
| | def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: |
| | userpass = b":".join((to_bytes(username), to_bytes(password))) |
| | token = b64encode(userpass).decode() |
| | return f"Basic {token}" |
| |
|
| |
|
| | class NetRCAuth(Auth): |
| | """ |
| | Use a 'netrc' file to lookup basic auth credentials based on the url host. |
| | """ |
| |
|
| | def __init__(self, file: str | None = None) -> None: |
| | |
| | |
| | import netrc |
| |
|
| | self._netrc_info = netrc.netrc(file) |
| |
|
| | def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| | auth_info = self._netrc_info.authenticators(request.url.host) |
| | if auth_info is None or not auth_info[2]: |
| | |
| | yield request |
| | else: |
| | |
| | request.headers["Authorization"] = self._build_auth_header( |
| | username=auth_info[0], password=auth_info[2] |
| | ) |
| | yield request |
| |
|
| | def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: |
| | userpass = b":".join((to_bytes(username), to_bytes(password))) |
| | token = b64encode(userpass).decode() |
| | return f"Basic {token}" |
| |
|
| |
|
| | class DigestAuth(Auth): |
| | _ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = { |
| | "MD5": hashlib.md5, |
| | "MD5-SESS": hashlib.md5, |
| | "SHA": hashlib.sha1, |
| | "SHA-SESS": hashlib.sha1, |
| | "SHA-256": hashlib.sha256, |
| | "SHA-256-SESS": hashlib.sha256, |
| | "SHA-512": hashlib.sha512, |
| | "SHA-512-SESS": hashlib.sha512, |
| | } |
| |
|
| | def __init__(self, username: str | bytes, password: str | bytes) -> None: |
| | self._username = to_bytes(username) |
| | self._password = to_bytes(password) |
| | self._last_challenge: _DigestAuthChallenge | None = None |
| | self._nonce_count = 1 |
| |
|
| | def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| | if self._last_challenge: |
| | request.headers["Authorization"] = self._build_auth_header( |
| | request, self._last_challenge |
| | ) |
| |
|
| | response = yield request |
| |
|
| | if response.status_code != 401 or "www-authenticate" not in response.headers: |
| | |
| | |
| | return |
| |
|
| | for auth_header in response.headers.get_list("www-authenticate"): |
| | if auth_header.lower().startswith("digest "): |
| | break |
| | else: |
| | |
| | |
| | return |
| |
|
| | self._last_challenge = self._parse_challenge(request, response, auth_header) |
| | self._nonce_count = 1 |
| |
|
| | request.headers["Authorization"] = self._build_auth_header( |
| | request, self._last_challenge |
| | ) |
| | if response.cookies: |
| | Cookies(response.cookies).set_cookie_header(request=request) |
| | yield request |
| |
|
| | def _parse_challenge( |
| | self, request: Request, response: Response, auth_header: str |
| | ) -> _DigestAuthChallenge: |
| | """ |
| | Returns a challenge from a Digest WWW-Authenticate header. |
| | These take the form of: |
| | `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` |
| | """ |
| | scheme, _, fields = auth_header.partition(" ") |
| |
|
| | |
| | assert scheme.lower() == "digest" |
| |
|
| | header_dict: dict[str, str] = {} |
| | for field in parse_http_list(fields): |
| | key, value = field.strip().split("=", 1) |
| | header_dict[key] = unquote(value) |
| |
|
| | try: |
| | realm = header_dict["realm"].encode() |
| | nonce = header_dict["nonce"].encode() |
| | algorithm = header_dict.get("algorithm", "MD5") |
| | opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None |
| | qop = header_dict["qop"].encode() if "qop" in header_dict else None |
| | return _DigestAuthChallenge( |
| | realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop |
| | ) |
| | except KeyError as exc: |
| | message = "Malformed Digest WWW-Authenticate header" |
| | raise ProtocolError(message, request=request) from exc |
| |
|
| | def _build_auth_header( |
| | self, request: Request, challenge: _DigestAuthChallenge |
| | ) -> str: |
| | hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()] |
| |
|
| | def digest(data: bytes) -> bytes: |
| | return hash_func(data).hexdigest().encode() |
| |
|
| | A1 = b":".join((self._username, challenge.realm, self._password)) |
| |
|
| | path = request.url.raw_path |
| | A2 = b":".join((request.method.encode(), path)) |
| | |
| | HA2 = digest(A2) |
| |
|
| | nc_value = b"%08x" % self._nonce_count |
| | cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce) |
| | self._nonce_count += 1 |
| |
|
| | HA1 = digest(A1) |
| | if challenge.algorithm.lower().endswith("-sess"): |
| | HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) |
| |
|
| | qop = self._resolve_qop(challenge.qop, request=request) |
| | if qop is None: |
| | |
| | digest_data = [HA1, challenge.nonce, HA2] |
| | else: |
| | |
| | digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2] |
| |
|
| | format_args = { |
| | "username": self._username, |
| | "realm": challenge.realm, |
| | "nonce": challenge.nonce, |
| | "uri": path, |
| | "response": digest(b":".join(digest_data)), |
| | "algorithm": challenge.algorithm.encode(), |
| | } |
| | if challenge.opaque: |
| | format_args["opaque"] = challenge.opaque |
| | if qop: |
| | format_args["qop"] = b"auth" |
| | format_args["nc"] = nc_value |
| | format_args["cnonce"] = cnonce |
| |
|
| | return "Digest " + self._get_header_value(format_args) |
| |
|
| | def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes: |
| | s = str(nonce_count).encode() |
| | s += nonce |
| | s += time.ctime().encode() |
| | s += os.urandom(8) |
| |
|
| | return hashlib.sha1(s).hexdigest()[:16].encode() |
| |
|
| | def _get_header_value(self, header_fields: dict[str, bytes]) -> str: |
| | NON_QUOTED_FIELDS = ("algorithm", "qop", "nc") |
| | QUOTED_TEMPLATE = '{}="{}"' |
| | NON_QUOTED_TEMPLATE = "{}={}" |
| |
|
| | header_value = "" |
| | for i, (field, value) in enumerate(header_fields.items()): |
| | if i > 0: |
| | header_value += ", " |
| | template = ( |
| | QUOTED_TEMPLATE |
| | if field not in NON_QUOTED_FIELDS |
| | else NON_QUOTED_TEMPLATE |
| | ) |
| | header_value += template.format(field, to_str(value)) |
| |
|
| | return header_value |
| |
|
| | def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None: |
| | if qop is None: |
| | return None |
| | qops = re.split(b", ?", qop) |
| | if b"auth" in qops: |
| | return b"auth" |
| |
|
| | if qops == [b"auth-int"]: |
| | raise NotImplementedError("Digest auth-int support is not yet implemented") |
| |
|
| | message = f'Unexpected qop value "{qop!r}" in digest auth' |
| | raise ProtocolError(message, request=request) |
| |
|
| |
|
| | class _DigestAuthChallenge(typing.NamedTuple): |
| | realm: bytes |
| | nonce: bytes |
| | algorithm: str |
| | opaque: bytes | None |
| | qop: bytes | None |
| |
|