bitbake: bitbake: asyncrpc: Catch early SIGTERM

If the SIGTERM signal is sent to an asyncrpc server before it has
installed the SIGTERM handler in the main loop, it may miss the signal
which will can cause the calling process to wait forever on the join().
To resolve this, the calling process should mask of SIGTERM before
forking the server process and the server should unmask the signal only
after the handler is installed. To simplify the usage of the server, an
new helper function called serve_as_process() is added to do this
automatically and correctly.

Thanks: Scott Murray <scott.murray@konsulko.com> for helping debug
(Bitbake rev: ef2865efa98ad20823267364f2159d8d8c931400)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt 2021-07-22 11:19:37 -05:00 committed by Richard Purdie
parent 445c5b9324
commit a83ea01b99
3 changed files with 64 additions and 14 deletions

View File

@ -9,6 +9,7 @@ import os
import signal
import socket
import sys
import multiprocessing
from . import chunkify, DEFAULT_MAX_CHUNK
@ -201,12 +202,14 @@ class AsyncServer(object):
pass
def signal_handler(self):
self.logger.debug("Got exit signal")
self.loop.stop()
def serve_forever(self):
asyncio.set_event_loop(self.loop)
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
self.run_loop_forever()
self.server.close()
@ -221,3 +224,21 @@ class AsyncServer(object):
if self._cleanup_socket is not None:
self._cleanup_socket()
def serve_as_process(self, *, prefunc=None, args=()):
def run():
if prefunc is not None:
prefunc(self, *args)
self.serve_forever()
# Temporarily block SIGTERM. The server process will inherit this
# block which will ensure it doesn't receive the SIGTERM until the
# handler is ready for it
mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
try:
self.process = multiprocessing.Process(target=run)
self.process.start()
return self.process
finally:
signal.pthread_sigmask(signal.SIG_SETMASK, mask)

View File

@ -390,8 +390,7 @@ class BBCooker:
dbfile = (self.data.getVar("PERSISTENT_DIR") or self.data.getVar("CACHE")) + "/hashserv.db"
self.hashservaddr = "unix://%s/hashserve.sock" % self.data.getVar("TOPDIR")
self.hashserv = hashserv.create_server(self.hashservaddr, dbfile, sync=False)
self.hashserv.process = multiprocessing.Process(target=self.hashserv.serve_forever)
self.hashserv.process.start()
self.hashserv.serve_as_process()
self.data.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.origdata.setVar("BB_HASHSERVE", self.hashservaddr)
self.databuilder.data.setVar("BB_HASHSERVE", self.hashservaddr)

View File

@ -15,28 +15,32 @@ import tempfile
import threading
import unittest
import socket
import time
import signal
def _run_server(server, idx):
# logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
# format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
def server_prefunc(server, idx):
logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w',
format='%(levelname)s %(filename)s:%(lineno)d %(message)s')
server.logger.debug("Running server %d" % idx)
sys.stdout = open('bbhashserv-%d.log' % idx, 'w')
sys.stderr = sys.stdout
server.serve_forever()
class HashEquivalenceTestSetup(object):
METHOD = 'TestMethod'
server_index = 0
def start_server(self, dbpath=None, upstream=None, read_only=False):
def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
def cleanup_thread(thread):
thread.terminate()
thread.join()
def cleanup_server(server):
if server.process.exitcode is not None:
return
server.process.terminate()
server.process.join()
server = create_server(self.get_server_addr(self.server_index),
dbpath,
@ -44,9 +48,8 @@ class HashEquivalenceTestSetup(object):
read_only=read_only)
server.dbpath = dbpath
server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index))
server.thread.start()
self.addCleanup(cleanup_thread, server.thread)
server.serve_as_process(prefunc=prefunc, args=(self.server_index,))
self.addCleanup(cleanup_server, server)
def cleanup_client(client):
client.close()
@ -283,6 +286,33 @@ class HashEquivalenceCommonTests(object):
self.assertClientGetHash(self.client, taskhash2, None)
def test_slow_server_start(self):
"""
Ensures that the server will exit correctly even if it gets a SIGTERM
before entering the main loop
"""
event = multiprocessing.Event()
def prefunc(server, idx):
nonlocal event
server_prefunc(server, idx)
event.wait()
def do_nothing(signum, frame):
pass
old_signal = signal.signal(signal.SIGTERM, do_nothing)
self.addCleanup(signal.signal, signal.SIGTERM, old_signal)
_, server = self.start_server(prefunc=prefunc)
server.process.terminate()
time.sleep(30)
event.set()
server.process.join(300)
self.assertIsNotNone(server.process.exitcode, "Server did not exit in a timely manner!")
class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def get_server_addr(self, server_idx):
return "unix://" + os.path.join(self.temp_dir.name, 'sock%d' % server_idx)