| import re |
| from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union |
|
|
| from ._abnf import field_name, field_value |
| from ._util import bytesify, LocalProtocolError, validate |
|
|
| if TYPE_CHECKING: |
| from ._events import Request |
|
|
| try: |
| from typing import Literal |
| except ImportError: |
| from typing_extensions import Literal |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _content_length_re = re.compile(rb"[0-9]+") |
| _field_name_re = re.compile(field_name.encode("ascii")) |
| _field_value_re = re.compile(field_value.encode("ascii")) |
|
|
|
|
| class Headers(Sequence[Tuple[bytes, bytes]]): |
| """ |
| A list-like interface that allows iterating over headers as byte-pairs |
| of (lowercased-name, value). |
| |
| Internally we actually store the representation as three-tuples, |
| including both the raw original casing, in order to preserve casing |
| over-the-wire, and the lowercased name, for case-insensitive comparisions. |
| |
| r = Request( |
| method="GET", |
| target="/", |
| headers=[("Host", "example.org"), ("Connection", "keep-alive")], |
| http_version="1.1", |
| ) |
| assert r.headers == [ |
| (b"host", b"example.org"), |
| (b"connection", b"keep-alive") |
| ] |
| assert r.headers.raw_items() == [ |
| (b"Host", b"example.org"), |
| (b"Connection", b"keep-alive") |
| ] |
| """ |
|
|
| __slots__ = "_full_items" |
|
|
| def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: |
| self._full_items = full_items |
|
|
| def __bool__(self) -> bool: |
| return bool(self._full_items) |
|
|
| def __eq__(self, other: object) -> bool: |
| return list(self) == list(other) |
|
|
| def __len__(self) -> int: |
| return len(self._full_items) |
|
|
| def __repr__(self) -> str: |
| return "<Headers(%s)>" % repr(list(self)) |
|
|
| def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: |
| _, name, value = self._full_items[idx] |
| return (name, value) |
|
|
| def raw_items(self) -> List[Tuple[bytes, bytes]]: |
| return [(raw_name, value) for raw_name, _, value in self._full_items] |
|
|
|
|
| HeaderTypes = Union[ |
| List[Tuple[bytes, bytes]], |
| List[Tuple[bytes, str]], |
| List[Tuple[str, bytes]], |
| List[Tuple[str, str]], |
| ] |
|
|
|
|
| @overload |
| def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers: |
| ... |
|
|
|
|
| @overload |
| def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers: |
| ... |
|
|
|
|
| @overload |
| def normalize_and_validate( |
| headers: Union[Headers, HeaderTypes], _parsed: bool = False |
| ) -> Headers: |
| ... |
|
|
|
|
| def normalize_and_validate( |
| headers: Union[Headers, HeaderTypes], _parsed: bool = False |
| ) -> Headers: |
| new_headers = [] |
| seen_content_length = None |
| saw_transfer_encoding = False |
| for name, value in headers: |
| |
| |
| |
| if not _parsed: |
| name = bytesify(name) |
| value = bytesify(value) |
| validate(_field_name_re, name, "Illegal header name {!r}", name) |
| validate(_field_value_re, value, "Illegal header value {!r}", value) |
| assert isinstance(name, bytes) |
| assert isinstance(value, bytes) |
|
|
| raw_name = name |
| name = name.lower() |
| if name == b"content-length": |
| lengths = {length.strip() for length in value.split(b",")} |
| if len(lengths) != 1: |
| raise LocalProtocolError("conflicting Content-Length headers") |
| value = lengths.pop() |
| validate(_content_length_re, value, "bad Content-Length") |
| if seen_content_length is None: |
| seen_content_length = value |
| new_headers.append((raw_name, name, value)) |
| elif seen_content_length != value: |
| raise LocalProtocolError("conflicting Content-Length headers") |
| elif name == b"transfer-encoding": |
| |
| |
| |
| |
| if saw_transfer_encoding: |
| raise LocalProtocolError( |
| "multiple Transfer-Encoding headers", error_status_hint=501 |
| ) |
| |
| |
| value = value.lower() |
| if value != b"chunked": |
| raise LocalProtocolError( |
| "Only Transfer-Encoding: chunked is supported", |
| error_status_hint=501, |
| ) |
| saw_transfer_encoding = True |
| new_headers.append((raw_name, name, value)) |
| else: |
| new_headers.append((raw_name, name, value)) |
| return Headers(new_headers) |
|
|
|
|
| def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| out: List[bytes] = [] |
| for _, found_name, found_raw_value in headers._full_items: |
| if found_name == name: |
| found_raw_value = found_raw_value.lower() |
| for found_split_value in found_raw_value.split(b","): |
| found_split_value = found_split_value.strip() |
| if found_split_value: |
| out.append(found_split_value) |
| return out |
|
|
|
|
| def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| new_headers: List[Tuple[bytes, bytes]] = [] |
| for found_raw_name, found_name, found_raw_value in headers._full_items: |
| if found_name != name: |
| new_headers.append((found_raw_name, found_raw_value)) |
| for new_value in new_values: |
| new_headers.append((name.title(), new_value)) |
| return normalize_and_validate(new_headers) |
|
|
|
|
| def has_expect_100_continue(request: "Request") -> bool: |
| |
| |
| |
| if request.http_version < b"1.1": |
| return False |
| expect = get_comma_header(request.headers, b"expect") |
| return b"100-continue" in expect |
|
|