| |
|
|
| """asyncio library query support""" |
|
|
| import asyncio |
| import socket |
| import sys |
|
|
| import dns._asyncbackend |
| import dns._features |
| import dns.exception |
| import dns.inet |
|
|
| _is_win32 = sys.platform == "win32" |
|
|
|
|
| def _get_running_loop(): |
| try: |
| return asyncio.get_running_loop() |
| except AttributeError: |
| return asyncio.get_event_loop() |
|
|
|
|
| class _DatagramProtocol: |
| def __init__(self): |
| self.transport = None |
| self.recvfrom = None |
|
|
| def connection_made(self, transport): |
| self.transport = transport |
|
|
| def datagram_received(self, data, addr): |
| if self.recvfrom and not self.recvfrom.done(): |
| self.recvfrom.set_result((data, addr)) |
|
|
| def error_received(self, exc): |
| if self.recvfrom and not self.recvfrom.done(): |
| self.recvfrom.set_exception(exc) |
|
|
| def connection_lost(self, exc): |
| if self.recvfrom and not self.recvfrom.done(): |
| if exc is None: |
| |
| try: |
| raise EOFError |
| except EOFError as e: |
| self.recvfrom.set_exception(e) |
| else: |
| self.recvfrom.set_exception(exc) |
|
|
| def close(self): |
| self.transport.close() |
|
|
|
|
| async def _maybe_wait_for(awaitable, timeout): |
| if timeout is not None: |
| try: |
| return await asyncio.wait_for(awaitable, timeout) |
| except asyncio.TimeoutError: |
| raise dns.exception.Timeout(timeout=timeout) |
| else: |
| return await awaitable |
|
|
|
|
| class DatagramSocket(dns._asyncbackend.DatagramSocket): |
| def __init__(self, family, transport, protocol): |
| super().__init__(family) |
| self.transport = transport |
| self.protocol = protocol |
|
|
| async def sendto(self, what, destination, timeout): |
| |
| self.transport.sendto(what, destination) |
| return len(what) |
|
|
| async def recvfrom(self, size, timeout): |
| |
| done = _get_running_loop().create_future() |
| try: |
| assert self.protocol.recvfrom is None |
| self.protocol.recvfrom = done |
| await _maybe_wait_for(done, timeout) |
| return done.result() |
| finally: |
| self.protocol.recvfrom = None |
|
|
| async def close(self): |
| self.protocol.close() |
|
|
| async def getpeername(self): |
| return self.transport.get_extra_info("peername") |
|
|
| async def getsockname(self): |
| return self.transport.get_extra_info("sockname") |
|
|
| async def getpeercert(self, timeout): |
| raise NotImplementedError |
|
|
|
|
| class StreamSocket(dns._asyncbackend.StreamSocket): |
| def __init__(self, af, reader, writer): |
| self.family = af |
| self.reader = reader |
| self.writer = writer |
|
|
| async def sendall(self, what, timeout): |
| self.writer.write(what) |
| return await _maybe_wait_for(self.writer.drain(), timeout) |
|
|
| async def recv(self, size, timeout): |
| return await _maybe_wait_for(self.reader.read(size), timeout) |
|
|
| async def close(self): |
| self.writer.close() |
|
|
| async def getpeername(self): |
| return self.writer.get_extra_info("peername") |
|
|
| async def getsockname(self): |
| return self.writer.get_extra_info("sockname") |
|
|
| async def getpeercert(self, timeout): |
| return self.writer.get_extra_info("peercert") |
|
|
|
|
| if dns._features.have("doh"): |
| import anyio |
| import httpcore |
| import httpcore._backends.anyio |
| import httpx |
|
|
| _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend |
| _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream |
|
|
| from dns.query import _compute_times, _expiration_for_this_attempt, _remaining |
|
|
| class _NetworkBackend(_CoreAsyncNetworkBackend): |
| def __init__(self, resolver, local_port, bootstrap_address, family): |
| super().__init__() |
| self._local_port = local_port |
| self._resolver = resolver |
| self._bootstrap_address = bootstrap_address |
| self._family = family |
| if local_port != 0: |
| raise NotImplementedError( |
| "the asyncio transport for HTTPX cannot set the local port" |
| ) |
|
|
| async def connect_tcp( |
| self, host, port, timeout, local_address, socket_options=None |
| ): |
| addresses = [] |
| _, expiration = _compute_times(timeout) |
| if dns.inet.is_address(host): |
| addresses.append(host) |
| elif self._bootstrap_address is not None: |
| addresses.append(self._bootstrap_address) |
| else: |
| timeout = _remaining(expiration) |
| family = self._family |
| if local_address: |
| family = dns.inet.af_for_address(local_address) |
| answers = await self._resolver.resolve_name( |
| host, family=family, lifetime=timeout |
| ) |
| addresses = answers.addresses() |
| for address in addresses: |
| try: |
| attempt_expiration = _expiration_for_this_attempt(2.0, expiration) |
| timeout = _remaining(attempt_expiration) |
| with anyio.fail_after(timeout): |
| stream = await anyio.connect_tcp( |
| remote_host=address, |
| remote_port=port, |
| local_host=local_address, |
| ) |
| return _CoreAnyIOStream(stream) |
| except Exception: |
| pass |
| raise httpcore.ConnectError |
|
|
| async def connect_unix_socket( |
| self, path, timeout, socket_options=None |
| ): |
| raise NotImplementedError |
|
|
| async def sleep(self, seconds): |
| await anyio.sleep(seconds) |
|
|
| class _HTTPTransport(httpx.AsyncHTTPTransport): |
| def __init__( |
| self, |
| *args, |
| local_port=0, |
| bootstrap_address=None, |
| resolver=None, |
| family=socket.AF_UNSPEC, |
| **kwargs, |
| ): |
| if resolver is None: |
| |
| import dns.asyncresolver |
|
|
| resolver = dns.asyncresolver.Resolver() |
| super().__init__(*args, **kwargs) |
| self._pool._network_backend = _NetworkBackend( |
| resolver, local_port, bootstrap_address, family |
| ) |
|
|
| else: |
| _HTTPTransport = dns._asyncbackend.NullTransport |
|
|
|
|
| class Backend(dns._asyncbackend.Backend): |
| def name(self): |
| return "asyncio" |
|
|
| async def make_socket( |
| self, |
| af, |
| socktype, |
| proto=0, |
| source=None, |
| destination=None, |
| timeout=None, |
| ssl_context=None, |
| server_hostname=None, |
| ): |
| loop = _get_running_loop() |
| if socktype == socket.SOCK_DGRAM: |
| if _is_win32 and source is None: |
| |
| |
| source = (dns.inet.any_for_af(af), 0) |
| transport, protocol = await loop.create_datagram_endpoint( |
| _DatagramProtocol, |
| source, |
| family=af, |
| proto=proto, |
| remote_addr=destination, |
| ) |
| return DatagramSocket(af, transport, protocol) |
| elif socktype == socket.SOCK_STREAM: |
| if destination is None: |
| |
| |
| raise ValueError("destination required for stream sockets") |
| (r, w) = await _maybe_wait_for( |
| asyncio.open_connection( |
| destination[0], |
| destination[1], |
| ssl=ssl_context, |
| family=af, |
| proto=proto, |
| local_addr=source, |
| server_hostname=server_hostname, |
| ), |
| timeout, |
| ) |
| return StreamSocket(af, r, w) |
| raise NotImplementedError( |
| "unsupported socket " + f"type {socktype}" |
| ) |
|
|
| async def sleep(self, interval): |
| await asyncio.sleep(interval) |
|
|
| def datagram_connection_required(self): |
| return False |
|
|
| def get_transport_class(self): |
| return _HTTPTransport |
|
|
| async def wait_for(self, awaitable, timeout): |
| return await _maybe_wait_for(awaitable, timeout) |
|
|