bitbake: bitbake: asyncrpc: Defer all asyncio to child process

Reworks the async I/O API so that the async loop is only created in the
child process. This requires deferring the creation of the server until
the child process and a queue to transfer the bound address back to the
parent process

(Bitbake rev: 8555869cde39f9e9a9ced5a3e5788209640f6d50)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
[small loop -> self.loop fix in serv.py]
Signed-off-by: Scott Murray <scott.murray@konsulko.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt 2021-08-19 12:46:41 -04:00 committed by Richard Purdie
parent e8182a794d
commit fdc908f321
2 changed files with 72 additions and 46 deletions

View File

@ -131,53 +131,58 @@ class AsyncServerConnection(object):
class AsyncServer(object): class AsyncServer(object):
def __init__(self, logger, loop=None): def __init__(self, logger):
if loop is None:
self.loop = asyncio.new_event_loop()
self.close_loop = True
else:
self.loop = loop
self.close_loop = False
self._cleanup_socket = None self._cleanup_socket = None
self.logger = logger self.logger = logger
self.start = None
self.address = None
@property
def loop(self):
return asyncio.get_event_loop()
def start_tcp_server(self, host, port): def start_tcp_server(self, host, port):
self.server = self.loop.run_until_complete( def start_tcp():
asyncio.start_server(self.handle_client, host, port, loop=self.loop) self.server = self.loop.run_until_complete(
) asyncio.start_server(self.handle_client, host, port)
)
for s in self.server.sockets: for s in self.server.sockets:
self.logger.debug('Listening on %r' % (s.getsockname(),)) self.logger.debug('Listening on %r' % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for # Newer python does this automatically. Do it manually here for
# maximum compatibility # maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
name = self.server.sockets[0].getsockname() name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6: if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1]) self.address = "[%s]:%d" % (name[0], name[1])
else: else:
self.address = "%s:%d" % (name[0], name[1]) self.address = "%s:%d" % (name[0], name[1])
self.start = start_tcp
def start_unix_server(self, path): def start_unix_server(self, path):
def cleanup(): def cleanup():
os.unlink(path) os.unlink(path)
cwd = os.getcwd() def start_unix():
try: cwd = os.getcwd()
# Work around path length limits in AF_UNIX try:
os.chdir(os.path.dirname(path)) # Work around path length limits in AF_UNIX
self.server = self.loop.run_until_complete( os.chdir(os.path.dirname(path))
asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) self.server = self.loop.run_until_complete(
) asyncio.start_unix_server(self.handle_client, os.path.basename(path))
finally: )
os.chdir(cwd) finally:
os.chdir(cwd)
self.logger.debug('Listening on %r' % path) self.logger.debug('Listening on %r' % path)
self._cleanup_socket = cleanup self._cleanup_socket = cleanup
self.address = "unix://%s" % os.path.abspath(path) self.address = "unix://%s" % os.path.abspath(path)
self.start = start_unix
@abc.abstractmethod @abc.abstractmethod
def accept_client(self, reader, writer): def accept_client(self, reader, writer):
@ -205,8 +210,7 @@ class AsyncServer(object):
self.logger.debug("Got exit signal") self.logger.debug("Got exit signal")
self.loop.stop() self.loop.stop()
def serve_forever(self): def _serve_forever(self):
asyncio.set_event_loop(self.loop)
try: try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
@ -217,28 +221,50 @@ class AsyncServer(object):
self.loop.run_until_complete(self.server.wait_closed()) self.loop.run_until_complete(self.server.wait_closed())
self.logger.debug('Server shutting down') self.logger.debug('Server shutting down')
finally: finally:
if self.close_loop:
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
if self._cleanup_socket is not None: if self._cleanup_socket is not None:
self._cleanup_socket() self._cleanup_socket()
def serve_forever(self):
"""
Serve requests in the current process
"""
self.start()
self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()): def serve_as_process(self, *, prefunc=None, args=()):
def run(): """
Serve requests in a child process
"""
def run(queue):
try:
self.start()
finally:
queue.put(self.address)
queue.close()
if prefunc is not None: if prefunc is not None:
prefunc(self, *args) prefunc(self, *args)
self.serve_forever()
self._serve_forever()
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
queue = multiprocessing.Queue()
# Temporarily block SIGTERM. The server process will inherit this # Temporarily block SIGTERM. The server process will inherit this
# block which will ensure it doesn't receive the SIGTERM until the # block which will ensure it doesn't receive the SIGTERM until the
# handler is ready for it # handler is ready for it
mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM]) mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
try: try:
self.process = multiprocessing.Process(target=run) self.process = multiprocessing.Process(target=run, args=(queue,))
self.process.start() self.process.start()
self.address = queue.get()
queue.close()
queue.join_thread()
return self.process return self.process
finally: finally:
signal.pthread_sigmask(signal.SIG_SETMASK, mask) signal.pthread_sigmask(signal.SIG_SETMASK, mask)

View File

@ -410,11 +410,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
class Server(bb.asyncrpc.AsyncServer): class Server(bb.asyncrpc.AsyncServer):
def __init__(self, db, loop=None, upstream=None, read_only=False): def __init__(self, db, upstream=None, read_only=False):
if upstream and read_only: if upstream and read_only:
raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server") raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
super().__init__(logger, loop) super().__init__(logger)
self.request_stats = Stats() self.request_stats = Stats()
self.db = db self.db = db